Skip to content

Commit

Permalink
Use setparams!! rather than reset_state!!
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 1, 2024
1 parent 15ee270 commit d52af52
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 140 deletions.
205 changes: 67 additions & 138 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
)
end

get_global_varinfo(context::GibbsContext) = context.global_varinfo[]

function set_global_varinfo!(context::GibbsContext, new_global_varinfo)
context.global_varinfo[] = new_global_varinfo
return nothing

Check warning on line 57 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L55-L57

Added lines #L55 - L57 were not covered by tests
end

# has and get
function has_conditioned_gibbs(context::GibbsContext, vn::VarName)
return DynamicPPL.haskey(context.global_varinfo[], vn)
return DynamicPPL.haskey(get_global_varinfo(context), vn)
end
function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns))
Expand All @@ -66,7 +73,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
end

function get_conditioned_gibbs(context::GibbsContext, vn::VarName)
return context.global_varinfo[][vn]
return get_global_varinfo(context)[vn]
end
function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
return map(Base.Fix1(get_conditioned_gibbs, context), vns)

Check warning on line 79 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
Expand Down Expand Up @@ -110,9 +117,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
DynamicPPL.SampleFromPrior(),
right,
vn,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 123 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L122-L123

Added lines #L122 - L123 were not covered by tests
end
end
Expand All @@ -137,9 +144,9 @@ function DynamicPPL.tilde_assume(
DynamicPPL.SampleFromPrior(),
right,
vn,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 150 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
end
end
Expand Down Expand Up @@ -181,9 +188,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi
right,
left,
vns,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 194 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L193-L194

Added lines #L193 - L194 were not covered by tests
end
end
Expand All @@ -209,9 +216,9 @@ function DynamicPPL.dot_tilde_assume(
right,
left,
vns,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 222 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L221-L222

Added lines #L221 - L222 were not covered by tests
end
end
Expand Down Expand Up @@ -468,139 +475,71 @@ function DynamicPPL.setlogp!!(state::TuringState, logp)
return TuringState(setlogp!!(state.state, logp), logp)

Check warning on line 475 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
end

# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!?
# The current list is a guess, and I think some are unnecessary.
"""
reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous)
Return an updated state for a Gibbs component sampler.
This takes into account changes caused by other Gibbs components. The default implementation
is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to
do, a method should be implemented for the specific type of `state`.
# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `setparams!!`.
function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector)
return DynamicPPL.unflatten(state, params)
end

# Arguments
- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not
sampled by this component sampler have been conditioned with a `GibbsContext`.
- `sampler::DynamicPPL.Sampler`: The current component sampler.
- `state`: The state of this component sampler from its previous iteration.
- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables
sampled by this component sampler.
- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain.
- `state_previous`: The state returned by the previous sampler.
function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo)
return params

Check warning on line 485 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L484-L485

Added lines #L484 - L485 were not covered by tests
end

# Returns
An updated state of the same type as `state`. It should have variables set to the values in
`varinfo`, and any other relevant updates done.
"""
function reset_state!!(
model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous
function AbstractMCMC.setparams!!(

Check warning on line 488 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L488

Added line #L488 was not covered by tests
model::DynamicPPL.Model,
state::TuringState,
params::Union{AbstractVector,AbstractVarInfo},
)
# In the fallback implementation we guess that `state` has a field called `vi` we can
# set. Fingers crossed!
try
return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo)
catch
error(
"Unable to set `state.vi` for a $(typeof(state)). " *
"Consider writing a method for reset_state!! for this type.",
)
end
new_inner_state = AbstractMCMC.setparams!!(model, state.state, params)
return TuringState(new_inner_state, state.logdensity)

Check warning on line 494 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L493-L494

Added lines #L493 - L494 were not covered by tests
end

# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `reset_state!!`.
function reset_state!!(
model,
sampler,
state::AbstractVarInfo,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
return varinfo
# Unless some other treatment has been specified for this state type, just flatten the
# AbstractVarInfo. This method exists because some sampler types need to override this
# behavior.
function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo)
return AbstractMCMC.setparams!!(model, state, params[:])
end

function reset_state!!(
model,
sampler,
state::TuringState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(

Check warning on line 504 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L504

Added line #L504 was not covered by tests
model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo
)
new_inner_state = reset_state!!(
model, sampler, state.state, varinfo, sampler_previous, state_previous
θ_new = params[:]
hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new))

Check warning on line 508 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L507-L508

Added lines #L507 - L508 were not covered by tests

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(

Check warning on line 515 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L512-L515

Added lines #L512 - L515 were not covered by tests
params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler
)
return TuringState(new_inner_state, state.logdensity)
end

function reset_state!!(
model,
sampler,
state::HMCState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(

Check warning on line 520 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L520

Added line #L520 was not covered by tests
model::DynamicPPL.Model, state::HMCState, params::AbstractVector
)
θ_new = varinfo[:]
hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new))
θ_new = params
vi = DynamicPPL.unflatten(state.vi, params)
hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new))

