Skip to content

Commit

Permalink
Make sampling, integration and mode-estimation also return an Evaluat…
Browse files Browse the repository at this point in the history
…edMeasure
  • Loading branch information
oschulz committed Oct 29, 2024
1 parent 2e985d7 commit dbd394f
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 6 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ run more or less unchanged (with deprecation warnings). Also:

### New features

* Sampling, integration and mode-finding algorithms now generate a return
value `result = ..., evaluated::EvaluatedMeasure = ..., ...)` if their
target is a probability measure/distribution.

* The new `RAMTuning` is now the default (transform) tuning algorithm for
`RandomWalk` (formerly `MetropolisHastings`). It typically results in a much
faster burn-in process than `AdaptiveAffineTuning` (formerly
Expand Down
63 changes: 63 additions & 0 deletions src/algotypes/bat_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,66 @@ end
result_with_args(r::NamedTuple) = merge(r, (optargs = NamedTuple(),))

result_with_args(r::NamedTuple, optargs::NamedTuple) = merge(r, (optargs = optargs,))

function result_with_args(::Val, ::Any, r::NamedTuple, optargs::NamedTuple)
return result_with_args(r, optargs)
end

function result_with_args(::Val{resultname}, target::Union{AbstractMeasure,Distribution}, r::NamedTuple, optargs::NamedTuple) where resultname
measure = batmeasure(target)
augmented_result = _augment_bat_retval(Val(resultname), measure, r)
result_with_args(augmented_result, optargs)
end

function _augment_bat_retval(::Val{resultname}, measure, r::R) where {resultname,R}
if hasfield(R, :evaluated)
return r
else
if resultname == :samples
samples = r.result
elseif hasfield(R, :samples)
samples = r.samples
else
samples = maybe_samplesof(measure)
end

if resultname == :approx
approx = r.result
elseif hasfield(R, :approx)
approx = r.approx
else
approx = maybe_approxof(measure)
end

if resultname == :mass
mass = r.result
elseif hasfield(R, :mass)
mass = r.mass
else
mass = massof(measure)
end

if resultname == :modes
modes = r.result
elseif resultname == :mode
modes = [r.result]
elseif hasfield(R, :modes)
modes = r.modes
elseif hasfield(R, :mode)
modes = [r.mode]
else
modes = maybe_modesof(measure)
end

if resultname == :generator
generator = r.result
elseif hasfield(R, :generator)
generator = r.generator
else
generator = maybe_generator(measure)
end
evaluated = EvaluatedMeasure(unevaluated(measure), samples, approx, mass, modes, generator)
r_add = (result = r.result, evaluated = evaluated)
return merge(r_add, r)
end
end
6 changes: 4 additions & 2 deletions src/algotypes/integration_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ Calculate the integral (evidence) of `target`.
Returns a NamedTuple of the shape
```julia
(result = X::Measurements.Measurement, ...)
(result = X::Measurements.Measurement, evaluated::EvaluatedMeasure, ...)
```
(The field `evaluated` is only present if `target` is a measure.)
Result properties not listed here are algorithm-specific and are not part
of the stable public API.
Expand All @@ -42,7 +44,7 @@ function bat_integrate_impl end
function bat_integrate(target::AnySampleable, algorithm::IntegrationAlgorithm, context::BATContext)
orig_context = deepcopy(context)
r = bat_integrate_impl(target, algorithm, context)
result_with_args(r, (algorithm = algorithm, context = orig_context))
result_with_args(Val(:mass), target, r, (algorithm = algorithm, context = orig_context))
end

bat_integrate(target::AnySampleable) = bat_integrate(target, get_batcontext())
Expand Down
7 changes: 5 additions & 2 deletions src/algotypes/mode_estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ Estimate the global mode of `target`.
Returns a NamedTuple of the shape
```julia
(result = X::DensitySampleVector, ...)
(result = X::DensitySampleVector, evaluated::EvaluatedMeasure, ...)
```
(The field `evaluated` is only present if `target` is a measure.)
Result properties not listed here are algorithm-specific and are not part
of the stable public API.
Expand All @@ -43,7 +46,7 @@ function bat_findmode_impl end
function bat_findmode(target::AnySampleable, algorithm, context::BATContext)
orig_context = deepcopy(context)
r = bat_findmode_impl(target, algorithm, context)
result_with_args(r, (algorithm = algorithm, context = orig_context))
result_with_args(Val(:mode), target, r, (algorithm = algorithm, context = orig_context))
end

bat_findmode(target::AnySampleable) = bat_findmode(target, get_batcontext())
Expand Down
5 changes: 3 additions & 2 deletions src/algotypes/sampling_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ Depending on sampling algorithm, the samples may be independent or correlated
Returns a NamedTuple of the shape
```julia
(result = X::DensitySampleVector, ...)
(result = X::DensitySampleVector, evaluated::EvaluatedMeasure, ...)
```
(The field `evaluated` is only present if `target` is a measure.)
Result properties not listed here are algorithm-specific and are not part
of the stable public API.
Expand Down Expand Up @@ -54,7 +55,7 @@ function bat_sample(target, algorithm::AbstractSamplingAlgorithm, context::BATCo
measure = convert_for(bat_sample, target)
orig_context = deepcopy(context)
r = bat_sample_impl(measure, algorithm, context)
result_with_args(r, (algorithm = algorithm, context = orig_context))
result_with_args(Val(:samples), target, r, (algorithm = algorithm, context = orig_context))
end

function bat_sample(target::AnySampleable)
Expand Down
1 change: 1 addition & 0 deletions test/samplers/test_mgvi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import MGVI, ForwardDiff
)
r = bat_sample(pstr, algorithm, context)
@test r.result isa DensitySampleVector
@test r.evaluated isa EvaluatedMeasure
@test first(r.result.info.converged) == false
@test last(r.result.info.converged) == true
@test unique(r.result.info.stepno) == 1:nsteps+1
Expand Down

0 comments on commit dbd394f

Please sign in to comment.