From 5948253f792aeec0e7418a438a5afddb0cbf0c4e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 14:28:08 +0100 Subject: [PATCH 01/25] Replace old Gibbs sampler with the experimental one. --- src/Turing.jl | 1 - src/experimental/Experimental.jl | 16 - src/experimental/gibbs.jl | 488 -------------------------- src/mcmc/Inference.jl | 7 +- src/mcmc/abstractmcmc.jl | 4 + src/mcmc/gibbs.jl | 580 ++++++++++++++++++++++--------- src/mcmc/gibbs_conditional.jl | 88 ----- test/experimental/gibbs.jl | 270 -------------- test/mcmc/Inference.jl | 14 +- test/mcmc/ess.jl | 10 +- test/mcmc/gibbs.jl | 325 +++++++++++++++-- test/mcmc/gibbs_conditional.jl | 172 --------- test/mcmc/hmc.jl | 12 +- test/mcmc/mh.jl | 10 +- test/runtests.jl | 5 - test/skipped/explicit_ret.jl | 2 +- 16 files changed, 751 insertions(+), 1253 deletions(-) delete mode 100644 src/experimental/Experimental.jl delete mode 100644 src/experimental/gibbs.jl delete mode 100644 src/mcmc/gibbs_conditional.jl delete mode 100644 test/experimental/gibbs.jl delete mode 100644 test/mcmc/gibbs_conditional.jl diff --git a/src/Turing.jl b/src/Turing.jl index 8dfb8df28..8fcee6c18 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -86,7 +86,6 @@ export @model, # modelling Emcee, ESS, Gibbs, - GibbsConditional, HMC, # Hamiltonian-like sampling SGLD, SGHMC, diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl deleted file mode 100644 index 518538e6c..000000000 --- a/src/experimental/Experimental.jl +++ /dev/null @@ -1,16 +0,0 @@ -module Experimental - -using Random: Random -using AbstractMCMC: AbstractMCMC -using DynamicPPL: DynamicPPL, VarName -using Accessors: Accessors - -using DocStringExtensions: TYPEDFIELDS -using Distributions - -using ..Turing: Turing -using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm - -include("gibbs.jl") - -end diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl deleted file mode 100644 index 596e6e283..000000000 --- a/src/experimental/gibbs.jl +++ /dev/null @@ -1,488 +0,0 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -# Needed to support broadcasting over columns for `MultivariateDistribution`s. -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - -""" - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) - -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. -""" -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end - -# For `DynamicPPL.AbstractVarInfo` we just extract the values. -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) - -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo... -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end - - -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" -function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs( - model, - filter(Base.Fix1(!==, target_varinfo), varinfos)... - ) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs( - model, - varinfos... - ) -end - -# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` -# or an `AbstractInferenceAlgorithm`. -wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - -""" - Gibbs - -A type representing a Gibbs sampler. - -# Fields -$(TYPEDFIELDS) -""" -struct Gibbs{V,A} <: InferenceAlgorithm - "varnames representing variables for each sampler" - varnames::V - "samplers for each entry in `varnames`" - samplers::A -end - -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs( - map(s -> VarName{s}(), keys(algs)), - map(wrap_algorithm_maybe, values(algs)), - ) -end - -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) -end -function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) -end - -# TODO: Remove when no longer needed. -DynamicPPL.getspace(::Gibbs) = () - -struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} - vi::V - states::S -end - -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] - -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - vi_base::DynamicPPL.AbstractVarInfo; - initial_params=nothing, - kwargs..., -) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers - - # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. - if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) - end - - # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local - # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) - - # Take initial step. - new_state_local = last(AbstractMCMC.step( - rng, model_local, sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs... - )) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = Turing.Inference.varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) - else - vi_local_state - end - return (new_state_local, vi_local_state_linked) - end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - state::GibbsState; - kwargs..., -) - alg = spl.alg - samplers = alg.samplers - states = state.states - varinfos = map(Turing.Inference.varinfo, state.states) - @assert length(samplers) == length(state.states) - - # TODO: move this into a recursive function so we can unroll when reasonable? - for index = 1:length(samplers) - # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, - model, - samplers, - states, - varinfos, - index; - kwargs..., - ) - - # Update the `states` and `varinfos`. - states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) - end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, - merge(vi_base, first(varinfos)), - firstindex(varinfos), - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - return true -end - -# TODO: Remove `rng`? -function Turing.Inference.recompute_logprob!!( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler, - state -) - varinfo = Turing.Inference.varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - varinfo_new = last(DynamicPPL.evaluate!!( - model, - varinfo, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun) - )) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) -end - -function gibbs_step_inner( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - samplers, - states, - varinfos, - index; - kwargs..., -) - # Needs to do a a few things. - sampler_local = samplers[index] - state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] - - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) - - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - - # 1. Re-run the sampler if needed. - if gibbs_requires_recompute_logprob( - model_local, - sampler_local, - sampler_previous, - state_local, - state_previous - ) - state_local = Turing.Inference.recompute_logprob!!( - rng, - model_local, - sampler_local, - state_local, - ) - end - - # 2. Take step with local sampler. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local, - state_local; - kwargs..., - ), - ) - - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = Turing.Inference.varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) -end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index b7bdf206b..495559871 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -46,7 +46,6 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling - GibbsConditional, HMC, SGLD, PolynomialStepsize, @@ -63,7 +62,6 @@ export InferenceAlgorithm, observe, dot_observe, predict, - isgibbscomponent, externalsampler ####################### @@ -526,22 +524,21 @@ end # Concrete algorithm implementations. # ####################################### +include("abstractmcmc.jl") include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") include("particle_mcmc.jl") -include("gibbs_conditional.jl") include("gibbs.jl") include("sghmc.jl") include("emcee.jl") -include("abstractmcmc.jl") ################ # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee) +for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index a350d2908..965c79706 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -27,6 +27,10 @@ function varinfo(state::TuringState) # TODO: Do we need to link here first? return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end +varinfo(state::AbstractVarInfo) = state +# TODO(mhauru) Could we have a type bound on the argument below, for documentation purposes? +varinfo(state) = state.vi + # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 736845b67..fb05b6475 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,101 +1,220 @@ -### -### Gibbs samplers / compositional samplers. -### +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. +struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + values::Values + context::Ctx +end -""" - isgibbscomponent(alg) +Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) -Determine whether algorithm `alg` is allowed as a Gibbs component. -""" -isgibbscomponent(alg) = false +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) + return GibbsContext(context.values, childcontext) +end -isgibbscomponent(::ESS) = true -isgibbscomponent(::GibbsConditional) = true -isgibbscomponent(::Hamiltonian) = true -isgibbscomponent(::MH) = true -isgibbscomponent(::PG) = true +# has and get +function has_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.hasvalue(context.values, vn) +end +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(has_conditioned_gibbs, context), vns) +end -const TGIBBS = Union{InferenceAlgorithm,GibbsConditional} +function get_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.getvalue(context.values, vn) +end +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end -""" - Gibbs(algs...) +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end -Compositional MCMC interface. Gibbs sampling combines one or more -sampling algorithms, each of which samples from a different set of -variables in a model. + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) +end -Example: -```julia -@model function gibbs_example(x) - v1 ~ Normal(0,1) - v2 ~ Categorical(5) +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) end -# Use PG for a 'v2' variable, and use HMC for the 'v1' variable. -# Note that v2 is discrete, so the PG sampler is more appropriate -# than is HMC. -alg = Gibbs(HMC(0.2, 3, :v1), PG(20, :v2)) -``` +# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. +make_broadcastable(x) = x +make_broadcastable(dist::Distribution) = tuple(dist) -One can also pass the number of iterations for each Gibbs component using the following syntax: -- `alg = Gibbs((HMC(0.2, 3, :v1), n_hmc), (PG(20, :v2), n_pg))` -where `n_hmc` and `n_pg` are the number of HMC and PG iterations for each Gibbs iteration. +# Need the following two methods to properly support broadcasting over columns. +broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) +function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) + return loglikelihood(dist, x) +end -Tips: -- `HMC` and `NUTS` are fast samplers and can throw off particle-based -methods like Particle Gibbs. You can increase the effectiveness of particle sampling by including -more particles in the particle sampler. -""" -struct Gibbs{space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} <: InferenceAlgorithm - algs::A # component sampling algorithms - iterations::B - function Gibbs{space,N,A,B}( - algs::A, iterations::B - ) where {space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} - all(isgibbscomponent, algs) || - error("all algorithms have to support Gibbs sampling") - return new{space,N,A,B}(algs, iterations) +# Needed to support broadcasting over columns for `MultivariateDistribution`s. +reconstruct_getvalue(dist, x) = x +function reconstruct_getvalue( + dist::MultivariateDistribution, x::AbstractVector{<:AbstractVector{<:Real}} +) + return reduce(hcat, x[2:end]; init=x[1]) +end + +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vns, vi + ) end -function Gibbs(alg1::TGIBBS, algrest::Vararg{TGIBBS,N}) where {N} - algs = (alg1, algrest...) - iterations = ntuple(Returns(1), Val(N + 1)) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi + ) end -function Gibbs(arg1::Tuple{<:TGIBBS,Int}, argrest::Vararg{Tuple{<:TGIBBS,Int},N}) where {N} - allargs = (arg1, argrest...) - algs = map(first, allargs) - iterations = map(last, allargs) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +""" + preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) + +Returns the preferred value type for a variable with the given `varinfo`. +""" +preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict +preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple +function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) + # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. + namedtuple_compatible = all(varinfo.metadata) do md + eltype(md.vns) <: VarName{<:Any,typeof(identity)} + end + return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + +Return a `GibbsContext` with the given values treated as conditioned. -Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that -the `Gibbs` sampler iterates through for each `step!`. +# Arguments +- `context::DynamicPPL.AbstractContext`: The context to condition. +- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. + If multiple values are provided, we recursively condition on each of them. """ -struct GibbsState{V<:VarInfo,S<:Tuple{Vararg{Sampler}},T} - vi::V - samplers::S - states::T +condition_gibbs(context::DynamicPPL.AbstractContext) = context +# For `NamedTuple` and `AbstractDict` we just construct the context. +function condition_gibbs( + context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict} +) + return GibbsContext(values, context) +end +# If we get more than one argument, we just recurse. +function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) + return condition_gibbs(condition_gibbs(context, value), values...) end -# extract varinfo object from state +# For `DynamicPPL.AbstractVarInfo` we just extract the values. """ - gibbs_varinfo(model, sampler, state) + condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) -Return the variables corresponding to the current `state` of the Gibbs component `sampler`. +Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. """ -gibbs_varinfo(model, sampler, state) = varinfo(state) -varinfo(state) = state.vi -varinfo(state::AbstractVarInfo) = state +function condition_gibbs( + context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo +) + return condition_gibbs( + context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)) + ) +end +function condition_gibbs( + context::DynamicPPL.AbstractContext, + varinfo::DynamicPPL.AbstractVarInfo, + varinfos::DynamicPPL.AbstractVarInfo..., +) + return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) +end +# Allow calling this on a `DynamicPPL.Model` directly. +function condition_gibbs(model::DynamicPPL.Model, values...) + return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) +end + +""" + make_conditional_model(model, varinfo, varinfos) + +Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. + +# Examples +```julia-repl +julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); + +julia> # A separate varinfo for each variable in `model`. + varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); + +julia> # The varinfo we want to NOT condition on. + target_varinfo = first(varinfos); + +julia> # Results in a model with only `m` conditioned. + conditioned_model = make_conditional(model, target_varinfo, varinfos); + +julia> result = conditioned_model(); + +julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` +true + +julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` +true +``` +""" +function make_conditional( + model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos +) + # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. + return condition_gibbs(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) +end +# Assumes the ones given are the ones to condition on. +function make_conditional(model::DynamicPPL.Model, varinfos) + return condition_gibbs(model, varinfos...) +end + +# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` +# or an `AbstractInferenceAlgorithm`. +wrap_algorithm_maybe(x) = x +wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) """ gibbs_state(model, sampler, state, varinfo) @@ -130,122 +249,263 @@ function gibbs_state( end """ - gibbs_rerun(prev_alg, alg) + Gibbs -Check if the model should be rerun to recompute the log density before sampling with the -Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. +A type representing a Gibbs sampler. -By default, the function returns `true`. +# Fields +$(TYPEDFIELDS) """ -gibbs_rerun(prev_alg, alg) = true +struct Gibbs{V,A} <: InferenceAlgorithm + "varnames representing variables for each sampler" + varnames::V + "samplers for each entry in `varnames`" + samplers::A +end + +# NamedTuple +Gibbs(; algs...) = Gibbs(NamedTuple(algs)) +function Gibbs(algs::NamedTuple) + return Gibbs( + map(s -> VarName{s}(), keys(algs)), map(wrap_algorithm_maybe, values(algs)) + ) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms -gibbs_rerun(::GibbsConditional, ::MH) = false -gibbs_rerun(::Hamiltonian, ::MH) = false +# AbstractDict +function Gibbs(algs::AbstractDict) + return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) +end +function Gibbs(algs::Pair...) + return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or a `MH` algorithm -gibbs_rerun(::MH, ::Hamiltonian) = false -gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false +# TODO: Remove when no longer needed. +DynamicPPL.getspace(::Gibbs) = () -# do not have to recompute `vi.logp` since it is not used in `step` -gibbs_rerun(prev_alg, ::GibbsConditional) = false +struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} + vi::V + states::S +end -# Do not recompute `vi.logp` since it is reset anyway in `step` -gibbs_rerun(prev_alg, ::PG) = false +_maybevec(x) = vec(x) # assume it's iterable +_maybevec(x::Tuple) = [x...] +_maybevec(x::VarName) = [x] -# Initialize the Gibbs sampler. function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, vi::AbstractVarInfo; kwargs... + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + vi_base::DynamicPPL.AbstractVarInfo; + initial_params=nothing, + kwargs..., ) - # TODO: Technically this only works for `VarInfo` or `ThreadSafeVarInfo{<:VarInfo}`. - # Should we enforce this? - - # Create tuple of samplers - algs = spl.alg.algs - i = 0 - samplers = map(algs) do alg - i += 1 - if i == 1 - prev_alg = algs[end] - else - prev_alg = algs[i - 1] - end - rerun = gibbs_rerun(prev_alg, alg) - selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun) - Sampler(alg, model, selector) - end + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers - # Add Gibbs to gids for all variables. - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns + # 1. Run the model once to get the varnames present + initial values to condition on. + vi_base = DynamicPPL.VarInfo(model) - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) + # Simple way of setting the initial parameters: set them in the `vi_base` + # if they are given so they propagate to the subset varinfos used by each sampler. + if initial_params !== nothing + vi_base = DynamicPPL.unflatten(vi_base, initial_params) + end - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end + # Create the varinfos for each sampler. + varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + initial_params_all = if initial_params === nothing + fill(nothing, length(varnames)) + else + # Extract from the `vi_base`, which should have the values set correctly from above. + map(vi -> vi[:], varinfos) end - # Compute initial states of the local samplers. - states = map(samplers) do local_spl - # Recompute `vi.logp` if needed. - if local_spl.selector.rerun - vi = last( - DynamicPPL.evaluate!!( - model, vi, DynamicPPL.SamplingContext(rng, local_spl) - ), - ) + # 2. Construct a varinfo for every vn + sampler combo. + states_and_varinfos = map( + samplers, varinfos, initial_params_all + ) do sampler_local, varinfo_local, initial_params_local + # Construct the conditional model. + model_local = make_conditional(model, varinfo_local, varinfos) + + # Take initial step. + new_state_local = last( + AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., + ), + ) + + # Return the new state and the invlinked `varinfo`. + vi_local_state = varinfo(new_state_local) + vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) + DynamicPPL.invlink(vi_local_state, sampler_local, model_local) + else + vi_local_state end + return (new_state_local, vi_local_state_linked) + end + + states = map(first, states_and_varinfos) + varinfos = map(last, states_and_varinfos) - # Compute initial state. - _, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...) + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - # Update `VarInfo` object. - vi = gibbs_varinfo(model, local_spl, state) + return Transition(model, vi), GibbsState(vi, states) +end - return state +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; + kwargs..., +) + alg = spl.alg + samplers = alg.samplers + states = state.states + varinfos = map(varinfo, state.states) + @assert length(samplers) == length(state.states) + + # TODO: move this into a recursive function so we can unroll when reasonable? + for index in 1:length(samplers) + # Take the inner step. + new_state_local, new_varinfo_local = gibbs_step_inner( + rng, model, samplers, states, varinfos, index; kwargs... + ) + + # Update the `states` and `varinfos`. + states = Accessors.setindex(states, new_state_local, index) + varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) end - # Compute initial transition and state. - transition = Transition(model, vi) - state = GibbsState(vi, samplers, states) + # Combine the resulting varinfo objects. + # The last varinfo holds the correctly computed logp. + vi_base = state.vi + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!( + varinfos, merge(vi_base, first(varinfos)), firstindex(varinfos) + ) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - return transition, state + return Transition(model, vi), GibbsState(vi, states) end -# Subsequent steps -function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, state::GibbsState; kwargs... -) - # Iterate through each of the samplers. - vi = state.vi - samplers = state.samplers - states = map(samplers, spl.alg.iterations, state.states) do _sampler, iteration, _state - # Recompute `vi.logp` if needed. - if _sampler.selector.rerun - vi = last(DynamicPPL.evaluate!!(model, rng, vi, _sampler)) - end +# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. +function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) + # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide + # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact + # same `selector` as before but now with `rerun` set to `true` if needed. + return Accessors.@set sampler.selector.rerun = true +end + +# Interface we need a sampler to implement to work as a component in a Gibbs sampler. +""" + gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - # Update state of current sampler with updated `VarInfo` object. - current_state = gibbs_state(model, _sampler, _state, vi) +Check if the log-probability of the destination model needs to be recomputed. - # Step through the local sampler. - newstate = current_state - for _ in 1:iteration - _, newstate = AbstractMCMC.step(rng, model, _sampler, newstate; kwargs...) - end +Defaults to `true` +""" +function gibbs_requires_recompute_logprob( + model_dst, sampler_dst, sampler_src, state_dst, state_src +) + return true +end - # Update `VarInfo` object. - vi = gibbs_varinfo(model, _sampler, newstate) +# TODO: Remove `rng`? +function recompute_logprob!!( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, state +) + vi = varinfo(state) + # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, + # e.g. log-likelihood in the scenario of `ESS`. + # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. + sampler_rerun = make_rerun_sampler(model, sampler) + # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed + # `varinfo`, even if `varinfo` was linked. + vi_new = last( + DynamicPPL.evaluate!!( + model, + vi, + # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. + DynamicPPL.SamplingContext(rng, sampler_rerun), + ) + ) + # Update the state we're about to use if need be. + # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + return gibbs_state(model, sampler, state, vi_new) +end + +function gibbs_step_inner( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + samplers, + states, + varinfos, + index; + kwargs..., +) + # Needs to do a a few things. + sampler_local = samplers[index] + state_local = states[index] + varinfo_local = varinfos[index] + + # Make sure that all `varinfos` are linked. + varinfos_invlinked = map(varinfos) do vi + # NOTE: This is immutable linking! + # TODO: Do we need the `istrans` check here or should we just always use `invlink`? + # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 + DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi + end + varinfo_local_invlinked = varinfos_invlinked[index] + + # 1. Create conditional model. + # Construct the conditional model. + # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, + # otherwise we're conditioning on values which are not in the support of the + # distributions. + model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) + + # Extract the previous sampler and state. + sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] + state_previous = states[index == 1 ? length(states) : index - 1] + + # 1. Re-run the sampler if needed. + if gibbs_requires_recompute_logprob( + model_local, sampler_local, sampler_previous, state_local, state_previous + ) + state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) + end - return newstate + # 2. Take step with local sampler. + new_state_local = last( + AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) + ) + + # 3. Extract the new varinfo. + # Return the resulting state and invlinked `varinfo`. + varinfo_local_state = varinfo(new_state_local) + varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) + DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) + else + varinfo_local_state end - return Transition(model, vi), GibbsState(vi, samplers, states) + # TODO: alternatively, we can return `states_new, varinfos_new, index_new` + return (new_state_local, varinfo_local_state_invlinked) end diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl deleted file mode 100644 index fda79315b..000000000 --- a/src/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,88 +0,0 @@ -""" - GibbsConditional(sym, conditional) - -A "pseudo-sampler" to manually provide analytical Gibbs conditionals to `Gibbs`. -`GibbsConditional(:x, cond)` will sample the variable `x` according to the conditional `cond`, which -must therefore be a function from a `NamedTuple` of the conditioned variables to a `Distribution`. - - -The `NamedTuple` that is passed in contains all random variables from the model in an unspecified -order, taken from the [`VarInfo`](@ref) object over which the model is run. Scalars and vectors are -stored in their respective shapes. The tuple also contains the value of the conditioned variable -itself, which can be useful, but using it creates something that is not a Gibbs sampler anymore (see -[here](https://github.com/TuringLang/Turing.jl/pull/1275#discussion_r434240387)). - -# Examples - -```julia -α_0 = 2.0 -θ_0 = inv(3.0) -x = [1.5, 2.0] -N = length(x) - -@model function inverse_gdemo(x) - λ ~ Gamma(α_0, θ_0) - σ = sqrt(1 / λ) - m ~ Normal(0, σ) - @. x ~ \$(Normal(m, σ)) -end - -# The conditionals can be formulated in terms of the following statistics: -x_bar = mean(x) # sample mean -s2 = var(x; mean=x_bar, corrected=false) # sample variance -m_n = N * x_bar / (N + 1) - -function cond_m(c) - λ_n = c.λ * (N + 1) - σ_n = sqrt(1 / λ_n) - return Normal(m_n, σ_n) -end - -function cond_λ(c) - α_n = α_0 + (N - 1) / 2 + 1 - β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0) - return Gamma(α_n, inv(β_n)) -end - -m = inverse_gdemo(x) - -sample(m, Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)), 10) -``` -""" -struct GibbsConditional{S,C} - conditional::C - - function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) - end -end - -DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,) - -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - return nothing, vi -end - -function AbstractMCMC.step( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - condvals = DynamicPPL.values_as(DynamicPPL.invlink(vi, model), NamedTuple) - conddist = spl.alg.conditional(condvals) - updated = rand(rng, conddist) - # Setindex allows only vectors in this case. - vi = setindex!!(vi, [updated;], spl) - # Update log joint probability. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - - return nothing, vi -end diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl deleted file mode 100644 index 0f0740f14..000000000 --- a/test/experimental/gibbs.jl +++ /dev/null @@ -1,270 +0,0 @@ -module ExperimentalGibbsTests - -using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo -using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical, two_sample_test -using DynamicPPL -using Random -using Test -using Turing -using Turing.Inference: AdvancedHMC, AdvancedMH -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff - -function check_transition_varnames( - transition::Turing.Inference.Transition, - parent_varnames -) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end - # Varnames in `transition` should be subsumed by those in `vns`. - for vn in transition_varnames - @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) - end -end - -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::Model) = true - -@testset "Gibbs using `condition`" begin - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] - - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end - - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end - - @testset "comparison with 'gold-standard' samples" begin - num_iterations = 1_000 - thinning = 10 - num_chains = 4 - - # Determine initial parameters to make comparison as fair as possible. - posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) - - # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Experimental.Gibbs( - vns_s => sampler_inner, - vns_m => sampler_inner, - ) - Random.seed!(42) - chain = sample( - model, - sampler, - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - discard_initial=1_000, - thinning=thinning - ) - - # "Ground truth" samples. - # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. - Random.seed!(42) - chain_true = sample( - model, - NUTS(), - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - thinning=thinning, - ) - - # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i = 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) - # Let's make sure that the significance level is not too low by - # checking that the KS test fails for some simple transformations. - # TODO: Replace the heuristic below with closed-form implementations - # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) - end - end - end - end - - @testset "multiple varnames" begin - rng = Random.default_rng() - - @testset "with both `s` and `m` as random" begin - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - - # `sample` - Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) - end - - @testset "without `m` as random" begin - model = gdemo(1.5, 2.0) | (m=7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - end - end - - @testset "CSMC + ESS" begin - rng = Random.default_rng() - model = MoGtest_default - alg = Turing.Experimental.Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) - check_MoGtest_default(chain, atol = 0.2) - end - - @testset "CSMC + ESS (usage of implicit varname)" begin - rng = Random.default_rng() - model = MoGtest_default_z_vector - alg = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) - check_MoGtest_default_z_vector(chain, atol = 0.2) - end - - @testset "externsalsampler" begin - @model function demo_gibbs_external() - m1 ~ Normal() - m2 ~ Normal() - - -1 ~ Normal(m1, 1) - +1 ~ Normal(m1 + m2, 1) - - return (; m1, m2) - end - - model = demo_gibbs_external() - samplers_inner = [ - externalsampler(AdvancedMH.RWMH(1)), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), - ] - @testset "$(sampler_inner)" for sampler_inner in samplers_inner - sampler = Turing.Experimental.Gibbs( - @varname(m1) => sampler_inner, - @varname(m2) => sampler_inner, - ) - Random.seed!(42) - chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) - check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) - end - end -end - -end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15ec6149c..4a6e0e9a6 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -33,15 +33,15 @@ ADUtils.install_tapir && import Tapir PG(10), IS(), MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) else ( HMC(0.1, 7; adtype=adbackend), IS(), MH(), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) end for sampler in samplers @@ -85,7 +85,7 @@ ADUtils.install_tapir && import Tapir alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(PG(30, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) chn1 = sample(gdemo_default, alg1, 5000; save_state=true) check_gdemo(chn1) @@ -234,7 +234,7 @@ ADUtils.install_tapir && import Tapir smc = SMC() pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x)) + gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -261,7 +261,7 @@ ADUtils.install_tapir && import Tapir return s, m end - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m; adtype=adbackend)) + gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -367,7 +367,7 @@ ADUtils.install_tapir && import Tapir @test all(isone, res_pg[:x]) end @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) + alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) chn = sample(gdemo_default, alg, 1000) end @testset "vectorization @." begin diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 0a1c23a9e..da03e686d 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -38,7 +38,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(ESS(:m), MH(:s)) + s3 = Gibbs(; m=ESS(), s=MH()) c5 = sample(gdemo_default, s3, N) end @@ -52,13 +52,17 @@ using Turing check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1) Random.seed!(100) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) # MoGtest Random.seed!(125) - alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) + alg = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(m2) => ESS(), + ) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain; atol=0.1) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 6868cb5e8..354a19537 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -1,30 +1,65 @@ module GibbsTests -using ..Models: MoGtest_default, gdemo, gdemo_default -using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical +using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo, gdemo_default +using ..NumericalTests: + check_MoGtest_default, + check_MoGtest_default_z_vector, + check_gdemo, + check_numerical, + two_sample_test import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample +using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff using Test: @test, @testset using Turing using Turing: Inference +using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess ADUtils.install_tapir && import Tapir +function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) + transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val + [first(vn_and_val)] + end + # Varnames in `transition` should be subsumed by those in `parent_varnames`. + for vn in transition_varnames + @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) + end +end + +const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, +} +has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false +has_dot_assume(::DynamicPPL.Model) = true + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "gibbs constructor" begin N = 500 - s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - s2 = Gibbs(PG(10, :s, :m)) - s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + s1 = begin + alg = HMC(0.1, 5, :s, :m; adtype=adbackend) + Gibbs(; s=alg, m=alg) + end + s2 = begin + alg = PG(10) + Gibbs(@varname(s) => alg, @varname(m) => alg) + end + s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end @@ -34,22 +69,15 @@ ADUtils.install_tapir && import Tapir c4 = sample(gdemo_default, s4, N) c5 = sample(gdemo_default, s5, N) c6 = sample(gdemo_default, s6, N) + c7 = sample(gdemo_default, s7, N) - # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) - - _, state = AbstractMCMC.step(Random.default_rng(), gdemo_default, g) - @test state.samplers[1].selector != g.selector - @test state.samplers[2].selector != g.selector - @test state.samplers[1].selector != state.samplers[2].selector - - # run sampler: progress logging should be disabled and - # it should return a Chains object @test sample(gdemo_default, g, N) isa MCMCChains.Chains end + @testset "gibbs inference" begin Random.seed!(100) - alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -57,11 +85,11 @@ ADUtils.install_tapir && import Tapir Random.seed!(100) - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) @@ -71,15 +99,17 @@ ADUtils.install_tapir && import Tapir Random.seed!(200) gibbs = Gibbs( - PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2; adtype=adbackend) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), ) chain = sample(MoGtest_default, gibbs, 10_000) check_MoGtest_default(chain; atol=0.15) Random.seed!(200) for alg in [ - Gibbs((MH(:s), 2), (HMC(0.2, 4, :m; adtype=adbackend), 1)), - Gibbs((MH(:s), 1), (HMC(0.2, 4, :m; adtype=adbackend), 2)), + # The new syntax for specifying a sampler to run twice for one variable. + Gibbs(s => MH(), s => MH(), m => HMC(0.2, 4; adtype=adbackend)), + Gibbs(s => MH(), m => HMC(0.2, 4), m => HMC(0.2, 4); adtype=adbackend), ] chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain; atol=0.15) @@ -113,9 +143,10 @@ ADUtils.install_tapir && import Tapir return nothing end - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end + @testset "dynamic model" begin @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) @@ -136,10 +167,250 @@ ADUtils.install_tapir && import Tapir m[k] ~ Normal(1.0, 1.0) end end - model = imm(randn(100), 1.0) + model = imm(Random.randn(100), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); - sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m; adtype=adbackend)), 100) + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) + end + + @testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end + + samplers = [ + Turing.Gibbs(vns_s => NUTS(), vns_m => NUTS()), + Turing.Gibbs(vns_s => NUTS(), vns_m => HMC(0.01, 4)), + ] + + if !has_dot_assume(model) + # Add in some MH samplers, which are not compatible with `.~`. + append!( + samplers, + [ + Turing.Gibbs(vns_s => HMC(0.01, 4), vns_m => MH()), + Turing.Gibbs(vns_s => MH(), vns_m => HMC(0.01, 4)), + ], + ) + end + + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler) + ) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler), state + ) + check_transition_varnames(transition, vns) + end + end + + # Run the Gibbs sampler and NUTS on the same model, compare statistics of the + # chains. + @testset "comparison with 'gold-standard' samples" begin + num_iterations = 1_000 + thinning = 10 + num_chains = 4 + + # Determine initial parameters to make comparison as fair as possible. + posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) + initial_params = DynamicPPL.TestUtils.update_values!!( + DynamicPPL.VarInfo(model), + posterior_mean, + DynamicPPL.TestUtils.varnames(model), + )[:] + initial_params = fill(initial_params, num_chains) + + # Sampler to use for Gibbs components. + sampler_inner = HMC(0.1, 32) + sampler = Turing.Gibbs(vns_s => sampler_inner, vns_m => sampler_inner) + Random.seed!(42) + chain = sample( + model, + sampler, + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + discard_initial=1_000, + thinning=thinning, + ) + + # "Ground truth" samples. + # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. + Random.seed!(42) + chain_true = sample( + model, + NUTS(), + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + thinning=thinning, + ) + + # Perform KS test to ensure that the chains are similar. + xs = Array(chain) + xs_true = Array(chain_true) + for i in 1:size(xs, 2) + @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + # Let's make sure that the significance level is not too low by + # checking that the KS test fails for some simple transformations. + # TODO: Replace the heuristic below with closed-form implementations + # of the targets, once they are implemented in DynamicPPL. + @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + end + end + end + end + + @testset "multiple varnames" begin + rng = Random.default_rng() + + @testset "with both `s` and `m` as random" begin + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # `sample` + Random.seed!(42) + chain = sample(model, alg, 10_000; progress=false) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) + end + + @testset "without `m` as random" begin + model = gdemo(1.5, 2.0) | (m=7 / 6,) + vns = (@varname(s),) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + end + end + + @testset "CSMC + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + vns = ( + @varname(z1), + @varname(z2), + @varname(z3), + @varname(z4), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(MoGtest_default, alg, 1000; progress=false) + check_MoGtest_default(chain; atol=0.2) + end + + @testset "CSMC + ESS (usage of implicit varname)" begin + rng = Random.default_rng() + model = MoGtest_default_z_vector + alg = Turing.Gibbs( + @varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS() + ) + vns = ( + @varname(z[1]), + @varname(z[2]), + @varname(z[3]), + @varname(z[4]), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(model, alg, 1000; progress=false) + check_MoGtest_default_z_vector(chain; atol=0.2) + end + + @testset "externsalsampler" begin + @model function demo_gibbs_external() + m1 ~ Normal() + m2 ~ Normal() + + -1 ~ Normal(m1, 1) + +1 ~ Normal(m1 + m2, 1) + + return (; m1, m2) + end + + model = demo_gibbs_external() + samplers_inner = [ + externalsampler(AdvancedMH.RWMH(1)), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoForwardDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff()), + externalsampler( + AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff(; compile=true) + ), + ] + @testset "$(sampler_inner)" for sampler_inner in samplers_inner + sampler = Turing.Gibbs( + @varname(m1) => sampler_inner, @varname(m2) => sampler_inner + ) + Random.seed!(42) + chain = sample( + model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0 + ) + check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1) + end end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl deleted file mode 100644 index 3f02c7594..000000000 --- a/test/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,172 +0,0 @@ -module GibbsConditionalTests - -using ..Models: gdemo, gdemo_default -using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils -using Clustering: Clustering -using Distributions: Categorical, InverseGamma, Normal, sample -using ForwardDiff: ForwardDiff -using LinearAlgebra: Diagonal, I -using Random: Random -using ReverseDiff: ReverseDiff -using StableRNGs: StableRNG -using StatsBase: counts -using StatsFuns: StatsFuns -using Test: @test, @testset -using Turing - -ADUtils.install_tapir && import Tapir - -@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(1000) - rng = StableRNG(123) - - @testset "gdemo" begin - # We consider the model - # ```math - # s ~ InverseGamma(2, 3) - # m ~ Normal(0, √s) - # xᵢ ~ Normal(m, √s), i = 1, …, N, - # ``` - # with ``N = 2`` observations ``x₁ = 1.5`` and ``x₂ = 2``. - - # The conditionals and posterior can be formulated in terms of the following statistics: - N = 2 - x_mean = 1.75 # sample mean ``∑ xᵢ / N`` - x_var = 0.0625 # sample variance ``∑ (xᵢ - x_bar)^2 / N`` - m_n = 3.5 / 3 # ``∑ xᵢ / (N + 1)`` - - # Conditional distribution - # ```math - # m | s, x ~ Normal(m_n, sqrt(s / (N + 1))) - # ``` - cond_m = let N = N, m_n = m_n - c -> Normal(m_n, sqrt(c.s / (N + 1))) - end - - # Conditional distribution - # ```math - # s | m, x ~ InverseGamma(2 + (N + 1) / 2, 3 + (m^2 + ∑ (xᵢ - m)^2) / 2) = - # InverseGamma(2 + (N + 1) / 2, 3 + m^2 / 2 + N / 2 * (x_var + (x_mean - m)^2)) - # ``` - cond_s = let N = N, x_mean = x_mean, x_var = x_var - c -> InverseGamma( - 2 + (N + 1) / 2, 3 + c.m^2 / 2 + N / 2 * (x_var + (x_mean - c.m)^2) - ) - end - - # Three Gibbs samplers: - # one for each variable fixed to the posterior mean - s_posterior_mean = 49 / 24 - sampler1 = Gibbs( - GibbsConditional(:m, cond_m), - GibbsConditional(:s, _ -> Normal(s_posterior_mean, 0)), - ) - chain = sample(rng, gdemo_default, sampler1, 10_000) - cond_m_mean = mean(cond_m((s=s_posterior_mean,))) - check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean]) - @test all(==(s_posterior_mean), chain[:s][2:end]) - - m_posterior_mean = 7 / 6 - sampler2 = Gibbs( - GibbsConditional(:m, _ -> Normal(m_posterior_mean, 0)), - GibbsConditional(:s, cond_s), - ) - chain = sample(rng, gdemo_default, sampler2, 10_000) - cond_s_mean = mean(cond_s((m=m_posterior_mean,))) - check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean]) - @test all(==(m_posterior_mean), chain[:m][2:end]) - - # and one for both using the conditional - sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s)) - chain = sample(rng, gdemo_default, sampler3, 10_000) - check_gdemo(chain) - end - - @testset "GMM" begin - Random.seed!(1000) - rng = StableRNG(123) - # We consider the model - # ```math - # μₖ ~ Normal(m, σ_μ), k = 1, …, K, - # zᵢ ~ Categorical(π), i = 1, …, N, - # xᵢ ~ Normal(μ_{zᵢ}, σₓ), i = 1, …, N, - # ``` - # with ``K = 2`` clusters, ``N = 20`` observations, and the following parameters: - K = 2 # number of clusters - π = fill(1 / K, K) # uniform cluster weights - m = 0.5 # prior mean of μₖ - σ²_μ = 4.0 # prior variance of μₖ - σ²_x = 0.01 # observation variance - N = 20 # number of observations - - # We generate data - μ_data = rand(rng, Normal(m, sqrt(σ²_μ)), K) - z_data = rand(rng, Categorical(π), N) - x_data = rand(rng, MvNormal(μ_data[z_data], σ²_x * I)) - - @model function mixture(x) - μ ~ $(MvNormal(fill(m, K), σ²_μ * I)) - z ~ $(filldist(Categorical(π), N)) - x ~ MvNormal(μ[z], $(σ²_x * I)) - return x - end - model = mixture(x_data) - - # Conditional distribution ``z | μ, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_z = let x = x_data, log_π = log.(π), σ_x = sqrt(σ²_x) - c -> begin - dists = map(x) do xi - logp = log_π .+ logpdf.(Normal.(c.μ, σ_x), xi) - return Categorical(StatsFuns.softmax!(logp)) - end - return arraydist(dists) - end - end - - # Conditional distribution ``μ | z, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_μ = let K = K, x_data = x_data, inv_σ²_μ = inv(σ²_μ), inv_σ²_x = inv(σ²_x) - c -> begin - # Convert cluster assignments to one-hot encodings - z_onehot = c.z .== (1:K)' - - # Count number of observations in each cluster - n = vec(sum(z_onehot; dims=1)) - - # Compute mean and variance of the conditional distribution - μ_var = @. inv(inv_σ²_x * n + inv_σ²_μ) - μ_mean = (z_onehot' * x_data) .* inv_σ²_x .* μ_var - - return MvNormal(μ_mean, Diagonal(μ_var)) - end - end - - estimate(chain, var) = dropdims(mean(Array(group(chain, var)); dims=1); dims=1) - function estimatez(chain, var, range) - z = Int.(Array(group(chain, var))) - return map(i -> findmax(counts(z[:, i], range))[2], 1:size(z, 2)) - end - - lμ_data, uμ_data = extrema(μ_data) - - # Compare three Gibbs samplers - sampler1 = Gibbs(GibbsConditional(:z, cond_z), GibbsConditional(:μ, cond_μ)) - sampler2 = Gibbs(GibbsConditional(:z, cond_z), MH(:μ)) - sampler3 = Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, :μ; adtype=adbackend)) - for sampler in (sampler1, sampler2, sampler3) - chain = sample(rng, model, sampler, 10_000) - - μ_hat = estimate(chain, :μ) - lμ_hat, uμ_hat = extrema(μ_hat) - @test isapprox([lμ_data, uμ_data], [lμ_hat, uμ_hat], atol=0.1) - - z_hat = estimatez(chain, :z, 1:2) - ari, _, _, _ = Clustering.randindex(z_data, Int.(z_hat)) - @test isapprox(ari, 1, atol=0.1) - end - end -end - -end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index dde977a6f..889be13c5 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -130,9 +130,9 @@ ADUtils.install_tapir && import Tapir @testset "hmcda inference" begin alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend) - # alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m; adtype=adbackend), HMC(0.25, 3, :s; adtype=adbackend)) + # alg2 = Gibbs(; m=HMCDA(200, 0.8, 0.35; adtype=adbackend), s=HMC(0.25, 3; adtype=adbackend)) - # alg3 = Gibbs(HMC(0.25, 3, :m; adtype=adbackend), PG(30, 3, :s)) + # alg3 = Gibbs(; m=HMC(0.25, 3; adtype=adbackend), s=PG(30, 3)) # alg3 = PG(50, 2000) res1 = sample(rng, gdemo_default, alg1, 3000) @@ -147,7 +147,7 @@ ADUtils.install_tapir && import Tapir @testset "hmcda+gibbs inference" begin rng = StableRNG(123) Random.seed!(12345) # particle samplers do not support user-provided `rng` yet - alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) + alg3 = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) check_gdemo(res3) @@ -200,9 +200,9 @@ ADUtils.install_tapir && import Tapir @test size(c2, 1) == 500 end @testset "AHMC resize" begin - alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend)) - alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend)) - alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend)) + alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(rng, gdemo_default, alg1, 300) isa Chains @test sample(rng, gdemo_default, alg2, 300) isa Chains @test sample(rng, gdemo_default, alg3, 300) isa Chains diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index a01d3dc25..f454db5a0 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -32,7 +32,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(MH(:m), MH(:s)) + s4 = Gibbs(; m=MH(), s=MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -62,14 +62,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) Random.seed!(125) # MH within Gibbs - alg = Gibbs(MH(:m), MH(:s)) + alg = Gibbs(; m=MH(), s=MH()) chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params) check_gdemo(chain; atol=0.1) Random.seed!(125) # MoGtest gibbs = Gibbs( - CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1))) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => MH((:mu1, GKernel(1))), + @varname(mu2) => MH((:mu2, GKernel(1))), ) chain = sample( MoGtest_default, @@ -167,7 +169,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg = Gibbs(MH((:μ, vc_μ)), MH((:σ, vc_σ))) + alg = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) chn = sample( mod, diff --git a/test/runtests.jl b/test/runtests.jl index 1aa8bb635..ba9aafd2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,7 +46,6 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin @timeit_include("mcmc/gibbs.jl") - @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") @@ -65,10 +64,6 @@ end end end - @testset "experimental" begin - @timeit_include("experimental/gibbs.jl") - end - @testset "variational optimisers" begin @timeit_include("variational/optimisers.jl") end diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index c1340464f..2dabc09bd 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(PG(20, 1, :x), HMC(0.2, 3, :y))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2 From 5a3e4a66cdacd931c023631576829d8401bf5207 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 15:59:40 +0100 Subject: [PATCH 02/25] Remove dead references to experimental --- .github/workflows/Tests.yml | 5 ++--- src/Turing.jl | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8de296e5e..770eab9a7 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -22,9 +22,8 @@ jobs: - "mcmc/hmc.jl" - "mcmc/abstractmcmc.jl" - "mcmc/Inference.jl" - - "experimental/gibbs.jl" - "mcmc/ess.jl" - - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" + - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl" version: - '1.7' - '1' @@ -79,7 +78,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 # TODO: Use julia-actions/julia-runtest when test_args are supported # Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs - # Ref https://github.com/julia-actions/julia-runtest/pull/73 + # Ref https://github.com/julia-actions/julia-runtest/pull/73 - name: Call Pkg.test run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/src/Turing.jl b/src/Turing.jl index 8fcee6c18..027c190a3 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -55,7 +55,6 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation -include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release ########### From 09c739d0c040545017977a749b5f2305ffd5d462 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 16:00:32 +0100 Subject: [PATCH 03/25] Remove mention of experimental from JuliaFormatter conf --- .JuliaFormatter.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d0e00b45f..745726d46 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -6,9 +6,7 @@ import_to_using = false # These ignores should be removed once the relevant PRs are merged/closed. ignore = [ # https://github.com/TuringLang/Turing.jl/pull/2231/files - "src/experimental/gibbs.jl", "src/mcmc/abstractmcmc.jl", - "test/experimental/gibbs.jl", "test/test_utils/numerical_tests.jl", # https://github.com/TuringLang/Turing.jl/pull/2218/files "src/mcmc/Inference.jl", From 58ebb259af5adc843790c57d7654e3a656a62687 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 10:05:20 +0100 Subject: [PATCH 04/25] Add tests for deprecated constructor --- src/mcmc/gibbs.jl | 20 ++++++++++++++++++++ test/mcmc/gibbs.jl | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index fb05b6475..754451a50 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -279,6 +279,26 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end +# The below constructor only serves to provide backwards compatibility with the constructor +# of the old Gibbs sampler. It is deprecated and will be removed in the future. +function Gibbs(algs::InferenceAlgorithm...) + alg_dict = Dict{Any,InferenceAlgorithm}() + for alg in algs + space = getspace(alg) + space_vns = if (space isa Symbol || space isa VarName) + space + else + tuple((s isa Symbol ? VarName{s}() : s for s in space)...) + end + alg_dict[space_vns] = alg + end + Base.depwarn( + "Specifying which sampler to use with which variable using syntax like `Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. Please use `Gibbs(; x=NUTS(), y=MH())` instead.", + :Gibbs, + ) + return Gibbs(alg_dict) +end + # TODO: Remove when no longer needed. DynamicPPL.getspace(::Gibbs) = () diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 354a19537..dbf4271c1 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -14,7 +14,7 @@ using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff -using Test: @test, @testset +using Test: @test, @test_deprecated, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -44,16 +44,34 @@ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @testset "gibbs constructor" begin - N = 500 - s1 = begin - alg = HMC(0.1, 5, :s, :m; adtype=adbackend) - Gibbs(; s=alg, m=alg) - end - s2 = begin - alg = PG(10) - Gibbs(@varname(s) => alg, @varname(m) => alg) + @testset "Deprecated Gibbs constructors" begin + N = 10 + @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + for s in (s1, s2, s3, s4, s5, s6) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end + + # Check that the samplers work despite using the deprecated constructor. + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + + g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end + + @testset "Gibbs constructors" begin + N = 10 + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5, :s, :m; adtype=adbackend)) + s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) From 771573293eb42dbf45226ac4689bd04ae8352855 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 14:43:21 +0100 Subject: [PATCH 05/25] Fix deprecated Gibbs constructors. Add HISTORY entry. --- HISTORY.md | 16 ++++++++++++++++ src/mcmc/gibbs.jl | 35 +++++++++++++++++++++++++---------- test/mcmc/gibbs.jl | 4 +++- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 5b1cad0ed..11d08e12c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# Release 0.35.0 + +## Breaking changes + +0.35.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. + +The new Gibbs sampler supports the same user-facing interface as the old one. However, given +that the internals of it having been completely rewritten in a very different manner, there +may be accidental breakage that we haven't anticipated. Please report any you find. + +`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. + +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. + +Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`. + # Release 0.33.0 ## Breaking changes diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 754451a50..445bc433f 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -279,24 +279,39 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end -# The below constructor only serves to provide backwards compatibility with the constructor -# of the old Gibbs sampler. It is deprecated and will be removed in the future. +# The below two constructors only provide backwards compatibility with the constructor of +# the old Gibbs sampler. They are deprecated and will be removed in the future. function Gibbs(algs::InferenceAlgorithm...) - alg_dict = Dict{Any,InferenceAlgorithm}() - for alg in algs + varnames = map(algs) do alg space = getspace(alg) - space_vns = if (space isa Symbol || space isa VarName) + if (space isa VarName) space + elseif (space isa Symbol) + VarName{space}() else tuple((s isa Symbol ? VarName{s}() : s for s in space)...) end - alg_dict[space_vns] = alg end - Base.depwarn( - "Specifying which sampler to use with which variable using syntax like `Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. Please use `Gibbs(; x=NUTS(), y=MH())` instead.", - :Gibbs, + msg = ( + "Specifying which sampler to use with which variable using syntax like " * + "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * + "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * + "counts for different subsamplers, use e.g. " * + "`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`" ) - return Gibbs(alg_dict) + Base.depwarn(msg, :Gibbs) + return Gibbs(varnames, map(wrap_algorithm_maybe, algs)) +end + +function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) + algs = Iterators.map(first, algs_with_iters) + iters = Iterators.map(last, algs_with_iters) + algs_duplicated = Iterators.flatten(( + Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters) + )) + # This calls the other deprecated constructor from above, hence no need for a depwarn + # here. + return Gibbs(algs_duplicated...) end # TODO: Remove when no longer needed. diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index dbf4271c1..9162c6cf8 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -52,7 +52,8 @@ has_dot_assume(::DynamicPPL.Model) = true @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end @@ -63,6 +64,7 @@ has_dot_assume(::DynamicPPL.Model) = true sample(gdemo_default, s4, N) sample(gdemo_default, s5, N) sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains From 672f7d986bcacc3fffc438de14b7581ce839ccc2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 14:43:38 +0100 Subject: [PATCH 06/25] Bump version to 0.35.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5f5c86b04..e22c4f4ae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.34.1" +version = "0.35.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7bf5abefde676b5aad9f293319ee7eac49f047ba Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 15:39:50 +0100 Subject: [PATCH 07/25] Add Gibbs constructor test for repeat samplers --- test/mcmc/gibbs.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 9162c6cf8..5082a5f4f 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -79,17 +79,25 @@ has_dot_assume(::DynamicPPL.Model) = true s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) - for s in (s1, s2, s3, s4, s5, s6, s7) + s8 = begin + hmc = HMC(0.1, 5; adtype=adbackend) + pg = PG(10) + vns = @varname(s) + vnm = @varname(m) + Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + end + for s in (s1, s2, s3, s4, s5, s6, s7, s8) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - c1 = sample(gdemo_default, s1, N) - c2 = sample(gdemo_default, s2, N) - c3 = sample(gdemo_default, s3, N) - c4 = sample(gdemo_default, s4, N) - c5 = sample(gdemo_default, s5, N) - c6 = sample(gdemo_default, s6, N) - c7 = sample(gdemo_default, s7, N) + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + sample(gdemo_default, s8, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains From 85bcfa51cf048c5f73054a136fde37e21d7e99f2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 15:57:31 +0100 Subject: [PATCH 08/25] Fix typo in test/mcmc/ess.jl --- test/mcmc/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index da03e686d..8d9697d9a 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -61,7 +61,7 @@ using Turing alg = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), - @varname(m2) => ESS(), + @varname(mu2) => ESS(), ) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain; atol=0.1) From 6f9679ac659210276fd1cb9acdc2a728573699e3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 16:11:34 +0100 Subject: [PATCH 09/25] Use provided rng to initialise VarInfo in Gibbs --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 445bc433f..571d694e3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -339,7 +339,7 @@ function DynamicPPL.initialstep( samplers = alg.samplers # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) + vi_base = DynamicPPL.VarInfo(rng, model) # Simple way of setting the initial parameters: set them in the `vi_base` # if they are given so they propagate to the subset varinfos used by each sampler. From f247ad997a26151f42b180c1e2121df76d9ca6d5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 8 Oct 2024 15:01:02 +0100 Subject: [PATCH 10/25] Fix a typo in GibbsContext --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 571d694e3..4986ff87d 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -15,7 +15,7 @@ struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstra context::Ctx end -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) +GibbsContext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context From d19afe18a971b6c623acea3aefff48ac59498911 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 11 Oct 2024 16:55:08 +0100 Subject: [PATCH 11/25] Fix the Gibbs sampler --- Project.toml | 2 +- src/mcmc/gibbs.jl | 182 +++++++++++++++++----------------------------- 2 files changed, 66 insertions(+), 118 deletions(-) diff --git a/Project.toml b/Project.toml index e22c4f4ae..e23cb5f0b 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ TuringOptimExt = "Optim" [compat] ADTypes = "0.2, 1" -AbstractMCMC = "5.2" +AbstractMCMC = "5.5" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 4986ff87d..7581eeedd 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -173,42 +173,14 @@ function condition_gibbs(model::DynamicPPL.Model, values...) return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) end -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" function make_conditional( - model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos + model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs(model, varinfos...) + not_target_variables = filter( + x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) + ) + vi_filtered = subset(varinfo, not_target_variables) + return condition_gibbs(model, vi_filtered) end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -219,13 +191,15 @@ wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) """ gibbs_state(model, sampler, state, varinfo) -Return an updated state, taking into account the variables sampled by other Gibbs components. +Return an updated state for a component sampler. + +This takes into account changes caused by other Gibbs components. # Arguments - `model`: model targeted by the Gibbs sampler. - `sampler`: the sampler for this Gibbs component. - `state`: the state of `sampler` computed in the previous iteration. -- `varinfo`: the variables, including the ones sampled by other Gibbs components. +- `varinfo`: the current values of the variables relevant for this sampler. """ gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo function gibbs_state(model, sampler, state::PGState, varinfo::AbstractVarInfo) @@ -237,12 +211,13 @@ function gibbs_state( model::Model, spl::Sampler{<:Hamiltonian}, state::HMCState, varinfo::AbstractVarInfo ) # Update hamiltonian - θ_old = varinfo[spl] - hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_old)) + θ_new = varinfo[:] + hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_new)) + # Update the parameter values in `state.z`. # TODO: Avoid mutation - resize!(state.z.θ, length(θ_old)) - state.z.θ .= θ_old + resize!(state.z.θ, length(θ_new)) + state.z.θ .= θ_new z = state.z return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) @@ -348,55 +323,41 @@ function DynamicPPL.initialstep( end # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + local_varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) initial_params_all = if initial_params === nothing fill(nothing, length(varnames)) else # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) + map(vi -> vi[:], local_varinfos) end # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map( - samplers, varinfos, initial_params_all - ) do sampler_local, varinfo_local, initial_params_local + states = [] + for (varnames_local, sampler_local, initial_params_local) in + zip(varnames, samplers, initial_params_all) # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) + model_local = make_conditional(model, _maybevec(varnames_local), vi_base) # Take initial step. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs..., - ), + _, new_state_local = AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., ) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) + vi_local = varinfo(new_state_local) + vi_local = if DynamicPPL.istrans(vi_local) + DynamicPPL.invlink(vi_local, sampler_local, model_local) else - vi_local_state + vi_local end - return (new_state_local, vi_local_state_linked) + vi_base = merge(vi_base, vi_local) + push!(states, new_state_local) end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) - ) - - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi_base), GibbsState(vi_base, states) end function AbstractMCMC.step( @@ -406,37 +367,23 @@ function AbstractMCMC.step( state::GibbsState; kwargs..., ) + vi = varinfo(state) alg = spl.alg + varnames = alg.varnames samplers = alg.samplers states = state.states - varinfos = map(varinfo, state.states) @assert length(samplers) == length(state.states) # TODO: move this into a recursive function so we can unroll when reasonable? for index in 1:length(samplers) # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, model, samplers, states, varinfos, index; kwargs... + vi, new_state_local = gibbs_step_inner( + rng, model, varnames, samplers, states, vi, index; kwargs... ) - # Update the `states` and `varinfos`. + # Update the `states` states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, merge(vi_base, first(varinfos)), firstindex(varinfos) - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) - ) - return Transition(model, vi), GibbsState(vi, states) end @@ -486,40 +433,50 @@ function recompute_logprob!!( return gibbs_state(model, sampler, state, vi_new) end +AbstractMCMC.setparams!!(::VarInfo, vi::VarInfo) = vi +function AbstractMCMC.setparams!!(state, vi::VarInfo) + # In the fallback implementation we guess that `state` has a field called `vi` we can + # set. Fingers crossed! + try + return Accessors.set(state, Accessors.PropertyLens{:vi}(), vi) + catch + error( + "Unable to set `state.vi` for a $(typeof(state)). " * + "Consider writing a method for setparams!! for this type.", + ) + end +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, + varnames, samplers, states, - varinfos, + vi, index; kwargs..., ) # Needs to do a a few things. sampler_local = samplers[index] state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] + varnames_local = _maybevec(varnames[index]) + + vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi # 1. Create conditional model. # Construct the conditional model. # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, # otherwise we're conditioning on values which are not in the support of the # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) + model_local = make_conditional(model, varnames_local, vi) + varinfo_local = subset(vi, varnames_local) # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] + state_local = AbstractMCMC.setparams!!(state_local, varinfo_local) # 1. Re-run the sampler if needed. if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous @@ -532,15 +489,6 @@ function gibbs_step_inner( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) + new_vi = merge(vi, varinfo(new_state_local)) + return new_vi, new_state_local end From 19598c4bd8d7f15db86fa1035228308307443e7e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 17 Oct 2024 17:17:19 +0100 Subject: [PATCH 12/25] Fix the Gibbs sampler more --- src/mcmc/abstractmcmc.jl | 2 +- src/mcmc/gibbs.jl | 262 ++++++++++++++++++++++++++------------- 2 files changed, 180 insertions(+), 84 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 7e6c64d11..c6ca61a9c 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -65,7 +65,7 @@ function recompute_logprob!!( rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? model::DynamicPPL.Model, sampler::DynamicPPL.Sampler{<:ExternalSampler}, - state, + state, # TODO(mhauru) Could we type constrain this to TuringState? ) # Re-using the log-density function from the `state` and updating only the `model` field, # since the `model` might now contain different conditioning values. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 7581eeedd..aa01a3b1f 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -127,52 +127,33 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) end """ - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. +Return a `GibbsContext` with the values extracted from the given `varinfo` treated as +conditioned. """ -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. function condition_gibbs( - context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict} + context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo ) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs(condition_gibbs(context, value), values...) + # TODO(mhauru) Maybe use preferred_value_type to return NamedTuples in some cases. + # If not, then remove preferred_value_type. + vals = DynamicPPL.OrderedDict(k => varinfo[k] for k in keys(varinfo)) + return GibbsContext(vals, context) end -# For `DynamicPPL.AbstractVarInfo` we just extract the values. """ - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) + make_conditional(model, target_variables, varinfo) -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs( - context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo -) - return condition_gibbs( - context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)) - ) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo..., -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end +Return a new, conditioned model for a component of a Gibbs sampler. +# Arguments +- `model::DynamicPPL.Model`: The model to condition. +- `target_variables::AbstractVector{<:VarName}`: The target variables of the component +sampler. These will _not_ conditioned. +- `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the +values in `varinfo` but not in `target_variables` will be conditioned to the values they +have in `varinfo`. +""" function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) @@ -180,7 +161,8 @@ function make_conditional( x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) ) vi_filtered = subset(varinfo, not_target_variables) - return condition_gibbs(model, vi_filtered) + gibbs_context = condition_gibbs(model.context, vi_filtered) + return DynamicPPL.contextualize(model, gibbs_context) end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -188,41 +170,6 @@ end wrap_algorithm_maybe(x) = x wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) -""" - gibbs_state(model, sampler, state, varinfo) - -Return an updated state for a component sampler. - -This takes into account changes caused by other Gibbs components. - -# Arguments -- `model`: model targeted by the Gibbs sampler. -- `sampler`: the sampler for this Gibbs component. -- `state`: the state of `sampler` computed in the previous iteration. -- `varinfo`: the current values of the variables relevant for this sampler. -""" -gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo -function gibbs_state(model, sampler, state::PGState, varinfo::AbstractVarInfo) - return PGState(varinfo, state.rng) -end - -# Update state in Gibbs sampling -function gibbs_state( - model::Model, spl::Sampler{<:Hamiltonian}, state::HMCState, varinfo::AbstractVarInfo -) - # Update hamiltonian - θ_new = varinfo[:] - hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_new)) - - # Update the parameter values in `state.z`. - # TODO: Avoid mutation - resize!(state.z.θ, length(θ_new)) - state.z.θ .= θ_new - z = state.z - - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) -end - """ Gibbs @@ -349,6 +296,7 @@ function DynamicPPL.initialstep( kwargs..., ) vi_local = varinfo(new_state_local) + # TODO(mhauru) Can we remove the invlinking? vi_local = if DynamicPPL.istrans(vi_local) DynamicPPL.invlink(vi_local, sampler_local, model_local) else @@ -428,25 +376,159 @@ function recompute_logprob!!( DynamicPPL.SamplingContext(rng, sampler_rerun), ) ) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return gibbs_state(model, sampler, state, vi_new) + return setlogp!!(state, vi_new.logp[]) end -AbstractMCMC.setparams!!(::VarInfo, vi::VarInfo) = vi -function AbstractMCMC.setparams!!(state, vi::VarInfo) +# TODO(mhauru) Would really like to type constraint this to something like AbstractMCMCState +# if such a thing existed. +function DynamicPPL.setlogp!!(state, logp) + try + new_vi = setlogp!!(state.vi, logp) + if new_vi !== state.vi + return Accessors.set(state, Accessors.PropertyLens{:vi}(), new_vi) + else + return state + end + catch + error( + "Unable to set `state.vi` for a $(typeof(state)). " * + "Consider writing a method for `setlogp!!` for this type.", + ) + end +end + +function DynamicPPL.setlogp!!(state::TuringState, logp) + return TuringState(setlogp!!(state.state, logp), logp) +end + +# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? +# The current list is a guess, but I think some might be unnecessary. +""" + reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) + +Return an updated state for a component sampler. + +This takes into account changes caused by other Gibbs components. The default implementation +is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to +do, a method should be implemented for the specific type of `state`. + +# Arguments +- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not +sampled by this component sampler have been conditioned with a `GibbsContext`. +- `sampler::DynamicPPL.Sampler`: The current component sampler. +- `state`: The state of this component sampler from its previous iteration. +- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables +sampled by this component sampler. +- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain. +- `state_previous`: The state returned by the previous sampler. + +# Returns +An updated state of the same type as `state`. It should have variables set to the values in +`varinfo`, and any other relevant updates done. +""" +function reset_state!!( + model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous +) # In the fallback implementation we guess that `state` has a field called `vi` we can # set. Fingers crossed! try - return Accessors.set(state, Accessors.PropertyLens{:vi}(), vi) + return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo) catch error( "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for setparams!! for this type.", + "Consider writing a method for reset_state!! for this type.", ) end end +function reset_state!!( + model, + sampler, + state::AbstractVarInfo, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + return varinfo +end + +function reset_state!!( + model, + sampler, + state::TuringState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + new_inner_state = reset_state!!( + model, sampler, state.state, varinfo, sampler_previous, state_previous + ) + return TuringState(new_inner_state, state.logdensity) +end + +function reset_state!!( + model, + sampler, + state::HMCState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + θ_new = varinfo[:] + hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) +end + +function reset_state!!( + model, + sampler, + state::AdvancedHMC.HMCState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + hamiltonian = AdvancedHMC.Hamiltonian( + state.metric, DynamicPPL.LogDensityFunction(model) + ) + θ_new = varinfo[:] + # Set the momentum to zero, since we have no idea what it should be at the new parameter + # values. + return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, θ_new, zero(θ_new) + ) +end + +function reset_state!!( + model, + sampler, + state::AdvancedMH.Transition, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + # TODO(mhauru) Setting the last argument like this seems a bit suspect, since the + # current values for the parameters might not have come from this sampler at all. + # I don't see a better way though. + return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted) +end + +function reset_state!!( + model, + sampler, + state::PGState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + return PGState(varinfo, state.rng) +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -462,6 +544,7 @@ function gibbs_step_inner( state_local = states[index] varnames_local = _maybevec(varnames[index]) + # TODO(mhauru) Can we remove the invlinking? vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi # 1. Create conditional model. @@ -471,13 +554,24 @@ function gibbs_step_inner( # distributions. model_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) + # If the varinfo of the previous state from this sampler is linked, we should link the + # new varinfo too. + if DynamicPPL.istrans(varinfo(state_local)) + varinfo_local = DynamicPPL.link(varinfo_local, sampler_local, model_local) + end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] - state_local = AbstractMCMC.setparams!!(state_local, varinfo_local) - # 1. Re-run the sampler if needed. + state_local = reset_state!!( + model_local, + sampler_local, + state_local, + varinfo_local, + sampler_previous, + state_previous, + ) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) @@ -489,6 +583,8 @@ function gibbs_step_inner( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) - new_vi = merge(vi, varinfo(new_state_local)) + new_vi_local = varinfo(new_state_local) + new_vi = merge(vi, new_vi_local) + new_vi = setlogp!!(new_vi, new_vi_local.logp[]) return new_vi, new_state_local end From 71f26ca09dbd7128b0a9e75bea66b734f440fae3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 18 Oct 2024 11:09:01 +0100 Subject: [PATCH 13/25] Remove mentions of old Gibbs sampler from MH docs Co-authored-by: Penelope Yong --- src/mcmc/mh.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index ffc064eb1..edec84365 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -20,7 +20,6 @@ Construct a Metropolis-Hastings algorithm. The arguments `space` can be - Blank (i.e. `MH()`), in which case `MH` defaults to using the prior for each parameter as the proposal distribution. -- A set of one or more symbols to sample with `MH` in conjunction with `Gibbs`, i.e. `Gibbs(MH(:m), PG(10, :s))` - An iterable of pairs or tuples mapping a `Symbol` to a `AdvancedMH.Proposal`, `Distribution`, or `Function` that generates returns a conditional proposal distribution. - A covariance matrix to use as for mean-zero multivariate normal proposals. @@ -41,22 +40,6 @@ chain = sample(gdemo(1.5, 2.0), MH(), 1_000) mean(chain) ``` -Alternatively, you can specify particular parameters to sample if you want to combine sampling -from multiple samplers: - -```julia -@model function gdemo(x, y) - s² ~ InverseGamma(2,3) - m ~ Normal(0, sqrt(s²)) - x ~ Normal(m, sqrt(s²)) - y ~ Normal(m, sqrt(s²)) -end - -# Samples s with MH and m with PG -chain = sample(gdemo(1.5, 2.0), Gibbs(MH(:s), PG(10, :m)), 1_000) -mean(chain) -``` - Using custom distributions defaults to using static MH: ```julia From d0f57acf249abc2c4463f3fe93fc8b07d3738472 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 21 Oct 2024 14:29:42 +0100 Subject: [PATCH 14/25] Bump DPPL to 0.28.6 --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e23cb5f0b..1b0003ab2 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.28.2" +DynamicPPL = "0.28.6" Compat = "4.15.0" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" diff --git a/test/Project.toml b/test/Project.toml index 67292d2af..7d463eb8a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.28" +DynamicPPL = "0.28.6" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From 74b57e716915d5d01b51c3588eda79583d1b7ace Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 11:32:13 +0000 Subject: [PATCH 15/25] Redesign GibbsContext, work in progress --- Project.toml | 2 +- src/mcmc/gibbs.jl | 173 ++++++++++++++++++++++++++++------------------ 2 files changed, 107 insertions(+), 68 deletions(-) diff --git a/Project.toml b/Project.toml index c9b9c71de..ab43ccbc3 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.30" +DynamicPPL = "0.30.1" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index aa01a3b1f..21e062fd1 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,44 +10,80 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values +struct GibbsContext{ + VNs,Values,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext +} <: DynamicPPL.AbstractContext + target_varnames::VNs + conditioned_values::Values + global_varinfo::GVI context::Ctx end -GibbsContext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) +function GibbsContext(target_varnames, conditioned_values, global_varinfo) + return GibbsContext( + target_varnames, conditioned_values, global_varinfo, DynamicPPL.DefaultContext() + ) +end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) - return GibbsContext(context.values, childcontext) + return GibbsContext( + context.target_varnames, + context.conditioned_values, + Ref(context.global_varinfo[]), + childcontext, + ) end # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.hasvalue(context.values, vn) + return DynamicPPL.hasvalue(context.conditioned_values, vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(has_conditioned_gibbs, context), vns) end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.getvalue(context.values, vn) + return DynamicPPL.getvalue(context.conditioned_values, vn) end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) end +function is_target_varname(context::GibbsContext, vn::VarName) + return Iterators.any( + Iterators.map(target -> subsumes(target, vn), context.target_varnames) + ) +end + # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vn) value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? return value, logpdf(right, value), vi + elseif is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, lp, new_global_vi = DynamicPPL.tilde_assume( + DynamicPPL.SamplingContext( + DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) + ), + right, + vn, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) end function DynamicPPL.tilde_assume( @@ -56,13 +92,30 @@ function DynamicPPL.tilde_assume( # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vn) value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? return value, logpdf(right, value), vi + elseif is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, lp, new_global_vi = DynamicPPL.tilde_assume( + DynamicPPL.SamplingContext( + rng, DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) + ), + right, + vn, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) end # Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. @@ -126,21 +179,6 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - -Return a `GibbsContext` with the values extracted from the given `varinfo` treated as -conditioned. -""" -function condition_gibbs( - context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo -) - # TODO(mhauru) Maybe use preferred_value_type to return NamedTuples in some cases. - # If not, then remove preferred_value_type. - vals = DynamicPPL.OrderedDict(k => varinfo[k] for k in keys(varinfo)) - return GibbsContext(vals, context) -end - """ make_conditional(model, target_variables, varinfo) @@ -157,12 +195,15 @@ have in `varinfo`. function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) + # We want to condition all the variables in keys(varinfo) that are not subsumed by any + # of the target variables. not_target_variables = filter( x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) ) vi_filtered = subset(varinfo, not_target_variables) - gibbs_context = condition_gibbs(model.context, vi_filtered) - return DynamicPPL.contextualize(model, gibbs_context) + vals = DynamicPPL.OrderedDict(k => vi_filtered[k] for k in keys(vi_filtered)) + gibbs_context = GibbsContext(target_variables, vals, Ref(varinfo), model.context) + return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -261,29 +302,21 @@ function DynamicPPL.initialstep( samplers = alg.samplers # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(rng, model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. + vi = DynamicPPL.VarInfo(rng, model) if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - local_varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], local_varinfos) + vi = DynamicPPL.unflatten(vi, initial_params) end - # 2. Construct a varinfo for every vn + sampler combo. + # Initialise each component sampler in turn, collect all their states. states = [] - for (varnames_local, sampler_local, initial_params_local) in - zip(varnames, samplers, initial_params_all) + for (varnames_local, sampler_local) in zip(varnames, samplers) + varnames_local = _maybevec(varnames_local) + # Get the initial values for this component sampler. + vi_local = DynamicPPL.subset(vi, varnames_local) + initial_params_local = initial_params === nothing ? nothing : vi_local[:] + # Construct the conditional model. - model_local = make_conditional(model, _maybevec(varnames_local), vi_base) + model_local, context_local = make_conditional(model, varnames_local, vi) # Take initial step. _, new_state_local = AbstractMCMC.step( @@ -295,17 +328,21 @@ function DynamicPPL.initialstep( initial_params=initial_params_local, kwargs..., ) - vi_local = varinfo(new_state_local) + new_vi_local = varinfo(new_state_local) # TODO(mhauru) Can we remove the invlinking? - vi_local = if DynamicPPL.istrans(vi_local) - DynamicPPL.invlink(vi_local, sampler_local, model_local) + new_vi_local = if DynamicPPL.istrans(new_vi_local) + DynamicPPL.invlink(new_vi_local, sampler_local, model_local) else - vi_local + new_vi_local end - vi_base = merge(vi_base, vi_local) + # This merges in any new variables that were introduced during the step, but that + # were not in the domain of the current sampler. + vi = merge(vi, context_local.global_varinfo[]) + # This merges the latest values for all the variables in the current sampler. + vi = merge(vi, new_vi_local) push!(states, new_state_local) end - return Transition(model, vi_base), GibbsState(vi_base, states) + return Transition(model, vi), GibbsState(vi, states) end function AbstractMCMC.step( @@ -328,8 +365,6 @@ function AbstractMCMC.step( vi, new_state_local = gibbs_step_inner( rng, model, varnames, samplers, states, vi, index; kwargs... ) - - # Update the `states` states = Accessors.setindex(states, new_state_local, index) end return Transition(model, vi), GibbsState(vi, states) @@ -379,7 +414,7 @@ function recompute_logprob!!( return setlogp!!(state, vi_new.logp[]) end -# TODO(mhauru) Would really like to type constraint this to something like AbstractMCMCState +# TODO(mhauru) Would really like to type constrain this to something like AbstractMCMCState # if such a thing existed. function DynamicPPL.setlogp!!(state, logp) try @@ -441,6 +476,8 @@ function reset_state!!( end end +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `reset_state!!`. function reset_state!!( model, sampler, @@ -539,7 +576,6 @@ function gibbs_step_inner( index; kwargs..., ) - # Needs to do a a few things. sampler_local = samplers[index] state_local = states[index] varnames_local = _maybevec(varnames[index]) @@ -547,13 +583,10 @@ function gibbs_step_inner( # TODO(mhauru) Can we remove the invlinking? vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varnames_local, vi) + # Construct the conditional model and the varinfo that this sampler should use. + model_local, context_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) + # TODO(mhauru) Can we remove the below, if get rid of all the invlinking? # If the varinfo of the previous state from this sampler is linked, we should link the # new varinfo too. if DynamicPPL.istrans(varinfo(state_local)) @@ -564,6 +597,8 @@ function gibbs_step_inner( sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] + # Set the state of the current sampler, accounting for any changes made by other + # samplers. state_local = reset_state!!( model_local, sampler_local, @@ -578,13 +613,17 @@ function gibbs_step_inner( state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) end - # 2. Take step with local sampler. + # Take a step with the local sampler. new_state_local = last( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) new_vi_local = varinfo(new_state_local) - new_vi = merge(vi, new_vi_local) + # This merges in any new variables that were introduced during the step, but that + # were not in the domain of the current sampler. + new_vi = merge(vi, context_local.global_varinfo[]) + # This merges the latest values for all the variables in the current sampler. + new_vi = merge(new_vi, new_vi_local) new_vi = setlogp!!(new_vi, new_vi_local.logp[]) return new_vi, new_state_local end From b16daf5eff519f08e40b712a0aa1cfc3a95146f0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 15:34:33 +0000 Subject: [PATCH 16/25] Fixing new Gibbs, adding a broken test --- src/mcmc/gibbs.jl | 132 ++++++++++++++++++++++++++------------------- test/mcmc/gibbs.jl | 35 ++++++++++-- 2 files changed, 110 insertions(+), 57 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 21e062fd1..21d71d9b2 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,42 +10,35 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{ - VNs,Values,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext -} <: DynamicPPL.AbstractContext +struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext target_varnames::VNs - conditioned_values::Values global_varinfo::GVI context::Ctx end -function GibbsContext(target_varnames, conditioned_values, global_varinfo) - return GibbsContext( - target_varnames, conditioned_values, global_varinfo, DynamicPPL.DefaultContext() - ) +function GibbsContext(target_varnames, global_varinfo) + return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) return GibbsContext( - context.target_varnames, - context.conditioned_values, - Ref(context.global_varinfo[]), - childcontext, + context.target_varnames, Ref(context.global_varinfo[]), childcontext ) end # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.hasvalue(context.conditioned_values, vn) + return DynamicPPL.haskey(context.global_varinfo[], vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(has_conditioned_gibbs, context), vns) end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.getvalue(context.conditioned_values, vn) + return context.global_varinfo[][vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -57,26 +50,30 @@ function is_target_varname(context::GibbsContext, vn::VarName) ) end +function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(is_target_varname, context), vns) +end + # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) + if is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + elseif has_conditioned_gibbs(context, vn) + # Short-circuits the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) # TODO(mhauru) Is the call to logpdf correct if context.context is not # DefaultContext? return value, logpdf(right, value), vi - elseif is_target_varname(context, vn) - # Fall back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( - DynamicPPL.SamplingContext( - DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) - ), + DynamicPPL.childcontext(context), + prior_sampler, right, vn, context.global_varinfo[], @@ -89,26 +86,27 @@ end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) Is the call to logpdf correct if context.context is not - # DefaultContext? - return value, logpdf(right, value), vi - elseif is_target_varname(context, vn) + if is_target_varname(context, vn) # Fall back to the default behavior. return DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) + elseif has_conditioned_gibbs(context, vn) + # Short-circuits the tilde assume if `vn` is present in `context`. + value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? + return value, logpdf(right, value), vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( - DynamicPPL.SamplingContext( - rng, DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) - ), + rng, + DynamicPPL.childcontext(context), + prior_sampler, right, vn, context.global_varinfo[], @@ -137,31 +135,64 @@ function reconstruct_getvalue( end function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) + if is_target_varname(context, vns) + # Fall back to the default behavior. + return DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vns, vi + ) + elseif has_conditioned_gibbs(context, vns) + # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) return value, broadcast_logpdf(right, value), vi + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), + prior_sampler, + right, + left, + vns, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume( - DynamicPPL.childcontext(context), right, left, vns, vi - ) end function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) + if is_target_varname(context, vns) + # Fall back to the default behavior. + return DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi + ) + elseif has_conditioned_gibbs(context, vns) + # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) return value, broadcast_logpdf(right, value), vi + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + rng, + DynamicPPL.childcontext(context), + prior_sampler, + right, + left, + vns, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi - ) end """ @@ -195,14 +226,7 @@ have in `varinfo`. function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) - # We want to condition all the variables in keys(varinfo) that are not subsumed by any - # of the target variables. - not_target_variables = filter( - x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) - ) - vi_filtered = subset(varinfo, not_target_variables) - vals = DynamicPPL.OrderedDict(k => vi_filtered[k] for k in keys(vi_filtered)) - gibbs_context = GibbsContext(target_variables, vals, Ref(varinfo), model.context) + gibbs_context = GibbsContext(target_variables, Ref(varinfo), model.context) return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index fc3ea6352..c11f29162 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -15,7 +15,7 @@ using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake -using Test: @test, @test_deprecated, @testset +using Test: @test, @test_broken, @test_deprecated, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -135,8 +135,16 @@ has_dot_assume(::DynamicPPL.Model) = true Random.seed!(200) for alg in [ # The new syntax for specifying a sampler to run twice for one variable. - Gibbs(s => MH(), s => MH(), m => HMC(0.2, 4; adtype=adbackend)), - Gibbs(s => MH(), m => HMC(0.2, 4), m => HMC(0.2, 4); adtype=adbackend), + Gibbs( + @varname(s) => MH(), + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), + Gibbs( + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), ] chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain; atol=0.15) @@ -200,6 +208,27 @@ has_dot_assume(::DynamicPPL.Model) = true sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) end + @testset "dynamic model with dot tilde" begin + @model function dynamic_model_with_dot_tilde(num_zs=10) + z = Vector(undef, num_zs) + z .~ Exponential(1.0) + num_ms = Int(round(sum(z))) + m = Vector(undef, num_ms) + return m .~ Normal(1.0, 1.0) + end + model = dynamic_model_with_dot_tilde() + # TODO(mhauru) This is broken because of + # https://github.com/TuringLang/DynamicPPL.jl/issues/700. + @test_broken ( + sample( + model, + Gibbs(; z=NUTS(; adtype=adbackend), m=HMC(0.01, 4; adtype=adbackend)), + 100, + ); + true + ) + end + @testset "Demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) From af802dc9e6715d7dc00b151db8b6721a5eaad3f6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 16:43:15 +0000 Subject: [PATCH 17/25] Document and clean up GibbsContext --- src/mcmc/gibbs.jl | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 21d71d9b2..32ad22db3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,10 +10,31 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. +""" + GibbsContext(target_varnames, global_varinfo, context) + +A context used in the implementation of the Turing.jl Gibbs sampler. + +There will be one `GibbsContext` for each iteration of a component sampler. +""" struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + """ + a collection of `VarName`s that are the ones the current component sampler is sampling. + For them, `GibbsContext` will just pass tilde_assume calls to its child context. + For other variables, their values will be fixed to the values they have in + `global_varinfo`. + """ target_varnames::VNs + """ + a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both + those fixed and those being sampled. We use a `Ref` because this field may need to be + updated if new variables are introduced. + """ global_varinfo::GVI + """ + the child context that tilde calls will eventually be passed onto. + """ context::Ctx end @@ -34,7 +55,14 @@ function has_conditioned_gibbs(context::GibbsContext, vn::VarName) return DynamicPPL.haskey(context.global_varinfo[], vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) + num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) + if (num_conditioned != 0) && (num_conditioned != length(vns)) + error( + "Some but not all of the variables in `vns` have been conditioned on. " * + "Having mixed conditioning like this is not supported in GibbsContext.", + ) + end + return num_conditioned > 0 end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) @@ -51,7 +79,14 @@ function is_target_varname(context::GibbsContext, vn::VarName) end function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(is_target_varname, context), vns) + num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) + if (num_target != 0) && (num_target != length(vns)) + error( + "Some but not all of the variables in `vns` are target variables. " * + "Having mixed targeting like this is not supported in GibbsContext.", + ) + end + return num_target > 0 end # Tilde pipeline From e58d93560794ec77cbac639a936542449ffe2a1a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 17:03:57 +0000 Subject: [PATCH 18/25] Code style and docs improvements to Gibbs --- src/mcmc/gibbs.jl | 79 +++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 32ad22db3..9034e6708 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -95,7 +95,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuits the tilde assume if `vn` is present in `context`. + # Short-circuit the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) # TODO(mhauru) Is the call to logpdf correct if context.context is not # DefaultContext? @@ -105,10 +105,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, vn, context.global_varinfo[], @@ -118,30 +117,24 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) end end +# As above but with an RNG. function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) + # See comment in the above, rng-less version of this method for an explanation. if is_target_varname(context, vn) - # Fall back to the default behavior. return DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) elseif has_conditioned_gibbs(context, vn) - # Short-circuits the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) Is the call to logpdf correct if context.context is not - # DefaultContext? + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, vn, context.global_varinfo[], @@ -169,21 +162,18 @@ function reconstruct_getvalue( return reduce(hcat, x[2:end]; init=x[1]) end +# Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. +# See comments there for more details. function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) if is_target_varname(context, vns) - # Fall back to the default behavior. return DynamicPPL.dot_tilde_assume( DynamicPPL.childcontext(context), right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) - # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, broadcast_logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( DynamicPPL.childcontext(context), @@ -198,23 +188,19 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi end end +# As above but with an RNG. function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) if is_target_varname(context, vns) - # Fall back to the default behavior. return DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) - # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, broadcast_logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( rng, @@ -230,21 +216,6 @@ function DynamicPPL.dot_tilde_assume( end end -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - """ make_conditional(model, target_variables, varinfo) @@ -253,10 +224,15 @@ Return a new, conditioned model for a component of a Gibbs sampler. # Arguments - `model::DynamicPPL.Model`: The model to condition. - `target_variables::AbstractVector{<:VarName}`: The target variables of the component -sampler. These will _not_ conditioned. +sampler. These will _not_ be conditioned. - `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the values in `varinfo` but not in `target_variables` will be conditioned to the values they have in `varinfo`. + +# Returns +- A new model with the variables _not_ in `target_variables` conditioned. +- The `GibbsContext` object that will be used to condition the variables. This is necessary +because evaluation can mutate its `global_varinfo` field, which we need to access later. """ function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo @@ -360,7 +336,7 @@ function DynamicPPL.initialstep( varnames = alg.varnames samplers = alg.samplers - # 1. Run the model once to get the varnames present + initial values to condition on. + # Run the model once to get the varnames present + initial values to condition on. vi = DynamicPPL.VarInfo(rng, model) if initial_params !== nothing vi = DynamicPPL.unflatten(vi, initial_params) @@ -371,10 +347,13 @@ function DynamicPPL.initialstep( for (varnames_local, sampler_local) in zip(varnames, samplers) varnames_local = _maybevec(varnames_local) # Get the initial values for this component sampler. - vi_local = DynamicPPL.subset(vi, varnames_local) - initial_params_local = initial_params === nothing ? nothing : vi_local[:] + initial_params_local = if initial_params === nothing + nothing + else + DynamicPPL.subset(vi, varnames_local)[:] + end - # Construct the conditional model. + # Construct the conditioned model. model_local, context_local = make_conditional(model, varnames_local, vi) # Take initial step. @@ -397,7 +376,7 @@ function DynamicPPL.initialstep( # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) - # This merges the latest values for all the variables in the current sampler. + # This merges the new values for all the variables sampled by the current sampler. vi = merge(vi, new_vi_local) push!(states, new_state_local) end @@ -473,8 +452,8 @@ function recompute_logprob!!( return setlogp!!(state, vi_new.logp[]) end -# TODO(mhauru) Would really like to type constrain this to something like AbstractMCMCState -# if such a thing existed. +# TODO(mhauru) Would really like to type constrain the first argument to something like +# AbstractMCMCState if such a thing existed. function DynamicPPL.setlogp!!(state, logp) try new_vi = setlogp!!(state.vi, logp) @@ -496,11 +475,11 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) end # TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? -# The current list is a guess, but I think some might be unnecessary. +# The current list is a guess, and I think some are unnecessary. """ reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) -Return an updated state for a component sampler. +Return an updated state for a Gibbs component sampler. This takes into account changes caused by other Gibbs components. The default implementation is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to From b8c3dcdbac31651fd5a6bf9b43699d712f436cc4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 10:29:36 +0000 Subject: [PATCH 19/25] Change how AdvancedHMC Gibbs state treats momenta --- src/mcmc/gibbs.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 9034e6708..ab3a827cd 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -572,10 +572,18 @@ function reset_state!!( state.metric, DynamicPPL.LogDensityFunction(model) ) θ_new = varinfo[:] - # Set the momentum to zero, since we have no idea what it should be at the new parameter - # values. + # Modify the momentum to have the right number of elements, if the number of position + # variables has changed. Any new dimensions will be set to zero momentum. + # Note that there's no guarantee that any new variables are at the end of the parameter + # list, so we may end up mismatching momenta and parameters. This shouldn't be of + # consequence though, since the momentum will get resampled anyway. + # Frankly, we could probably just as well set the momenta to zero, but that made + # ForwardDiff crash for some reason I (mhauru) didn't bother to investigate. + momenta_old = state.transition.z.r + momenta_new = zero(θ_new) + momenta_new[1:length(momenta_old)] .= momenta_old return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, θ_new, zero(θ_new) + hamiltonian, θ_new, momenta_new ) end From da0b740e561f2a476b87fdc60b3454a6b2b16477 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 11:31:03 +0000 Subject: [PATCH 20/25] Remove unnecessary invlinking --- src/mcmc/gibbs.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index ab3a827cd..b2895d1e7 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -367,12 +367,6 @@ function DynamicPPL.initialstep( kwargs..., ) new_vi_local = varinfo(new_state_local) - # TODO(mhauru) Can we remove the invlinking? - new_vi_local = if DynamicPPL.istrans(new_vi_local) - DynamicPPL.invlink(new_vi_local, sampler_local, model_local) - else - new_vi_local - end # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) @@ -626,18 +620,9 @@ function gibbs_step_inner( state_local = states[index] varnames_local = _maybevec(varnames[index]) - # TODO(mhauru) Can we remove the invlinking? - vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) - # TODO(mhauru) Can we remove the below, if get rid of all the invlinking? - # If the varinfo of the previous state from this sampler is linked, we should link the - # new varinfo too. - if DynamicPPL.istrans(varinfo(state_local)) - varinfo_local = DynamicPPL.link(varinfo_local, sampler_local, model_local) - end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] From d984d2bbceec25e0e345c20c8eeb66fc96ebb8bc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 13:50:43 +0000 Subject: [PATCH 21/25] Change how AdvancedHMC Gibbs state treats momenta, again --- src/mcmc/gibbs.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index b2895d1e7..8ee034e02 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -566,16 +566,13 @@ function reset_state!!( state.metric, DynamicPPL.LogDensityFunction(model) ) θ_new = varinfo[:] - # Modify the momentum to have the right number of elements, if the number of position - # variables has changed. Any new dimensions will be set to zero momentum. - # Note that there's no guarantee that any new variables are at the end of the parameter - # list, so we may end up mismatching momenta and parameters. This shouldn't be of - # consequence though, since the momentum will get resampled anyway. - # Frankly, we could probably just as well set the momenta to zero, but that made - # ForwardDiff crash for some reason I (mhauru) didn't bother to investigate. + # Set the momentum to some arbitrary value, making sure it has the right number of + # components. We could try to do something clever here to only reset momenta related to + # new variables, but it'll be resampled in the next iteration anyway. + # TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes + # ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug. momenta_old = state.transition.z.r - momenta_new = zero(θ_new) - momenta_new[1:length(momenta_old)] .= momenta_old + momenta_new = ones(eltype(momenta_old), length(θ_new)) return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( hamiltonian, θ_new, momenta_new ) From d52af52251d59b9f2bd08f5aa97c2e0f278fb562 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 1 Nov 2024 18:54:40 +0000 Subject: [PATCH 22/25] Use setparams!! rather than reset_state!! --- src/mcmc/gibbs.jl | 205 +++++++++++++++------------------------------- src/mcmc/hmc.jl | 5 +- 2 files changed, 70 insertions(+), 140 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 8ee034e02..d1b88f946 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -50,9 +50,16 @@ function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) ) end +get_global_varinfo(context::GibbsContext) = context.global_varinfo[] + +function set_global_varinfo!(context::GibbsContext, new_global_varinfo) + context.global_varinfo[] = new_global_varinfo + return nothing +end + # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.haskey(context.global_varinfo[], vn) + return DynamicPPL.haskey(get_global_varinfo(context), vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) @@ -66,7 +73,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return context.global_varinfo[][vn] + return get_global_varinfo(context)[vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -110,9 +117,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -137,9 +144,9 @@ function DynamicPPL.tilde_assume( DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -181,9 +188,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -209,9 +216,9 @@ function DynamicPPL.dot_tilde_assume( right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -468,139 +475,71 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) return TuringState(setlogp!!(state.state, logp), logp) end -# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? -# The current list is a guess, and I think some are unnecessary. -""" - reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) - -Return an updated state for a Gibbs component sampler. - -This takes into account changes caused by other Gibbs components. The default implementation -is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to -do, a method should be implemented for the specific type of `state`. +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `setparams!!`. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector) + return DynamicPPL.unflatten(state, params) +end -# Arguments -- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not -sampled by this component sampler have been conditioned with a `GibbsContext`. -- `sampler::DynamicPPL.Sampler`: The current component sampler. -- `state`: The state of this component sampler from its previous iteration. -- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables -sampled by this component sampler. -- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain. -- `state_previous`: The state returned by the previous sampler. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo) + return params +end -# Returns -An updated state of the same type as `state`. It should have variables set to the values in -`varinfo`, and any other relevant updates done. -""" -function reset_state!!( - model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, + state::TuringState, + params::Union{AbstractVector,AbstractVarInfo}, ) - # In the fallback implementation we guess that `state` has a field called `vi` we can - # set. Fingers crossed! - try - return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo) - catch - error( - "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for reset_state!! for this type.", - ) - end + new_inner_state = AbstractMCMC.setparams!!(model, state.state, params) + return TuringState(new_inner_state, state.logdensity) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `reset_state!!`. -function reset_state!!( - model, - sampler, - state::AbstractVarInfo, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return varinfo +# Unless some other treatment has been specified for this state type, just flatten the +# AbstractVarInfo. This method exists because some sampler types need to override this +# behavior. +function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo) + return AbstractMCMC.setparams!!(model, state, params[:]) end -function reset_state!!( - model, - sampler, - state::TuringState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo ) - new_inner_state = reset_state!!( - model, sampler, state.state, varinfo, sampler_previous, state_previous + θ_new = params[:] + hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState( + params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler ) - return TuringState(new_inner_state, state.logdensity) end -function reset_state!!( - model, - sampler, - state::HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVector ) - θ_new = varinfo[:] - hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new)) + θ_new = params + vi = DynamicPPL.unflatten(state.vi, params) + hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new)) # Update the parameter values in `state.z`. # TODO: Avoid mutation z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) + return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler) end -function reset_state!!( - model, - sampler, - state::AdvancedHMC.HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo ) - hamiltonian = AdvancedHMC.Hamiltonian( - state.metric, DynamicPPL.LogDensityFunction(model) - ) - θ_new = varinfo[:] - # Set the momentum to some arbitrary value, making sure it has the right number of - # components. We could try to do something clever here to only reset momenta related to - # new variables, but it'll be resampled in the next iteration anyway. - # TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes - # ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug. - momenta_old = state.transition.z.r - momenta_new = ones(eltype(momenta_old), length(θ_new)) - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, θ_new, momenta_new - ) + return PGState(params, state.rng) end -function reset_state!!( - model, - sampler, - state::AdvancedMH.Transition, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - # TODO(mhauru) Setting the last argument like this seems a bit suspect, since the - # current values for the parameters might not have come from this sampler at all. - # I don't see a better way though. - return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted) -end - -function reset_state!!( - model, - sampler, - state::PGState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return PGState(varinfo, state.rng) +function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector) + return PGState(DynamicPPL.unflatten(state.vi, params), state.rng) end function gibbs_step_inner( @@ -609,7 +548,7 @@ function gibbs_step_inner( varnames, samplers, states, - vi, + global_vi, index; kwargs..., ) @@ -618,8 +557,8 @@ function gibbs_step_inner( varnames_local = _maybevec(varnames[index]) # Construct the conditional model and the varinfo that this sampler should use. - model_local, context_local = make_conditional(model, varnames_local, vi) - varinfo_local = subset(vi, varnames_local) + model_local, context_local = make_conditional(model, varnames_local, global_vi) + varinfo_local = subset(global_vi, varnames_local) # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] @@ -627,14 +566,7 @@ function gibbs_step_inner( # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = reset_state!!( - model_local, - sampler_local, - state_local, - varinfo_local, - sampler_previous, - state_previous, - ) + state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) @@ -647,11 +579,8 @@ function gibbs_step_inner( ) new_vi_local = varinfo(new_state_local) - # This merges in any new variables that were introduced during the step, but that - # were not in the domain of the current sampler. - new_vi = merge(vi, context_local.global_varinfo[]) - # This merges the latest values for all the variables in the current sampler. - new_vi = merge(new_vi, new_vi_local) - new_vi = setlogp!!(new_vi, new_vi_local.logp[]) - return new_vi, new_state_local + # Merge the latest values for all the variables in the current sampler. + new_global_vi = merge(get_global_varinfo(context_local), new_vi_local) + new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + return new_global_vi, new_state_local end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d01ef274a..ab018e787 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -15,6 +15,7 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt + sampler::Sampler{<:Hamiltonian} end ### @@ -229,7 +230,7 @@ function DynamicPPL.initialstep( end transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl) return transition, state end @@ -275,7 +276,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl) return transition, newstate end From 508ac61ad52bfdb2bdd7d4621e2ae6e62cafe3e2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 4 Nov 2024 15:54:00 +0000 Subject: [PATCH 23/25] Don't overload setparams\!\! with VarInfo --- src/mcmc/gibbs.jl | 77 ++++++++++++++++++++--------------------------- src/mcmc/hmc.jl | 5 ++- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d1b88f946..9e1a70543 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -475,73 +475,58 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) return TuringState(setlogp!!(state.state, logp), logp) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `setparams!!`. -function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector) - return DynamicPPL.unflatten(state, params) -end - -function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo) - return params -end +""" + setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, - state::TuringState, - params::Union{AbstractVector,AbstractVarInfo}, -) - new_inner_state = AbstractMCMC.setparams!!(model, state.state, params) - return TuringState(new_inner_state, state.logdensity) -end +A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an +`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to +`AbstractMCMC.setparams!!(model, state, params[:])`. -# Unless some other treatment has been specified for this state type, just flatten the -# AbstractVarInfo. This method exists because some sampler types need to override this -# behavior. -function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo) +`model` is typically a `DynamicPPL.Model`, but can also be e.g. an +`AbstractMCMC.LogDensityModel`. +""" +function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) return AbstractMCMC.setparams!!(model, state, params[:]) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `setparams_varinfo!!`. +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::VarInfo, params::AbstractVarInfo ) - θ_new = params[:] - hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new)) + return params +end - # Update the parameter values in `state.z`. - # TODO: Avoid mutation - z = state.z - resize!(z.θ, length(θ_new)) - z.θ .= θ_new - return HMCState( - params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::TuringState, params::AbstractVarInfo +) + logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) + new_inner_state = setparams_varinfo!!( + AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params ) + return TuringState(new_inner_state, logdensity) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::HMCState, params::AbstractVector +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::HMCState, params::AbstractVarInfo ) - θ_new = params - vi = DynamicPPL.unflatten(state.vi, params) - hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new)) + θ_new = params[:] + hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) # Update the parameter values in `state.z`. # TODO: Avoid mutation z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler) + return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end -function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector) - return PGState(DynamicPPL.unflatten(state.vi, params), state.rng) -end - function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -566,7 +551,9 @@ function gibbs_step_inner( # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local) + state_local = setparams_varinfo!!( + model_local, sampler_local, state_local, varinfo_local + ) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index ab018e787..d01ef274a 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -15,7 +15,6 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt - sampler::Sampler{<:Hamiltonian} end ### @@ -230,7 +229,7 @@ function DynamicPPL.initialstep( end transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) return transition, state end @@ -276,7 +275,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) return transition, newstate end From 6ff7c59aae9ff321a0e31ba809b66ff7a7788df2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 5 Nov 2024 10:15:46 +0000 Subject: [PATCH 24/25] A fix for ESS in Gibbs --- src/mcmc/gibbs.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 9e1a70543..e13b71f8b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -374,6 +374,12 @@ function DynamicPPL.initialstep( kwargs..., ) new_vi_local = varinfo(new_state_local) + # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. + # For some reason not having this in place was causing trouble for ESS, but not for + # other samplers. I didn't get to the bottom of it. + for vn in keys(new_vi_local) + DynamicPPL.setgid!(new_vi_local, sampler_local.selector, vn) + end # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) @@ -544,6 +550,12 @@ function gibbs_step_inner( # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) + # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. + # For some reason not having this in place was causing trouble for ESS, but not for + # other samplers. I didn't get to the bottom of it. + for vn in keys(varinfo_local) + DynamicPPL.setgid!(varinfo_local, sampler_local.selector, vn) + end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] From 934c03efd281c092c0c69e54b49325a3ba21f5b6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 5 Nov 2024 16:02:16 +0000 Subject: [PATCH 25/25] Remove recompute_logprob!! --- src/mcmc/abstractmcmc.jl | 45 ------------------ src/mcmc/gibbs.jl | 100 ++++++--------------------------------- test/mcmc/gibbs.jl | 4 +- 3 files changed, 18 insertions(+), 131 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index c6ca61a9c..8dfee52b4 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -56,51 +56,6 @@ function setvarinfo( ) end -""" - recompute_logprob!!(rng, model, sampler, state) - -Recompute the log-probability of the `model` based on the given `state` and return the resulting state. -""" -function recompute_logprob!!( - rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:ExternalSampler}, - state, # TODO(mhauru) Could we type constrain this to TuringState? -) - # Re-using the log-density function from the `state` and updating only the `model` field, - # since the `model` might now contain different conditioning values. - f = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) - # Recompute the log-probability with the new `model`. - state_inner = recompute_logprob!!( - rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state - ) - return state_to_turing(f, state_inner) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedHMC.AbstractHMCSampler, - state::AdvancedHMC.HMCState, -) - # Construct hamiltionian. - hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) - # Re-compute the log-probability and gradient. - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, state.transition.z.θ, state.transition.z.r - ) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedMH.MetropolisHastings, - state::AdvancedMH.Transition, -) - logdensity = model.logdensity - return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params) -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index e13b71f8b..d44184da2 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -407,80 +407,17 @@ function AbstractMCMC.step( # TODO: move this into a recursive function so we can unroll when reasonable? for index in 1:length(samplers) # Take the inner step. + sampler_local = samplers[index] + state_local = states[index] + varnames_local = _maybevec(varnames[index]) vi, new_state_local = gibbs_step_inner( - rng, model, varnames, samplers, states, vi, index; kwargs... + rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) states = Accessors.setindex(states, new_state_local, index) end return Transition(model, vi), GibbsState(vi, states) end -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob( - model_dst, sampler_dst, sampler_src, state_dst, state_src -) - return true -end - -# TODO: Remove `rng`? -function recompute_logprob!!( - rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, state -) - vi = varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - vi_new = last( - DynamicPPL.evaluate!!( - model, - vi, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun), - ) - ) - return setlogp!!(state, vi_new.logp[]) -end - -# TODO(mhauru) Would really like to type constrain the first argument to something like -# AbstractMCMCState if such a thing existed. -function DynamicPPL.setlogp!!(state, logp) - try - new_vi = setlogp!!(state.vi, logp) - if new_vi !== state.vi - return Accessors.set(state, Accessors.PropertyLens{:vi}(), new_vi) - else - return state - end - catch - error( - "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for `setlogp!!` for this type.", - ) - end -end - -function DynamicPPL.setlogp!!(state::TuringState, logp) - return TuringState(setlogp!!(state.state, logp), logp) -end - """ setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) @@ -536,17 +473,12 @@ end function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, - varnames, - samplers, - states, - global_vi, - index; + varnames_local, + sampler_local, + state_local, + global_vi; kwargs..., ) - sampler_local = samplers[index] - state_local = states[index] - varnames_local = _maybevec(varnames[index]) - # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) @@ -557,20 +489,18 @@ function gibbs_step_inner( DynamicPPL.setgid!(varinfo_local, sampler_local.selector, vn) end - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - + # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not + # sampled by other samplers, we don't need to `setparams`, but could rather simply + # recompute the log probability. More over, in some cases the recomputation could also + # be avoided, if e.g. the previous sampler has done all the necessary work already. + # However, we've judged that doing any caching or other tricks to avoid this now would + # be premature optimization. In most use cases of Gibbs a single model call here is not + # going to be a significant expense anyway. # Set the state of the current sampler, accounting for any changes made by other # samplers. state_local = setparams_varinfo!!( model_local, sampler_local, state_local, varinfo_local ) - if gibbs_requires_recompute_logprob( - model_local, sampler_local, sampler_previous, state_local, state_previous - ) - state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) - end # Take a step with the local sampler. new_state_local = last( diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c11f29162..65f192dc7 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -183,6 +183,8 @@ has_dot_assume(::DynamicPPL.Model) = true end @testset "dynamic model" begin + # TODO(mhauru) We should check that the results of the sampling are correct. + # Currently we just check that this doesn't crash. @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) rpm = DirichletProcess(alpha) @@ -204,7 +206,7 @@ has_dot_assume(::DynamicPPL.Model) = true end model = imm(Random.randn(100), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); + # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) end