Skip to content

Commit

Permalink
Merge branch 'master' into py/cherry-pick-0.28.5
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm authored Nov 5, 2024
2 parents 3ad1f1d + b96e368 commit 83c8a28
Showing 1 changed file with 7 additions and 79 deletions.
86 changes: 7 additions & 79 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.
This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.
This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end

0 comments on commit 83c8a28

Please sign in to comment.