Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 8, 2023
1 parent bf23604 commit 743c8b6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 28 deletions.
36 changes: 16 additions & 20 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,11 @@ end
Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable.
"""
Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right)
Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right)
Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =
_merge(varinfo_left, varinfo_right)
function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo)
return _merge(varinfo_left, varinfo_right)
end

function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
Expand All @@ -314,9 +317,8 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
end

function merge_metadata(
metadata_left::NamedTuple{names_left},
metadata_right::NamedTuple{names_right}
) where {names_left, names_right}
metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right}
) where {names_left,names_right}
# TODO: Improve this. Maybe make `@generated`?
metadata = map(names_left) do sym
if sym in names_right
Expand All @@ -332,7 +334,9 @@ function merge_metadata(
end
end

return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...))
return NamedTuple{(names_left..., names_right_only...)}(
tuple(metadata..., metadata_right_only...)
)
end

function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
Expand Down Expand Up @@ -361,13 +365,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)

# Initialize required fields for `metadata`.
vns = VarName[]
idcs = Dict{VarName, Int}()
idcs = Dict{VarName,Int}()
ranges = Vector{UnitRange{Int}}()
vals = T[]
dists = D[]
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
orders = Int[]
flags = Dict{String, BitVector}()
flags = Dict{String,BitVector}()
# Initialize the `flags`.
for k in union(keys(metadata_left.flags), keys(metadata_right.flags))
flags[k] = BitVector()
Expand Down Expand Up @@ -442,16 +446,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
end
end

return Metadata(
idcs,
vns,
ranges,
vals,
dists,
gids,
orders,
flags,
)
return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags)
end

const VarView = Union{Int,UnitRange,Vector{Int}}
Expand Down Expand Up @@ -1601,7 +1596,6 @@ run before sampling `vn`.
getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn)
getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)]


#######################################
# Rand & replaying method for VarInfo #
#######################################
Expand All @@ -1614,7 +1608,9 @@ Check whether `vn` has a true value for `flag` in `vi`.
function is_flagged(vi::VarInfo, vn::VarName, flag::String)
return is_flagged(getmetadata(vi, vn), vn, flag)
end
is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)]
function is_flagged(metadata::Metadata, vn::VarName, flag::String)
return metadata.flags[flag][getidx(metadata, vn)]
end

"""
unset_flag!(vi::VarInfo, vn::VarName, flag::String)
Expand Down
13 changes: 5 additions & 8 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
]

# All variables.
@test isempty(
setdiff(
keys(varinfo),
vns,
),
)
@test isempty(setdiff(keys(varinfo), vns))

@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [
[@varname(s)],
Expand Down Expand Up @@ -526,7 +521,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(short_varinfo_name(varinfo))" for varinfo in [
VarInfo(model),
last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext()))
last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())),
]
vns = DynamicPPL.TestUtils.varnames(model)
@testset "with itself" begin
Expand All @@ -551,7 +546,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "with different value" begin
x = DynamicPPL.TestUtils.rand(model)
varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns)
varinfo_changed = DynamicPPL.TestUtils.update_values!!(
deepcopy(varinfo), x, vns
)
# After `merge`, we should have the same values as `x`.
varinfo_merged = merge(varinfo, varinfo_changed)
DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns)
Expand Down

0 comments on commit 743c8b6

Please sign in to comment.