Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for generated_quantities #534

Merged
merged 31 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9735066
added method for extracting the child lens from a varname subsumed by
torfjelde Sep 5, 2023
306946d
added nested_getindex and nested_setindex! for VarInfo
torfjelde Sep 5, 2023
af2a30f
added ConstructionBase.setproperties implementation for `Cholesky`
torfjelde Sep 5, 2023
ce0b2e7
fixed minor formatting issue
torfjelde Sep 5, 2023
3a8201c
added `supports_varname_indexing` for chains and use this in generate…
torfjelde Sep 6, 2023
dc7c675
use a private method rather than overloading getindex for Chains
torfjelde Sep 6, 2023
e4d964e
removed getindex overloads in nested_index testing
torfjelde Sep 6, 2023
a18b435
moved generated_quantities tests to test/model.jl
torfjelde Sep 6, 2023
34e422a
Apply suggestions from code review
torfjelde Sep 6, 2023
c4b2556
will now also correctly set variables to be resampled, etc.
torfjelde Sep 6, 2023
3b27435
Update test/model.jl
torfjelde Sep 6, 2023
ef67495
Update src/varinfo.jl
torfjelde Sep 6, 2023
df2d8e3
added Compat as a test dep so we can methods such as stack
torfjelde Sep 6, 2023
035d592
improved overload of ConstructionBase.setproperties
torfjelde Sep 6, 2023
24076b5
Apply suggestions from code review
torfjelde Sep 6, 2023
c604503
added docstring to remove_parent_lens
torfjelde Sep 6, 2023
ad4a5bd
removed methods which are not useful for the purpose of this PR
torfjelde Sep 6, 2023
378897e
noticed we're incorrectly using chain rather than chain_params in gen…
torfjelde Sep 6, 2023
220646e
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 6, 2023
08dd71a
fixed doctests
torfjelde Sep 6, 2023
c09b780
added Requires.jl
torfjelde Sep 6, 2023
44335d4
Update src/DynamicPPL.jl
torfjelde Sep 6, 2023
521775f
bump patch version
torfjelde Sep 6, 2023
3594017
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 6, 2023
2ae113c
Update src/DynamicPPL.jl
torfjelde Sep 7, 2023
e625fc5
moved new generated_quantities functionality into setval_and_resample!
torfjelde Sep 7, 2023
6d16806
Apply suggestions from code review
torfjelde Sep 7, 2023
1e42770
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 7, 2023
628809c
Update src/chains.jl
torfjelde Sep 7, 2023
8b9e3e0
bump compat entry for ConstructionBase.jl
torfjelde Sep 8, 2023
46e5a94
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.16"
version = "0.23.17"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -16,19 +16,11 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
Expand All @@ -39,9 +31,19 @@ ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
MacroTools = "0.5.6"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
54 changes: 48 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,57 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains)
chain_parameters = MCMCChains.get_sections(chain, :parameters)
# TODO: Add proper overload of `Base.getindex` to Turing.jl?
function _getindex(c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx)
DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end

function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx)
model(varinfo)
if DynamicPPL.supports_varname_indexing(chain)
# First we need to set every variable to be resampled.
for vn in keys(varinfo)
DynamicPPL.set_flag!(varinfo, vn, "del")
end
# Then we set the variables in `varinfo` from `chain`.
for vn in keys(chain.info.varname_to_symbol)
vn_updated = DynamicPPL.nested_setindex_maybe!(
varinfo, _getindex(chain, sample_idx, vn, chain_idx), vn
)

# Unset the `del` flag if we found something.
if vn_updated !== nothing
# NOTE: This will be triggered even if only a subset of a variable has been set!
DynamicPPL.unset_flag!(varinfo, vn_updated, "del")
end
end
else
# NOTE: This can be quite unreliable (but will warn the uesr in that case).
# Hence the above path is much more preferable.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
end
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,16 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
end
end

end # module
3 changes: 3 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,9 @@
end
end

# HACK: Be better.
supports_varname_indexing(chain::AbstractChains) = false

Check warning on line 1261 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L1261

Added line #L1261 was not covered by tests

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand Down
59 changes: 59 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,42 @@
return current_parent, current_child, condition(current_parent)
end

"""
remove_parent_lens(vn_parent::VarName, vn_child::VarName)

Remove the parent lens `vn_parent` from `vn_child`.

# Examples
```jldoctest
julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a))
(@lens _.a)

julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1]))
(@lens _.a[1])

julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1]))
(@lens _[1])

julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b))
(@lens _[1].b)

julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a))
ERROR: Could not find x.a in x.a

julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1]))
ERROR: Could not find x.a[2] in x.a[1]
```
"""
function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitlens(getlens(vn_child)) do lens
l = lens === nothing ? Setfield.IdentityLens() : lens
VarName(vn_child, l) == vn_parent
end

