diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..a0e11b65b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -77,49 +77,11 @@ function tilde_assume( return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(PriorContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) -end - -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - vi, -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) -end function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(NoDist(right), vn, vi) + return assume(nodist(right), vn, vi) end function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, NoDist(right), vn, vi) + return assume(rng, sampler, nodist(right), vn, vi) end function tilde_assume(context::PrefixContext, right, vn, vi) @@ -328,37 +290,6 @@ function dot_tilde_assume( end # `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) - end -end - function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(nodist(right), left, vn, vi) end @@ -368,46 +299,16 @@ function dot_tilde_assume( return dot_assume(rng, sampler, nodist(right), vn, left, vi) end -# `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) - end -end - # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) + return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi +) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), vi + rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) end diff --git a/src/contexts.jl b/src/contexts.jl index 53b454df6..5da4208b5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -53,7 +53,7 @@ DefaultContext() julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) -PriorContext{Nothing}(nothing) +PriorContext() ``` """ setchildcontext @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext{Nothing}(nothing) +PriorContext() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end NodeTrait(context::DefaultContext) = IsLeaf() """ - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end + PriorContext <: AbstractContext -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. +A leaf context resulting in the exclusion of likelihood terms when running the model. """ -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) +struct PriorContext <: AbstractContext end NodeTrait(context::PriorContext) = IsLeaf() """ - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end + LikelihoodContext <: AbstractContext -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +A leaf context resulting in the exclusion of prior terms when running the model. """ -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) +struct LikelihoodContext <: AbstractContext end NodeTrait(context::LikelihoodContext) = IsLeaf() """ diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index d8968a68e..c631b6f19 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) +# NOTE(torfjelde): Need this to avoid stack overflow. +function Distributions.rand!( + rng::Random.AbstractRNG, + d::NoDist{Distributions.ArrayLikeVariate{N}}, + x::AbstractArray{<:Real,N}, +) where {N} + return Distributions.rand!(rng, d.dist, x) +end Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) diff --git a/src/test_utils.jl b/src/test_utils.jl index 6199138aa..bf3aecbb3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1039,24 +1039,7 @@ function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; return test_sampler_on_demo_models(sampler, args...; kwargs...) end -""" - test_context_interface(context) - -Test that `context` implements the `AbstractContext` interface. -""" -function test_context_interface(context) - # Is a subtype of `AbstractContext`. - @test context isa DynamicPPL.AbstractContext - # Should implement `NodeTrait.` - @test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf} - # If it's a parent. - if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent - # Should implement `childcontext` and `setchildcontext` - @test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) == - context - end -end - +# Testing for contexts. """ Context that multiplies each log-prior by mod used to test whether varwise_logpriors respects child-context. @@ -1097,4 +1080,66 @@ function DynamicPPL.dot_tilde_observe( return logp * context.mod, vi end +# Dummy context to test nested behaviors. +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + context::C +end +TestParentContext() = TestParentContext(DefaultContext()) +DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestParentContext) = context.context +DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) +function Base.show(io::IO, c::TestParentContext) + return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")") +end + +""" + test_context(context::AbstractContext, model::Model) + +Test that `context` correctly implements the `AbstractContext` interface for `model`. + +This method ensures that `context` +- Correctly implements the `AbstractContext` interface. +- Correctly implements the tilde-pipeline. +""" +function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + # `NodeTrait`. + node_trait = DynamicPPL.NodeTrait(context) + # Throw error immediately if it it's missing a `NodeTrait` implementation. + node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || + throw(ValueError("Invalid NodeTrait: $node_trait")) + + # The interface methods. + if node_trait isa DynamicPPL.IsParent + # `childcontext` and `setchildcontext` + # With new child context + childcontext_new = TestParentContext() + @test DynamicPPL.childcontext( + DynamicPPL.setchildcontext(context, childcontext_new) + ) == childcontext_new + end + + # To see change, let's make sure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + PriorContext() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new + + # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). + # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. + # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the + # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. + # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. + # Untyped varinfo. + varinfo_untyped = DynamicPPL.VarInfo() + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) + @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + # Typed varinfo. + varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) + @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) +end + end diff --git a/test/contexts.jl b/test/contexts.jl index 4ec9ff945..0f6628440 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -18,16 +18,6 @@ using DynamicPPL: using EnzymeCore -# Dummy context to test nested behaviors. -struct ParentContext{C<:AbstractContext} <: AbstractContext - context::C -end -ParentContext() = ParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::ParentContext) = context.context -DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) -Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -63,88 +53,24 @@ end child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] parent_contexts = [ - ParentContext(DefaultContext()), + DynamicPPL.TestUtils.TestParentContext(DefaultContext()), SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLogdensityContext(), ConditionContext((x=1.0,)), - ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext( + (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) + ), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) - @testset "NodeTrait" begin - @testset "$context" for context in contexts - # Every `context` should have a `NodeTrait`. - @test NodeTrait(context) isa NodeTrait - end - end - - @testset "leafcontext" begin - @testset "$context" for context in child_contexts - @test leafcontext(context) === context - end - - @testset "$context" for context in parent_contexts - @test NodeTrait(leafcontext(context)) isa IsLeaf - end - end - - @testset "setleafcontext" begin - @testset "$context" for context in child_contexts - # Setting to itself should return itself. - @test setleafcontext(context, context) === context - - # Setting to a different context should return that context. - new_leaf = context isa DefaultContext ? PriorContext() : DefaultContext() - @test setleafcontext(context, new_leaf) === new_leaf - - # Also works for parent contexts. - new_leaf = ParentContext(context) - @test setleafcontext(context, new_leaf) === new_leaf - end - - @testset "$context" for context in parent_contexts - # Leaf contexts. - new_leaf = - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf - - # Setting parent contexts as "leaf" means that the new leaf should be - # the leaf of the parent context we just set as the leaf. - new_leaf = ParentContext(( - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - )) - @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) - end - end - - # `IsParent` interface. - @testset "childcontext" begin - @testset "$context" for context in parent_contexts - @test childcontext(context) isa AbstractContext - end - end - - @testset "setchildcontext" begin - @testset "nested contexts" begin - # Both of the following should result in the same context. - context1 = ParentContext(ParentContext(ParentContext())) - context2 = setchildcontext( - ParentContext(), setchildcontext(ParentContext(), ParentContext()) - ) - @test context1 === context2 - end - - @testset "$context" for context in parent_contexts - # Setting the child context to a leaf should now change the `leafcontext` accordingly. - new_leaf = - leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() - new_context = setchildcontext(context, new_leaf) - @test childcontext(new_context) === leafcontext(new_context) === new_leaf + @testset "$(context)" for context in contexts + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.test_context(context, model) end end @@ -213,32 +139,45 @@ end end end - @testset "Evaluation" begin - @testset "$context" for context in contexts - # Just making sure that we can actually sample with each of the contexts. - @test (gdemo_default(SamplingContext(context)); true) - end - end - @testset "PrefixContext" begin - ctx = @inferred PrefixContext{:f}( - PrefixContext{:e}( - PrefixContext{:d}( - PrefixContext{:c}( - PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + @testset "prefixing" begin + ctx = @inferred PrefixContext{:f}( + PrefixContext{:e}( + PrefixContext{:d}( + PrefixContext{:c}( + PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + ), ), ), - ), - ) - vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) - - vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + ) + vn = VarName{:x}() + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + + vn = VarName{:x}(((1,),)) + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + end + + context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Sample with the context. + varinfo = DynamicPPL.VarInfo() + DynamicPPL.evaluate!!(model, varinfo, context) + # Extract the resulting symbols. + vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + + # Extract the ground truth symbols. + vns_syms = Set([ + Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for + vn in DynamicPPL.TestUtils.varnames(model) + ]) + + # Check that all variables are prefixed correctly. + @test vns_syms == vns_varinfo_syms + end end @testset "SamplingContext" begin diff --git a/test/debug_utils.jl b/test/debug_utils.jl index b1897aa9b..5c309da3a 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -3,7 +3,7 @@ # HACK: Require a model to instantiate it, so let's just grab one. model = first(DynamicPPL.TestUtils.DEMO_MODELS) context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context_interface(context) + DynamicPPL.TestUtils.test_context(context) end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS