Skip to content

Commit

Permalink
Revert to using setval_and_resample! in generated_quantities (#709)
Browse files Browse the repository at this point in the history
* revert to using `setval_and_resample!` in `generated_quantities` for now

* added TODO so we don't forget

* bump patch version

---------

Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
torfjelde and penelopeysm authored Nov 4, 2024
1 parent 99def88 commit b96e368
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30.2"
version = "0.30.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
86 changes: 7 additions & 79 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.
This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.
This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end

2 comments on commit b96e368

@penelopeysm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Fix regression in generated_quantities not fixing the values of array variables.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118672

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.30.3 -m "<description of version>" b96e368874215125a222a65a33ceeed54a4bb1ee
git push origin v0.30.3

Please sign in to comment.