Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for generated_quantities #534

Merged
merged 31 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9735066
added method for extracting the child lens from a varname subsumed by
torfjelde Sep 5, 2023
306946d
added nested_getindex and nested_setindex! for VarInfo
torfjelde Sep 5, 2023
af2a30f
added ConstructionBase.setproperties implementation for `Cholesky`
torfjelde Sep 5, 2023
ce0b2e7
fixed minor formatting issue
torfjelde Sep 5, 2023
3a8201c
added `supports_varname_indexing` for chains and use this in generate…
torfjelde Sep 6, 2023
dc7c675
use a private method rather than overloading getindex for Chains
torfjelde Sep 6, 2023
e4d964e
removed getindex overloads in nested_index testing
torfjelde Sep 6, 2023
a18b435
moved generated_quantities tests to test/model.jl
torfjelde Sep 6, 2023
34e422a
Apply suggestions from code review
torfjelde Sep 6, 2023
c4b2556
will now also correctly set variables to be resampled, etc.
torfjelde Sep 6, 2023
3b27435
Update test/model.jl
torfjelde Sep 6, 2023
ef67495
Update src/varinfo.jl
torfjelde Sep 6, 2023
df2d8e3
added Compat as a test dep so we can methods such as stack
torfjelde Sep 6, 2023
035d592
improved overload of ConstructionBase.setproperties
torfjelde Sep 6, 2023
24076b5
Apply suggestions from code review
torfjelde Sep 6, 2023
c604503
added docstring to remove_parent_lens
torfjelde Sep 6, 2023
ad4a5bd
removed methods which are not useful for the purpose of this PR
torfjelde Sep 6, 2023
378897e
noticed we're incorrectly using chain rather than chain_params in gen…
torfjelde Sep 6, 2023
220646e
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 6, 2023
08dd71a
fixed doctests
torfjelde Sep 6, 2023
c09b780
added Requires.jl
torfjelde Sep 6, 2023
44335d4
Update src/DynamicPPL.jl
torfjelde Sep 6, 2023
521775f
bump patch version
torfjelde Sep 6, 2023
3594017
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 6, 2023
2ae113c
Update src/DynamicPPL.jl
torfjelde Sep 7, 2023
e625fc5
moved new generated_quantities functionality into setval_and_resample!
torfjelde Sep 7, 2023
6d16806
Apply suggestions from code review
torfjelde Sep 7, 2023
1e42770
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 7, 2023
628809c
Update src/chains.jl
torfjelde Sep 7, 2023
8b9e3e0
bump compat entry for ConstructionBase.jl
torfjelde Sep 8, 2023
46e5a94
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,48 @@ module DynamicPPLMCMCChainsExt
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

# TODO: Add proper overload of `Base.getindex` to Turing.jl?
function _getindex(c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx)
DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end

function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains)
chain_parameters = MCMCChains.get_sections(chain, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx)
model(varinfo)
if DynamicPPL.supports_varname_indexing(chain)
# First we need to set every variable to be resampled.
for vn in keys(varinfo)
DynamicPPL.set_flag!(varinfo, vn, "del")
end
# Then we set the variables in `varinfo` from `chain`.
for vn in keys(chain.info.varname_to_symbol)
vn_updated = DynamicPPL.nested_setindex_maybe!(
varinfo, _getindex(chain, sample_idx, vn, chain_idx), vn
)

# Unset the `del` flag if we found something.
if vn_updated !== nothing
# NOTE: This will be triggered even if only a subset of a variable has been set!
DynamicPPL.unset_flag!(varinfo, vn_updated, "del")
end
end
else
# NOTE: This can be quite unreliable (but will warn the uesr in that case).
# Hence the above path is much more preferable.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
end
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
end
end

Expand Down
3 changes: 3 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,9 @@ function generated_quantities(model::Model, chain::AbstractChains)
end
end

# HACK: Be better.
supports_varname_indexing(chain::AbstractChains) = false

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand Down
28 changes: 28 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,19 @@ function splitlens(condition, lens)
return current_parent, current_child, condition(current_parent)
end

function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitlens(getlens(vn_child)) do lens
l = lens === nothing ? Setfield.IdentityLens() : lens
VarName(vn_child, l) == vn_parent
end

if !issuccess
error("Could not find $vn_parent in $vn_child")
end

return child
end

# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
Expand Down Expand Up @@ -1045,3 +1058,18 @@ function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTrian
for I in CartesianIndices(x) if I[1] <= I[2]
)
end

