Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 8, 2023
1 parent bb2a3f6 commit bf23604
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
return nothing
end
model = demo_subsetting_varinfo()
vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]

@testset "$(short_varinfo_name(varinfo))" for varinfo in [
VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext()))
Expand All @@ -483,11 +484,11 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test isempty(
setdiff(
keys(varinfo),
[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])],
vns,
),
)

@testset "$(convert(Vector{VarName}, vns))" for vns in [
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [
[@varname(s)],
[@varname(m)],
[@varname(x[1])],
Expand All @@ -504,11 +505,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
[@varname(m), @varname(x[1]), @varname(x[2])],
[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])],
]
varinfo_subset = subset(varinfo, vns)
# Should now only contain the variables in `vns`.
@test isempty(setdiff(keys(varinfo_subset), vns))
varinfo_subset = subset(varinfo, vns_subset)
# Should now only contain the variables in `vns_subset`.
@test isempty(setdiff(keys(varinfo_subset), vns_subset))
# Values should be the same.
@test [varinfo_subset[vn] for vn in vns] == [varinfo[vn] for vn in vns]
@test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset]

# `merge` with the original.
varinfo_merged = merge(varinfo, varinfo_subset)
vns_merged = keys(varinfo_merged)
# Should be equivalent.
@test union(vns_merged, vns) == intersect(vns_merged, vns)
# Values should be the same.
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
end
end
end
Expand All @@ -519,9 +528,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
VarInfo(model),
last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext()))
]

vns = DynamicPPL.TestUtils.varnames(model)

@testset "with itself" begin
# Merging itself should be a no-op.
varinfo_merged = merge(varinfo, varinfo)
Expand Down

0 comments on commit bf23604

Please sign in to comment.