Skip to content

Commit

Permalink
Add from_dict (#35)
Browse files Browse the repository at this point in the history
* Add from_dict

* Fix from_namedtuple docstring

* Generalize function signature

* Add from_dict tests

* Test also with OrderedDict

* Test as_namedtuple

* Increment patch number

* Add from_dict to docs

* Correctly import package
  • Loading branch information
sethaxen authored Nov 13, 2022
1 parent c060241 commit 84a5aa8
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 5 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.2.4"
version = "0.2.5"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -12,11 +12,13 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Compat = "3.46.0, 4.2.0"
DimensionalData = "0.20, 0.21, 0.22, 0.23"
OffsetArrays = "1"
OrderedCollections = "1"
julia = "1.6"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "OffsetArrays"]
test = ["Test", "OffsetArrays", "OrderedCollections"]
1 change: 1 addition & 0 deletions docs/src/inference_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ That is, iterating over an `InferenceData` iterates over its groups.

```@docs
convert_to_inference_data
from_dict
from_namedtuple
```

Expand Down
4 changes: 3 additions & 1 deletion src/InferenceObjects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ const SCHEMA_GROUPS_DICT = Dict(n => i for (i, n) in enumerate(SCHEMA_GROUPS))
const DEFAULT_SAMPLE_DIMS = Dimensions.key2dim((:draw, :chain))

export Dataset, InferenceData
export convert_to_dataset, convert_to_inference_data, from_namedtuple, namedtuple_to_dataset
export convert_to_dataset,
convert_to_inference_data, from_dict, from_namedtuple, namedtuple_to_dataset

include("utils.jl")
include("dimensions.jl")
Expand All @@ -36,5 +37,6 @@ include("inference_data.jl")
include("convert_dataset.jl")
include("convert_inference_data.jl")
include("from_namedtuple.jl")
include("from_dict.jl")

end # module
72 changes: 72 additions & 0 deletions src/from_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
from_dict(posterior::AbstractDict; kwargs...) -> InferenceData
Convert a `Dict` to an `InferenceData`.
# Arguments
- `posterior`: The data to be converted. Its strings must be `Symbol` or `AbstractString`,
and its values must be arrays.
# Keywords
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
- `prior::Dict=nothing`: Draws from the prior
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process
- `observed_data::NamedTuple`: Observed data on which the `posterior` is
conditional. It should only contain data which is modeled as a random variable. Keys
are parameter names and values.
- `constant_data::NamedTuple`: Model constants, data included in the model
which is not modeled as a random variable. Keys are parameter names and values.
- `predictions_constant_data::NamedTuple`: Constants relevant to the model
predictions (i.e. new `x` values in a linear regression).
- `log_likelihood`: Pointwise log-likelihood for the data. It is recommended
to use this argument as a `NamedTuple` whose keys are observed variable names and whose
values are log likelihood arrays.
- `library`: Name of library that generated the draws
- `coords`: Map from named dimension to named indices
- `dims`: Map from variable name to names of its dimensions
# Returns
- `InferenceData`: The data with groups corresponding to the provided data
# Examples
```@example
using InferenceObjects
nchains = 2
ndraws = 100
data = Dict(
:x => rand(ndraws, nchains),
:y => randn(2, ndraws, nchains),
:z => randn(3, 2, ndraws, nchains),
)
idata = from_dict(data)
```
"""
from_dict

function from_dict(
posterior::Union{<:AbstractDict,Nothing}=nothing; prior=nothing, kwargs...
)
nt = posterior === nothing ? posterior : as_namedtuple(posterior)
nt_prior = prior === nothing ? prior : as_namedtuple(prior)
return from_namedtuple(nt; prior=nt_prior, kwargs...)
end

"""
convert_to_inference_data(obj::AbstractDict; kwargs...) -> InferenceData
Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description
of `obj` possibilities and `kwargs`.
"""
function convert_to_inference_data(data::AbstractDict; group=:posterior, kwargs...)
group = Symbol(group)
group === :posterior && return from_dict(data; kwargs...)
return from_dict(; group => data, kwargs...)
end
2 changes: 1 addition & 1 deletion src/from_namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ whose first dimensions correspond to the dimensions of the containers.
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
- `prior::Any=nothing`: Draws from the prior
- `prior=nothing`: Draws from the prior. Accepts the same types as `posterior`.
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process
- `observed_data::NamedTuple`: Observed data on which the `posterior` is
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ function rekey(d::NamedTuple, keymap)
new_keys = map(k -> get(keymap, k, k), keys(d))
return NamedTuple{new_keys}(values(d))
end

as_namedtuple(dict::AbstractDict{Symbol}) = NamedTuple(dict)
function as_namedtuple(dict::AbstractDict{<:AbstractString})
return NamedTuple(Symbol(k) => v for (k, v) in dict)
end
as_namedtuple(nt::NamedTuple) = nt
70 changes: 70 additions & 0 deletions test/from_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using InferenceObjects, OrderedCollections, Test

@testset "from_dict" begin
nchains, ndraws = 4, 10
sizes = (x=(), y=(2,), z=(3, 5))
dims = (y=[:yx], z=[:zx, :zy])
coords = (yx=["y1", "y2"], zx=1:3, zy=1:5)

dicts = [
"Dict{Symbol}" =>
Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)),
"OrderedDict{String}" =>
Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)),
]

@testset "posterior::$(type)" for (type, dict) in dicts
@test_broken @inferred from_dict(dict; dims, coords, library="MyLib")
idata1 = from_dict(dict; dims, coords, library="MyLib")
idata2 = convert_to_inference_data(dict; dims, coords, library="MyLib")
test_idata_approx_equal(idata1, idata2)
end

@testset "$(group)" for group in [
:posterior_predictive, :sample_stats, :predictions, :log_likelihood
]
library = "MyLib"
@testset "::$(type)" for (type, dict) in dicts
idata1 = from_dict(dict; group => dict, dims, coords, library)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)

idata2 = from_dict(dict; group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end

@testset "$(group)" for group in [:prior_predictive, :sample_stats_prior]
library = "MyLib"
@testset "::$(type)" for (type, dict) in dicts
idata1 = from_dict(; prior=dict, group => dict, dims, coords, library)
test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)

idata2 = from_dict(; prior=dict, group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end

@testset "$(group)" for group in
[:observed_data, :constant_data, :predictions_constant_data]
_, dict = dicts[1]
library = "MyLib"
dims = (; w=[:wx])
coords = (; wx=1:2)
idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata1, group, (:w,); library, dims, coords, default_dims=()
)

# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata2, group, (:w,); library, dims, coords, default_dims=()
)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ using InferenceObjects, Test
include("convert_dataset.jl")
include("convert_inference_data.jl")
include("from_namedtuple.jl")
include("from_dict.jl")
end
8 changes: 7 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using InferenceObjects, Test
using InferenceObjects, OrderedCollections, Test

module TestSubModule end

Expand Down Expand Up @@ -40,4 +40,10 @@ module TestSubModule end
@test new == Dict(:y => 3, :a => 4, :z => 5)
end
end

@testset "as_namedtuple" begin
@test InferenceObjects.as_namedtuple(OrderedDict(:x => 3, :y => 4)) === (x=3, y=4)
@test InferenceObjects.as_namedtuple(OrderedDict("x" => 4, "y" => 5)) === (x=4, y=5)
@test InferenceObjects.as_namedtuple((y=6, x=7)) === (y=6, x=7)
end
end

2 comments on commit 84a5aa8

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/72135

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" 84a5aa8666e6f8bac25414c6eeabaf57a546a3f2
git push origin v0.2.5

Please sign in to comment.