Skip to content

Commit

Permalink
added impl of varname_and_value_leaves
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Aug 31, 2023
1 parent 549d9b1 commit 5a43012
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

#### `SimpleVarInfo`
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky
using LinearAlgebra: LinearAlgebra, Cholesky

using DocStringExtensions

Expand Down
119 changes: 119 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,122 @@ function varname_leaves(vn::VarName, val::NamedTuple)
end
return Iterators.flatten(iter)
end

"""
varname_and_value_leaves(vn::VarName, val)
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
# Examples
```jldoctest varname-and-value-leaves
julia> using DynamicPPL: varname_and_value_leaves
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```
There are also some special handling for certain types:
```jldoctest varname-and-value-leaves
julia> using LinearAlgebra
julia> x = reshape(1:4, 2, 2);
julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1,1], 1)
(x[2,1], 2)
(x[2,2], 4)
julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1,1], 1)
(x[1,2], 3)
(x[2,2], 4)
```
"""
function varname_and_value_leaves(vn::VarName, x)
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
end

# Simple struct used to represent a varname-value pair even if we use
# something like `Iterators.flatten`.
struct Leaf{T}
value::T
end

Leaf(xs...) = Leaf(xs)

# Allows us to just use `Leaf` to "terminate" recursion in `Iterators.flatten`.
Base.iterate(leaf::Leaf) = leaf, leaf
Base.iterate(::Leaf, _) = nothing

# Convenience.
value(leaf::Leaf) = leaf.value

# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
lens = DynamicPPL.Setfield.PropertyLens{sym}()
varname_and_value_leaves_inner(vn lens, get(val, lens))
end

return Iterators.flatten(iter)
end
# Special types.
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)

Check warning on line 968 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L968

Added line #L968 was not covered by tests
# TODO: Or do we use `PDMat` here?
return varname_and_value_leaves_inner(vn, x.UL)

Check warning on line 970 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L970

Added line #L970 was not covered by tests
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
end

0 comments on commit 5a43012

Please sign in to comment.