issuccess || error("Could not find $vn_parent in $vn_child")
return child
end

# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
Expand Down Expand Up @@ -1045,3 +1081,26 @@
for I in CartesianIndices(x) if I[1] <= I[2]
)
end

# TODO: Remove as soon as https://github.com/JuliaObjects/ConstructionBase.jl/pull/80 goes through.
ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, ::NamedTuple{()}) = C

Check warning on line 1086 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1086

Added line #L1086 was not covered by tests
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:L,)})
return LinearAlgebra.Cholesky(
C.uplo === 'U' ? permutedims(patch.L) : patch.L, C.uplo, C.info
)
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:U,)})
return LinearAlgebra.Cholesky(

Check warning on line 1093 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1092-L1093

Added lines #L1092 - L1093 were not covered by tests
C.uplo === 'L' ? permutedims(patch.U) : patch.U, C.uplo, C.info
)
end
function ConstructionBase.setproperties(

Check warning on line 1097 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1097

Added line #L1097 was not covered by tests
C::LinearAlgebra.Cholesky, patch::NamedTuple{(:UL,)}
)
return LinearAlgebra.Cholesky(patch.UL, C.uplo, C.info)

Check warning on line 1100 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1100

Added line #L1100 was not covered by tests
end
@nospecialize function ConstructionBase.setproperties(

Check warning on line 1102 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1102

Added line #L1102 was not covered by tests
C::LinearAlgebra.Cholesky, patch::NamedTuple
)
return error("Can only patch one of :L, :U, :UL at the time")

Check warning on line 1105 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1105

Added line #L1105 was not covered by tests
end
38 changes: 37 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,41 @@
return Expr(:||, false, out...)
end

function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)

Check warning on line 1068 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1067-L1068

Added lines #L1067 - L1068 were not covered by tests
end
function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
) where {names,sym}
return if sym in names
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
else
nothing
end
end
function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName)
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
if vn in vns
setindex!(vi, val, vn)
return vn

Check warning on line 1084 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1083-L1084

Added lines #L1083 - L1084 were not covered by tests
end

# Otherwise, we need to check if either of the `vns` subsumes `vn`.
i = findfirst(Base.Fix2(subsumes, vn), vns)
i === nothing && return nothing

vn_parent = vns[i]
dist = getdist(md, vn_parent)
val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here.
# Split the varname into its tail lens.
lens = remove_parent_lens(vn_parent, vn)
# Update the value for the parent.
val_parent_updated = set!!(val_parent, lens, val)
setindex!(vi, val_parent_updated, vn_parent)
return vn_parent
end

# The default getindex & setindex!() for get & set values
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn))
Expand Down Expand Up @@ -1131,7 +1166,8 @@
"""
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)
function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
return (setindex!(vi, val, vn); return vi)
setindex!(vi, val, vn)
return vi
end

"""
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -24,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.6"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
Expand Down
51 changes: 51 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,55 @@ end
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)
end
end

@testset "generated_quantities on `LKJCholesky`" begin
n = 10
d = 2
model = DynamicPPL.TestUtils.demo_lkjchol(d)
xs = [model().x for _ in 1:n]

# Extract varnames and values.
vns_and_vals_xs = map(
collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
)
vns = map(first, first(vns_and_vals_xs))
vals = map(vns_and_vals_xs) do vns_and_vals
map(last, vns_and_vals)
end

# Construct the chain.
syms = map(Symbol, vns)
vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms))

chain = MCMCChains.Chains(
permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,)
)
display(chain)

# Test!
results = generated_quantities(model, chain)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end

# With variables that aren't in the `model`.
vns_to_syms_with_extra = let d = deepcopy(vns_to_syms)
d[@varname(y)] = :y
d
end
vals_with_extra = map(enumerate(vals)) do (i, v)
vcat(v, i)
end
chain_with_extra = MCMCChains.Chains(
permutedims(stack(vals_with_extra)),
vcat(syms, [:y]);
info=(varname_to_symbol=vns_to_syms_with_extra,),
)
display(chain_with_extra)
# Test!
results = generated_quantities(model, chain_with_extra)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using MCMCChains
using Tracker
using Zygote
using Setfield
using Compat

using Distributed
using LinearAlgebra
Expand Down
Loading