From 748823833478d7a7e87a52d8626c7b2ead352b12 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 14:58:14 +0100 Subject: [PATCH 1/9] fixed incorrect implementation of `dot_tilde_assume` for `PrefixContext` --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..d7c24fafb 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -402,12 +402,12 @@ end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) + return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), vi + rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) end From 1d211c5c8ef5bda879cb9a09b2838b1faa2545c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 14:58:33 +0100 Subject: [PATCH 2/9] removed `vars` field from `PriorContext` and `LikelihoodContext` as it's no longer used functionality (was dropped when we dropped the logprob-macro) --- src/context_implementations.jl | 101 --------------------------------- src/contexts.jl | 29 +++------- 2 files changed, 8 insertions(+), 122 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7c24fafb..50919e77e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -77,44 +77,6 @@ function tilde_assume( return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(PriorContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) -end - -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - vi, -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) -end function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end @@ -328,37 +290,6 @@ function dot_tilde_assume( end # `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) - end -end - function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(nodist(right), left, vn, vi) end @@ -368,38 +299,6 @@ function dot_tilde_assume( return dot_assume(rng, sampler, nodist(right), vn, left, vi) end -# `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) - end -end - # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) diff --git a/src/contexts.jl b/src/contexts.jl index 53b454df6..5da4208b5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -53,7 +53,7 @@ DefaultContext() julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) -PriorContext{Nothing}(nothing) +PriorContext() ``` """ setchildcontext @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext{Nothing}(nothing) +PriorContext() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end NodeTrait(context::DefaultContext) = IsLeaf() """ - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end + PriorContext <: AbstractContext -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. +A leaf context resulting in the exclusion of likelihood terms when running the model. """ -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) +struct PriorContext <: AbstractContext end NodeTrait(context::PriorContext) = IsLeaf() """ - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end + LikelihoodContext <: AbstractContext -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +A leaf context resulting in the exclusion of prior terms when running the model. """ -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) +struct LikelihoodContext <: AbstractContext end NodeTrait(context::LikelihoodContext) = IsLeaf() """ From f97828755ca867d420e6b406682495933c0086a6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:09:35 +0100 Subject: [PATCH 3/9] replaced `NoDist` with `nodist` --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50919e77e..ad507445a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -78,10 +78,10 @@ function tilde_assume( end function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(NoDist(right), vn, vi) + return assume(nodist(right), vn, vi) end function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, NoDist(right), vn, vi) + return assume(rng, sampler, nodist(right), vn, vi) end function tilde_assume(context::PrefixContext, right, vn, vi) From ae8411291bccd4d00d19b686c2b65a86e8befa2d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:09:49 +0100 Subject: [PATCH 4/9] fixed method ambiguity issue --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ad507445a..86bf132e5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -304,7 +304,7 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) +function dot_tilde_assume(rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) From 881e28ddbdf2944a8619d946c40ad6034e653a4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:10:01 +0100 Subject: [PATCH 5/9] added missing `Distributions.rand!` definition for `NoDist` --- src/distribution_wrappers.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index d8968a68e..471886688 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -42,6 +42,10 @@ Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) +# NOTE(torfjelde): Need this to avoid stack overflow. +function Distributions.rand!(rng::Random.AbstractRNG, d::NoDist{Distributions.ArrayLikeVariate{N}}, x::AbstractArray{<:Real, N}) where {N} + return Distributions.rand!(rng, d.dist, x) +end Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) From 1f6fc6229aeed40015aa23ff4d19278d2e50145f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:22:26 +0100 Subject: [PATCH 6/9] added more elaborate testing of evaluation of contexts --- test/contexts.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 4ec9ff945..10b35845a 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -215,8 +215,20 @@ end @testset "Evaluation" begin @testset "$context" for context in contexts - # Just making sure that we can actually sample with each of the contexts. - @test (gdemo_default(SamplingContext(context)); true) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the + # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. + # Untyped varinfo. + varinfo_untyped = DynamicPPL.VarInfo() + # With `SamplingContext`. + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) + # Without `SamplingContext`. + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + # Typed varinfo. + varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) + end end end From 55e24e92ab3158c56b8d5ae251df4d0fe26155aa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:37:27 +0100 Subject: [PATCH 7/9] added `DynamicPPL.TestUtils.test_context` for testing contexts and replaced much of the `test/contexts.jl` with calls to this method --- src/test_utils.jl | 71 ++++++++++++++++++++++-------- test/contexts.jl | 105 +++----------------------------------------- test/debug_utils.jl | 2 +- 3 files changed, 59 insertions(+), 119 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 6199138aa..ee6daba67 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1039,24 +1039,7 @@ function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; return test_sampler_on_demo_models(sampler, args...; kwargs...) end -""" - test_context_interface(context) - -Test that `context` implements the `AbstractContext` interface. -""" -function test_context_interface(context) - # Is a subtype of `AbstractContext`. - @test context isa DynamicPPL.AbstractContext - # Should implement `NodeTrait.` - @test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf} - # If it's a parent. - if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent - # Should implement `childcontext` and `setchildcontext` - @test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) == - context - end -end - +# Testing for contexts. """ Context that multiplies each log-prior by mod used to test whether varwise_logpriors respects child-context. @@ -1097,4 +1080,56 @@ function DynamicPPL.dot_tilde_observe( return logp * context.mod, vi end +# Dummy context to test nested behaviors. +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + context::C +end +TestParentContext() = TestParentContext(DefaultContext()) +DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestParentContext) = context.context +DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) +Base.show(io::IO, c::TestParentContext) = print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")") + +""" + test_context(context::AbstractContext, model::Model) + +Test that `context` correctly implements the `AbstractContext` interface for `model`. + +This method ensures that `context` +- Correctly implements the `AbstractContext` interface. +- Correctly implements the tilde-pipeline. +""" +function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + # `NodeTrait`. + node_trait = DynamicPPL.NodeTrait(context) + # Throw error immediately if it it's missing a `NodeTrait` implementation. + node_trait isa Union{DynamicPPL.IsLeaf, DynamicPPL.IsParent} || throw(ValueError("Invalid NodeTrait: $node_trait")) + + # The interface methods. + if node_trait isa DynamicPPL.IsParent + # `childcontext` and `setchildcontext` + # With new child context + childcontext_new = TestParentContext() + @test DynamicPPL.childcontext(DynamicPPL.setchildcontext(context, childcontext_new)) == childcontext_new + end + + # To see change, let's make sure we're using a different leaf context than the current. + leafcontext_new = DynamicPPL.leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == leafcontext_new + + # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). + # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. + # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the + # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. + # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. + # Untyped varinfo. + varinfo_untyped = DynamicPPL.VarInfo() + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + # Typed varinfo. + varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) +end + end diff --git a/test/contexts.jl b/test/contexts.jl index 10b35845a..2767bb1ab 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -18,16 +18,6 @@ using DynamicPPL: using EnzymeCore -# Dummy context to test nested behaviors. -struct ParentContext{C<:AbstractContext} <: AbstractContext - context::C -end -ParentContext() = ParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::ParentContext) = context.context -DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) -Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -63,88 +53,22 @@ end child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] parent_contexts = [ - ParentContext(DefaultContext()), + DynamicPPL.TestUtils.TestParentContext(DefaultContext()), SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLogdensityContext(), ConditionContext((x=1.0,)), - ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext((x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) - @testset "NodeTrait" begin - @testset "$context" for context in contexts - # Every `context` should have a `NodeTrait`. - @test NodeTrait(context) isa NodeTrait - end - end - - @testset "leafcontext" begin - @testset "$context" for context in child_contexts - @test leafcontext(context) === context - end - - @testset "$context" for context in parent_contexts - @test NodeTrait(leafcontext(context)) isa IsLeaf - end - end - - @testset "setleafcontext" begin - @testset "$context" for context in child_contexts - # Setting to itself should return itself. - @test setleafcontext(context, context) === context - - # Setting to a different context should return that context. - new_leaf = context isa DefaultContext ? PriorContext() : DefaultContext() - @test setleafcontext(context, new_leaf) === new_leaf - - # Also works for parent contexts. - new_leaf = ParentContext(context) - @test setleafcontext(context, new_leaf) === new_leaf - end - - @testset "$context" for context in parent_contexts - # Leaf contexts. - new_leaf = - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf - - # Setting parent contexts as "leaf" means that the new leaf should be - # the leaf of the parent context we just set as the leaf. - new_leaf = ParentContext(( - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - )) - @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) - end - end - - # `IsParent` interface. - @testset "childcontext" begin - @testset "$context" for context in parent_contexts - @test childcontext(context) isa AbstractContext - end - end - - @testset "setchildcontext" begin - @testset "nested contexts" begin - # Both of the following should result in the same context. - context1 = ParentContext(ParentContext(ParentContext())) - context2 = setchildcontext( - ParentContext(), setchildcontext(ParentContext(), ParentContext()) - ) - @test context1 === context2 - end - - @testset "$context" for context in parent_contexts - # Setting the child context to a leaf should now change the `leafcontext` accordingly. - new_leaf = - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - new_context = setchildcontext(context, new_leaf) - @test childcontext(new_context) === leafcontext(new_context) === new_leaf + @testset "$(context)" for context in contexts + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.test_context(context, model) end end @@ -213,25 +137,6 @@ end end end - @testset "Evaluation" begin - @testset "$context" for context in contexts - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - # With `SamplingContext`. - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) - # Without `SamplingContext`. - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) - # Typed varinfo. - varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) - end - end - end - @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( diff --git a/test/debug_utils.jl b/test/debug_utils.jl index b1897aa9b..5c309da3a 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -3,7 +3,7 @@ # HACK: Require a model to instantiate it, so let's just grab one. model = first(DynamicPPL.TestUtils.DEMO_MODELS) context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context_interface(context) + DynamicPPL.TestUtils.test_context(context) end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS From e5b7e44959bb81e4bacc778dcc603b3d125edfcd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:46:28 +0100 Subject: [PATCH 8/9] added proper testing for PrefixContext of all demo models --- test/contexts.jl | 51 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 2767bb1ab..0491fca3e 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -138,24 +138,43 @@ end end @testset "PrefixContext" begin - ctx = @inferred PrefixContext{:f}( - PrefixContext{:e}( - PrefixContext{:d}( - PrefixContext{:c}( - PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + @testset "prefixing" begin + ctx = @inferred PrefixContext{:f}( + PrefixContext{:e}( + PrefixContext{:d}( + PrefixContext{:c}( + PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + ), ), ), - ), - ) - vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) - - vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + ) + vn = VarName{:x}() + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + + vn = VarName{:x}(((1,),)) + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + end + + context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Sample with the context. + varinfo = DynamicPPL.VarInfo() + DynamicPPL.evaluate!!(model, varinfo, context) + # Extract the resulting symbols. + vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + + # Extract the ground truth symbols. + vns_syms = Set([ + Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for vn in DynamicPPL.TestUtils.varnames(model) + ]) + + # Check that all variables are prefixed correctly. + @test vns_syms == vns_varinfo_syms + end end @testset "SamplingContext" begin From e1c8fd104c21bc2c2d704914560905eed6a78f5b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:49:25 +0100 Subject: [PATCH 9/9] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 4 +++- src/distribution_wrappers.jl | 6 +++++- src/test_utils.jl | 20 +++++++++++++++----- test/contexts.jl | 7 +++++-- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 86bf132e5..a0e11b65b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -304,7 +304,9 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi) +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi +) return dot_tilde_assume( rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 471886688..c631b6f19 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -43,7 +43,11 @@ Base.size(dist::NoDist) = Base.size(dist.dist) Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) # NOTE(torfjelde): Need this to avoid stack overflow. -function Distributions.rand!(rng::Random.AbstractRNG, d::NoDist{Distributions.ArrayLikeVariate{N}}, x::AbstractArray{<:Real, N}) where {N} +function Distributions.rand!( + rng::Random.AbstractRNG, + d::NoDist{Distributions.ArrayLikeVariate{N}}, + x::AbstractArray{<:Real,N}, +) where {N} return Distributions.rand!(rng, d.dist, x) end Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 diff --git a/src/test_utils.jl b/src/test_utils.jl index ee6daba67..bf3aecbb3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1088,7 +1088,9 @@ TestParentContext() = TestParentContext(DefaultContext()) DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) -Base.show(io::IO, c::TestParentContext) = print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")") +function Base.show(io::IO, c::TestParentContext) + return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")") +end """ test_context(context::AbstractContext, model::Model) @@ -1103,19 +1105,27 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # `NodeTrait`. node_trait = DynamicPPL.NodeTrait(context) # Throw error immediately if it it's missing a `NodeTrait` implementation. - node_trait isa Union{DynamicPPL.IsLeaf, DynamicPPL.IsParent} || throw(ValueError("Invalid NodeTrait: $node_trait")) + node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || + throw(ValueError("Invalid NodeTrait: $node_trait")) # The interface methods. if node_trait isa DynamicPPL.IsParent # `childcontext` and `setchildcontext` # With new child context childcontext_new = TestParentContext() - @test DynamicPPL.childcontext(DynamicPPL.setchildcontext(context, childcontext_new)) == childcontext_new + @test DynamicPPL.childcontext( + DynamicPPL.setchildcontext(context, childcontext_new) + ) == childcontext_new end # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = DynamicPPL.leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == leafcontext_new + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + PriorContext() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. diff --git a/test/contexts.jl b/test/contexts.jl index 0491fca3e..0f6628440 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -59,7 +59,9 @@ end PrefixContext{:x}(DefaultContext()), PointwiseLogdensityContext(), ConditionContext((x=1.0,)), - ConditionContext((x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,)))), + ConditionContext( + (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) + ), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), ConditionContext((x=[1.0, missing],)), ] @@ -169,7 +171,8 @@ end # Extract the ground truth symbols. vns_syms = Set([ - Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for vn in DynamicPPL.TestUtils.varnames(model) + Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for + vn in DynamicPPL.TestUtils.varnames(model) ]) # Check that all variables are prefixed correctly.