Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for generated_quantities #534

Merged
merged 31 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9735066
added method for extracting the child lens from a varname subsumed by
torfjelde Sep 5, 2023
306946d
added nested_getindex and nested_setindex! for VarInfo
torfjelde Sep 5, 2023
af2a30f
added ConstructionBase.setproperties implementation for `Cholesky`
torfjelde Sep 5, 2023
ce0b2e7
fixed minor formatting issue
torfjelde Sep 5, 2023
3a8201c
added `supports_varname_indexing` for chains and use this in generate…
torfjelde Sep 6, 2023
dc7c675
use a private method rather than overloading getindex for Chains
torfjelde Sep 6, 2023
e4d964e
removed getindex overloads in nested_index testing
torfjelde Sep 6, 2023
a18b435
moved generated_quantities tests to test/model.jl
torfjelde Sep 6, 2023
34e422a
Apply suggestions from code review
torfjelde Sep 6, 2023
c4b2556
will now also correctly set variables to be resampled, etc.
torfjelde Sep 6, 2023
3b27435
Update test/model.jl
torfjelde Sep 6, 2023
ef67495
Update src/varinfo.jl
torfjelde Sep 6, 2023
df2d8e3
added Compat as a test dep so we can methods such as stack
torfjelde Sep 6, 2023
035d592
improved overload of ConstructionBase.setproperties
torfjelde Sep 6, 2023
24076b5
Apply suggestions from code review
torfjelde Sep 6, 2023
c604503
added docstring to remove_parent_lens
torfjelde Sep 6, 2023
ad4a5bd
removed methods which are not useful for the purpose of this PR
torfjelde Sep 6, 2023
378897e
noticed we're incorrectly using chain rather than chain_params in gen…
torfjelde Sep 6, 2023
220646e
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 6, 2023
08dd71a
fixed doctests
torfjelde Sep 6, 2023
c09b780
added Requires.jl
torfjelde Sep 6, 2023
44335d4
Update src/DynamicPPL.jl
torfjelde Sep 6, 2023
521775f
bump patch version
torfjelde Sep 6, 2023
3594017
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 6, 2023
2ae113c
Update src/DynamicPPL.jl
torfjelde Sep 7, 2023
e625fc5
moved new generated_quantities functionality into setval_and_resample!
torfjelde Sep 7, 2023
6d16806
Apply suggestions from code review
torfjelde Sep 7, 2023
1e42770
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 7, 2023
628809c
Update src/chains.jl
torfjelde Sep 7, 2023
8b9e3e0
bump compat entry for ConstructionBase.jl
torfjelde Sep 8, 2023
46e5a94
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@ else
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
function _check_varname_indexing(c::MCMCChains.Chains)
DynamicPPL.supports_varname_indexing(c) || error("Chains do not support indexing using $vn.")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

# TODO: Add proper overload of `Base.getindex` to Turing.jl?
function _getindex(c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx)
DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
# A few methods needed.
DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) = _has_varname_to_symbol(chain.info)
function DynamicPPL.getindex_varname(c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
_check_varname_indexing(c)
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end
function DynamicPPL.varnames(c::MCMCChains.Chains)
_check_varname_indexing(c)
return keys(c.info.varname_to_symbol)
end

function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
Expand All @@ -27,28 +31,10 @@ function DynamicPPL.generated_quantities(
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
# First we need to set every variable to be resampled.
for vn in keys(varinfo)
DynamicPPL.set_flag!(varinfo, vn, "del")
end
# Then we set the variables in `varinfo` from `chain`.
for vn in keys(chain.info.varname_to_symbol)
vn_updated = DynamicPPL.nested_setindex_maybe!(
varinfo, _getindex(chain, sample_idx, vn, chain_idx), vn
)
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)

# Unset the `del` flag if we found something.
if vn_updated !== nothing
# NOTE: This will be triggered even if only a subset of a variable has been set!
DynamicPPL.unset_flag!(varinfo, vn_updated, "del")
end
end
else
# NOTE: This can be quite unreliable (but will warn the uesr in that case).
# Hence the above path is much more preferable.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
end
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("chains.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
Expand All @@ -180,8 +181,8 @@
end

@static if !isdefined(Base, :get_extension)
function __init__()
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(

Check warning on line 185 in src/DynamicPPL.jl

View check run for this annotation

Codecov / codecov/patch

src/DynamicPPL.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
"../ext/DynamicPPLMCMCChainsExt.jl"
)
end
Expand Down
25 changes: 25 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
supports_chain_indexing(chain::AbstractChains)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

Return `true` if `chain` supports indexing using `VarName` in place of the
variable name index.
"""
supports_varname_indexing(::AbstractChains) = false

Check warning on line 7 in src/chains.jl

View check run for this annotation

Codecov / codecov/patch

src/chains.jl#L7

Added line #L7 was not covered by tests

"""
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)

Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function getindex_varname end

"""
varnames(chains::AbstractChains)

Return an iterator over the varnames present in `chains`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function varnames end
3 changes: 0 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,6 @@ function generated_quantities(model::Model, chain::AbstractChains)
end
end

# HACK: Be better.
supports_varname_indexing(chain::AbstractChains) = false

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand Down
21 changes: 20 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,8 @@
return Expr(:||, false, out...)
end

function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)

Check warning on line 1068 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1067-L1068

Added lines #L1067 - L1068 were not covered by tests
end
function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
Expand All @@ -1080,8 +1080,8 @@
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
if vn in vns
setindex!(vi, val, vn)
return vn

Check warning on line 1084 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1083-L1084

Added lines #L1083 - L1084 were not covered by tests
end

# Otherwise, we need to check if either of the `vns` subsumes `vn`.
Expand Down Expand Up @@ -1636,7 +1636,26 @@
function setval_and_resample!(
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
)
return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
if supports_varname_indexing(chains)
# First we need to set every variable to be resampled.
for vn in keys(vi)
set_flag!(vi, vn, "del")
end
# Then we set the variables in `varinfo` from `chain`.
for vn in varnames(chains)
vn_updated = nested_setindex_maybe!(
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
)

# Unset the `del` flag if we found something.
if vn_updated !== nothing
# NOTE: This will be triggered even if only a subset of a variable has been set!
unset_flag!(vi, vn_updated, "del")
end
end
else
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end
end

function _setval_and_resample_kernel!(
Expand Down
Loading