Check warning on line 525 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L523-L525

Added lines #L523 - L525 were not covered by tests

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler)

Check warning on line 532 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L529-L532

Added lines #L529 - L532 were not covered by tests
end

function reset_state!!(
model,
sampler,
state::AdvancedHMC.HMCState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo
)
hamiltonian = AdvancedHMC.Hamiltonian(
state.metric, DynamicPPL.LogDensityFunction(model)
)
θ_new = varinfo[:]
# Set the momentum to some arbitrary value, making sure it has the right number of
# components. We could try to do something clever here to only reset momenta related to
# new variables, but it'll be resampled in the next iteration anyway.
# TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes
# ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug.
momenta_old = state.transition.z.r
momenta_new = ones(eltype(momenta_old), length(θ_new))
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, θ_new, momenta_new
)
return PGState(params, state.rng)
end

function reset_state!!(
model,
sampler,
state::AdvancedMH.Transition,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
# TODO(mhauru) Setting the last argument like this seems a bit suspect, since the
# current values for the parameters might not have come from this sampler at all.
# I don't see a better way though.
return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted)
end

function reset_state!!(
model,
sampler,
state::PGState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
return PGState(varinfo, state.rng)
function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector)
return PGState(DynamicPPL.unflatten(state.vi, params), state.rng)

Check warning on line 542 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L541-L542

Added lines #L541 - L542 were not covered by tests
end

function gibbs_step_inner(
Expand All @@ -609,7 +548,7 @@ function gibbs_step_inner(
varnames,
samplers,
states,
vi,
global_vi,
index;
kwargs...,
)
Expand All @@ -618,23 +557,16 @@ function gibbs_step_inner(
varnames_local = _maybevec(varnames[index])

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, vi)
varinfo_local = subset(vi, varnames_local)
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)

# Extract the previous sampler and state.
sampler_previous = samplers[index == 1 ? length(samplers) : index - 1]
state_previous = states[index == 1 ? length(states) : index - 1]

# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = reset_state!!(
model_local,
sampler_local,
state_local,
varinfo_local,
sampler_previous,
state_previous,
)
state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local)
if gibbs_requires_recompute_logprob(
model_local, sampler_local, sampler_previous, state_local, state_previous
)
Expand All @@ -647,11 +579,8 @@ function gibbs_step_inner(
)

new_vi_local = varinfo(new_state_local)
# This merges in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
new_vi = merge(vi, context_local.global_varinfo[])
# This merges the latest values for all the variables in the current sampler.
new_vi = merge(new_vi, new_vi_local)
new_vi = setlogp!!(new_vi, new_vi_local.logp[])
return new_vi, new_state_local
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
return new_global_vi, new_state_local
end
5 changes: 3 additions & 2 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct HMCState{
hamiltonian::THam
z::PhType
adaptor::TAdapt
sampler::Sampler{<:Hamiltonian}
end

###
Expand Down Expand Up @@ -229,7 +230,7 @@ function DynamicPPL.initialstep(
end

transition = Transition(model, vi, t)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl)

return transition, state
end
Expand Down Expand Up @@ -275,7 +276,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = Transition(model, vi, t)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl)

return transition, newstate
end
Expand Down

0 comments on commit d52af52

Please sign in to comment.