# TODO: Remove as soon as https://github.com/JuliaObjects/ConstructionBase.jl/pull/80 goes through.
ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, ::NamedTuple{()}) = C
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:L,)})
return LinearAlgebra.Cholesky(C.uplo === 'U' ? permutedims(patch.L) : patch.L, C.uplo, C.info)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:U,)})
return LinearAlgebra.Cholesky(C.uplo === 'L' ? permutedims(patch.U) : patch.U, C.uplo, C.info)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:UL,)})
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return LinearAlgebra.Cholesky(patch.UL, C.uplo, C.info)
end
@nospecialize function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple)
error("Can only patch one of :L, :U, :UL at the time")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
74 changes: 73 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,77 @@ end
return Expr(:||, false, out...)
end

"""
nested_getindex(vi::VarInfo, vn)

Return the value corresponding to `vn` in `vi`.
"""
nested_getindex(vi::VarInfo, vn::VarName) = _nested_getindex(vi, getmetadata(vi, vn), vn)
function _nested_getindex(varinfo::VarInfo, md::Metadata, vn::VarName)
yebai marked this conversation as resolved.
Show resolved Hide resolved
# Check if `vn` is in `md.vns`.
vns = md.vns
vn in vns && return getindex(varinfo, vn)

# If that's not the case, we check if `vn` is subsumed by any of `md.vns`.
i = findfirst(Base.Fix2(subsumes, vn), vns)
i === nothing && error(KeyError(vn))

# If `vn` is subsumed, we reconstruct the value from `md`, and act
# on the reconstructed value.
vn_parent = vns[i]
dist = getdist(md, vn_parent)
val = getindex(varinfo, vn_parent, dist)
# Split the varname into its tail lens.
lens = remove_parent_lens(vn_parent, vn)
# Get the value using `lens`.
return get(val, lens)
end

function nested_setindex!(vi::VarInfo, val, vn::VarName)
res, vn_updated = nested_setindex_maybe!(vi, val, vn)
vn_updated !== nothing || error(KeyError(vn))
return res
end
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
end
function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
) where {names,sym}
return if sym in names
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
else
nothing
end
end
function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName)
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
if vn in vns
setindex!(vi, val, vn)
return vn
end

# Otherwise, we need to check if either of the `vns` subsumes `vn`.
i = findfirst(Base.Fix2(subsumes, vn), vns)
i === nothing && return nothing

vn_parent = vns[i]
dist = getdist(md, vn_parent)
val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here.
# Split the varname into its tail lens.
lens = remove_parent_lens(vn_parent, vn)
# Update the value for the parent.
val_parent_updated = set!!(val_parent, lens, val)
setindex!(vi, val_parent_updated, vn_parent)
return vn_parent
end

function nested_setindex!!(vi::VarInfo, val, vn::VarName)
nested_setindex!(vi, val, vn)
return vi
end

# The default getindex & setindex!() for get & set values
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn))
Expand Down Expand Up @@ -1131,7 +1202,8 @@ The value(s) may or may not be transformed to Euclidean space.
"""
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)
function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
return (setindex!(vi, val, vn); return vi)
setindex!(vi, val, vn)
return vi
end

"""
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -24,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.6"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
Expand Down
51 changes: 51 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,55 @@ end
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)
end
end

@testset "generated_quantities on `LKJCholesky`" begin
n = 10
d = 2
model = DynamicPPL.TestUtils.demo_lkjchol(d)
xs = [model().x for _ in 1:n]

# Extract varnames and values.
vns_and_vals_xs = map(
collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
)
vns = map(first, first(vns_and_vals_xs))
vals = map(vns_and_vals_xs) do vns_and_vals
map(last, vns_and_vals)
end

# Construct the chain.
syms = map(Symbol, vns)
vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms))

chain = MCMCChains.Chains(
permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,)
)
display(chain)

# Test!
results = generated_quantities(model, chain)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end

# With variables that aren't in the `model`.
vns_to_syms_with_extra = let d = deepcopy(vns_to_syms)
d[@varname(y)] = :y
d
end
vals_with_extra = map(enumerate(vals)) do (i, v)
vcat(v, i)
end
chain_with_extra = MCMCChains.Chains(
permutedims(stack(vals_with_extra)),
vcat(syms, [:y]);
info=(varname_to_symbol=vns_to_syms_with_extra,),
)
display(chain_with_extra)
# Test!
results = generated_quantities(model, chain_with_extra)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using MCMCChains
using Tracker
using Zygote
using Setfield
using Compat

using Distributed
using LinearAlgebra
Expand Down
Loading