Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add from_dict #35

Merged
merged 9 commits into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.2.4"
version = "0.2.5"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -12,11 +12,13 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Compat = "3.46.0, 4.2.0"
DimensionalData = "0.20, 0.21, 0.22, 0.23"
OffsetArrays = "1"
OrderedCollections = "1"
julia = "1.6"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "OffsetArrays"]
test = ["Test", "OffsetArrays", "OrderedCollections"]
1 change: 1 addition & 0 deletions docs/src/inference_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ That is, iterating over an `InferenceData` iterates over its groups.

```@docs
convert_to_inference_data
from_dict
from_namedtuple
```

Expand Down
4 changes: 3 additions & 1 deletion src/InferenceObjects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ const SCHEMA_GROUPS_DICT = Dict(n => i for (i, n) in enumerate(SCHEMA_GROUPS))
const DEFAULT_SAMPLE_DIMS = Dimensions.key2dim((:draw, :chain))

export Dataset, InferenceData
export convert_to_dataset, convert_to_inference_data, from_namedtuple, namedtuple_to_dataset
export convert_to_dataset,
convert_to_inference_data, from_dict, from_namedtuple, namedtuple_to_dataset

include("utils.jl")
include("dimensions.jl")
Expand All @@ -36,5 +37,6 @@ include("inference_data.jl")
include("convert_dataset.jl")
include("convert_inference_data.jl")
include("from_namedtuple.jl")
include("from_dict.jl")

end # module
72 changes: 72 additions & 0 deletions src/from_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
from_dict(posterior::AbstractDict; kwargs...) -> InferenceData

Convert a `Dict` to an `InferenceData`.

# Arguments

- `posterior`: The data to be converted. Its strings must be `Symbol` or `AbstractString`,
and its values must be arrays.

# Keywords

- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
- `prior::Dict=nothing`: Draws from the prior
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process
- `observed_data::NamedTuple`: Observed data on which the `posterior` is
conditional. It should only contain data which is modeled as a random variable. Keys
are parameter names and values.
- `constant_data::NamedTuple`: Model constants, data included in the model
which is not modeled as a random variable. Keys are parameter names and values.
- `predictions_constant_data::NamedTuple`: Constants relevant to the model
predictions (i.e. new `x` values in a linear regression).
- `log_likelihood`: Pointwise log-likelihood for the data. It is recommended
to use this argument as a `NamedTuple` whose keys are observed variable names and whose
values are log likelihood arrays.
- `library`: Name of library that generated the draws
- `coords`: Map from named dimension to named indices
- `dims`: Map from variable name to names of its dimensions

# Returns

- `InferenceData`: The data with groups corresponding to the provided data

# Examples

```@example
using InferenceObjects
nchains = 2
ndraws = 100

data = Dict(
:x => rand(ndraws, nchains),
:y => randn(2, ndraws, nchains),
:z => randn(3, 2, ndraws, nchains),
)
idata = from_dict(data)
```
"""
from_dict

function from_dict(
posterior::Union{<:AbstractDict,Nothing}=nothing; prior=nothing, kwargs...
)
nt = posterior === nothing ? posterior : as_namedtuple(posterior)
nt_prior = prior === nothing ? prior : as_namedtuple(prior)
return from_namedtuple(nt; prior=nt_prior, kwargs...)
end

"""
convert_to_inference_data(obj::AbstractDict; kwargs...) -> InferenceData

Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description
of `obj` possibilities and `kwargs`.
"""
function convert_to_inference_data(data::AbstractDict; group=:posterior, kwargs...)
group = Symbol(group)
group === :posterior && return from_dict(data; kwargs...)
return from_dict(; group => data, kwargs...)
end
2 changes: 1 addition & 1 deletion src/from_namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ whose first dimensions correspond to the dimensions of the containers.
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
- `prior::Any=nothing`: Draws from the prior
- `prior=nothing`: Draws from the prior. Accepts the same types as `posterior`.
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process
- `observed_data::NamedTuple`: Observed data on which the `posterior` is
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ function rekey(d::NamedTuple, keymap)
new_keys = map(k -> get(keymap, k, k), keys(d))
return NamedTuple{new_keys}(values(d))
end

as_namedtuple(dict::AbstractDict{Symbol}) = NamedTuple(dict)
function as_namedtuple(dict::AbstractDict{<:AbstractString})
return NamedTuple(Symbol(k) => v for (k, v) in dict)
end
as_namedtuple(nt::NamedTuple) = nt
70 changes: 70 additions & 0 deletions test/from_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using InferenceObjects, OrderedCollections, Test

@testset "from_dict" begin
nchains, ndraws = 4, 10
sizes = (x=(), y=(2,), z=(3, 5))
dims = (y=[:yx], z=[:zx, :zy])
coords = (yx=["y1", "y2"], zx=1:3, zy=1:5)

dicts = [
"Dict{Symbol}" =>
Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)),
"OrderedDict{String}" =>
Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)),
]

@testset "posterior::$(type)" for (type, dict) in dicts
@test_broken @inferred from_dict(dict; dims, coords, library="MyLib")
idata1 = from_dict(dict; dims, coords, library="MyLib")
idata2 = convert_to_inference_data(dict; dims, coords, library="MyLib")
test_idata_approx_equal(idata1, idata2)
end

@testset "$(group)" for group in [
:posterior_predictive, :sample_stats, :predictions, :log_likelihood
]
library = "MyLib"
@testset "::$(type)" for (type, dict) in dicts
idata1 = from_dict(dict; group => dict, dims, coords, library)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)

idata2 = from_dict(dict; group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end

@testset "$(group)" for group in [:prior_predictive, :sample_stats_prior]
library = "MyLib"
@testset "::$(type)" for (type, dict) in dicts
idata1 = from_dict(; prior=dict, group => dict, dims, coords, library)
test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)

idata2 = from_dict(; prior=dict, group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end

@testset "$(group)" for group in
[:observed_data, :constant_data, :predictions_constant_data]
_, dict = dicts[1]
library = "MyLib"
dims = (; w=[:wx])
coords = (; wx=1:2)
idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata1, group, (:w,); library, dims, coords, default_dims=()
)

# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata2, group, (:w,); library, dims, coords, default_dims=()
)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ using InferenceObjects, Test
include("convert_dataset.jl")
include("convert_inference_data.jl")
include("from_namedtuple.jl")
include("from_dict.jl")
end
8 changes: 7 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using InferenceObjects, Test
using InferenceObjects, OrderedCollections, Test

module TestSubModule end

Expand Down Expand Up @@ -40,4 +40,10 @@ module TestSubModule end
@test new == Dict(:y => 3, :a => 4, :z => 5)
end
end

@testset "as_namedtuple" begin
@test InferenceObjects.as_namedtuple(OrderedDict(:x => 3, :y => 4)) === (x=3, y=4)
@test InferenceObjects.as_namedtuple(OrderedDict("x" => 4, "y" => 5)) === (x=4, y=5)
@test InferenceObjects.as_namedtuple((y=6, x=7)) === (y=6, x=7)
end
end