diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 8ee034e02..d1b88f946 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -50,9 +50,16 @@ function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) ) end +get_global_varinfo(context::GibbsContext) = context.global_varinfo[] + +function set_global_varinfo!(context::GibbsContext, new_global_varinfo) + context.global_varinfo[] = new_global_varinfo + return nothing +end + # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.haskey(context.global_varinfo[], vn) + return DynamicPPL.haskey(get_global_varinfo(context), vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) @@ -66,7 +73,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return context.global_varinfo[][vn] + return get_global_varinfo(context)[vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -110,9 +117,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -137,9 +144,9 @@ function DynamicPPL.tilde_assume( DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -181,9 +188,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -209,9 +216,9 @@ function DynamicPPL.dot_tilde_assume( right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -468,139 +475,71 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) return TuringState(setlogp!!(state.state, logp), logp) end -# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? -# The current list is a guess, and I think some are unnecessary. -""" - reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) - -Return an updated state for a Gibbs component sampler. - -This takes into account changes caused by other Gibbs components. The default implementation -is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to -do, a method should be implemented for the specific type of `state`. +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `setparams!!`. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector) + return DynamicPPL.unflatten(state, params) +end -# Arguments -- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not -sampled by this component sampler have been conditioned with a `GibbsContext`. -- `sampler::DynamicPPL.Sampler`: The current component sampler. -- `state`: The state of this component sampler from its previous iteration. -- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables -sampled by this component sampler. -- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain. -- `state_previous`: The state returned by the previous sampler. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo) + return params +end -# Returns -An updated state of the same type as `state`. It should have variables set to the values in -`varinfo`, and any other relevant updates done. -""" -function reset_state!!( - model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, + state::TuringState, + params::Union{AbstractVector,AbstractVarInfo}, ) - # In the fallback implementation we guess that `state` has a field called `vi` we can - # set. Fingers crossed! - try - return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo) - catch - error( - "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for reset_state!! for this type.", - ) - end + new_inner_state = AbstractMCMC.setparams!!(model, state.state, params) + return TuringState(new_inner_state, state.logdensity) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `reset_state!!`. -function reset_state!!( - model, - sampler, - state::AbstractVarInfo, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return varinfo +# Unless some other treatment has been specified for this state type, just flatten the +# AbstractVarInfo. This method exists because some sampler types need to override this +# behavior. +function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo) + return AbstractMCMC.setparams!!(model, state, params[:]) end -function reset_state!!( - model, - sampler, - state::TuringState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo ) - new_inner_state = reset_state!!( - model, sampler, state.state, varinfo, sampler_previous, state_previous + θ_new = params[:] + hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState( + params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler ) - return TuringState(new_inner_state, state.logdensity) end -function reset_state!!( - model, - sampler, - state::HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVector ) - θ_new = varinfo[:] - hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new)) + θ_new = params + vi = DynamicPPL.unflatten(state.vi, params) + hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new)) # Update the parameter values in `state.z`. # TODO: Avoid mutation z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) + return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler) end -function reset_state!!( - model, - sampler, - state::AdvancedHMC.HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo ) - hamiltonian = AdvancedHMC.Hamiltonian( - state.metric, DynamicPPL.LogDensityFunction(model) - ) - θ_new = varinfo[:] - # Set the momentum to some arbitrary value, making sure it has the right number of - # components. We could try to do something clever here to only reset momenta related to - # new variables, but it'll be resampled in the next iteration anyway. - # TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes - # ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug. - momenta_old = state.transition.z.r - momenta_new = ones(eltype(momenta_old), length(θ_new)) - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, θ_new, momenta_new - ) + return PGState(params, state.rng) end -function reset_state!!( - model, - sampler, - state::AdvancedMH.Transition, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - # TODO(mhauru) Setting the last argument like this seems a bit suspect, since the - # current values for the parameters might not have come from this sampler at all. - # I don't see a better way though. - return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted) -end - -function reset_state!!( - model, - sampler, - state::PGState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return PGState(varinfo, state.rng) +function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector) + return PGState(DynamicPPL.unflatten(state.vi, params), state.rng) end function gibbs_step_inner( @@ -609,7 +548,7 @@ function gibbs_step_inner( varnames, samplers, states, - vi, + global_vi, index; kwargs..., ) @@ -618,8 +557,8 @@ function gibbs_step_inner( varnames_local = _maybevec(varnames[index]) # Construct the conditional model and the varinfo that this sampler should use. - model_local, context_local = make_conditional(model, varnames_local, vi) - varinfo_local = subset(vi, varnames_local) + model_local, context_local = make_conditional(model, varnames_local, global_vi) + varinfo_local = subset(global_vi, varnames_local) # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] @@ -627,14 +566,7 @@ function gibbs_step_inner( # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = reset_state!!( - model_local, - sampler_local, - state_local, - varinfo_local, - sampler_previous, - state_previous, - ) + state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) @@ -647,11 +579,8 @@ function gibbs_step_inner( ) new_vi_local = varinfo(new_state_local) - # This merges in any new variables that were introduced during the step, but that - # were not in the domain of the current sampler. - new_vi = merge(vi, context_local.global_varinfo[]) - # This merges the latest values for all the variables in the current sampler. - new_vi = merge(new_vi, new_vi_local) - new_vi = setlogp!!(new_vi, new_vi_local.logp[]) - return new_vi, new_state_local + # Merge the latest values for all the variables in the current sampler. + new_global_vi = merge(get_global_varinfo(context_local), new_vi_local) + new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + return new_global_vi, new_state_local end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d01ef274a..ab018e787 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -15,6 +15,7 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt + sampler::Sampler{<:Hamiltonian} end ### @@ -229,7 +230,7 @@ function DynamicPPL.initialstep( end transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl) return transition, state end @@ -275,7 +276,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl) return transition, newstate end