From 7ac5b2fdadf505fbd2aa3708b67581b23207a69d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:06:22 +0100 Subject: [PATCH 1/9] Add from_dict --- src/InferenceObjects.jl | 4 ++- src/from_dict.jl | 70 +++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 6 ++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 src/from_dict.jl diff --git a/src/InferenceObjects.jl b/src/InferenceObjects.jl index 4830bf58..2e984a18 100644 --- a/src/InferenceObjects.jl +++ b/src/InferenceObjects.jl @@ -27,7 +27,8 @@ const SCHEMA_GROUPS_DICT = Dict(n => i for (i, n) in enumerate(SCHEMA_GROUPS)) const DEFAULT_SAMPLE_DIMS = Dimensions.key2dim((:draw, :chain)) export Dataset, InferenceData -export convert_to_dataset, convert_to_inference_data, from_namedtuple, namedtuple_to_dataset +export convert_to_dataset, + convert_to_inference_data, from_dict, from_namedtuple, namedtuple_to_dataset include("utils.jl") include("dimensions.jl") @@ -36,5 +37,6 @@ include("inference_data.jl") include("convert_dataset.jl") include("convert_inference_data.jl") include("from_namedtuple.jl") +include("from_dict.jl") end # module diff --git a/src/from_dict.jl b/src/from_dict.jl new file mode 100644 index 00000000..d32bf984 --- /dev/null +++ b/src/from_dict.jl @@ -0,0 +1,70 @@ +""" + from_dict(posterior::AbstractDict; kwargs...) -> InferenceData + +Convert a `Dict` to an `InferenceData`. + +# Arguments + + - `posterior`: The data to be converted. Its strings must be `Symbol` or `AbstractString`, + and its values must be arrays. + +# Keywords + + - `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution + - `sample_stats::Any=nothing`: Statistics of the posterior sampling process + - `predictions::Any=nothing`: Out-of-sample predictions for the posterior. + - `prior::Dict=nothing`: Draws from the prior + - `prior_predictive::Any=nothing`: Draws from the prior predictive distribution + - `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process + - `observed_data::NamedTuple`: Observed data on which the `posterior` is + conditional. It should only contain data which is modeled as a random variable. Keys + are parameter names and values. + - `constant_data::NamedTuple`: Model constants, data included in the model + which is not modeled as a random variable. Keys are parameter names and values. + - `predictions_constant_data::NamedTuple`: Constants relevant to the model + predictions (i.e. new `x` values in a linear regression). + - `log_likelihood`: Pointwise log-likelihood for the data. It is recommended + to use this argument as a `NamedTuple` whose keys are observed variable names and whose + values are log likelihood arrays. + - `library`: Name of library that generated the draws + - `coords`: Map from named dimension to named indices + - `dims`: Map from variable name to names of its dimensions + +# Returns + + - `InferenceData`: The data with groups corresponding to the provided data + +# Examples + +```@example +using InferenceObjects +nchains = 2 +ndraws = 100 + +data = Dict( + :x => rand(ndraws, nchains), + :y => randn(2, ndraws, nchains), + :z => randn(3, 2, ndraws, nchains), +) +idata = from_dict(data) +``` +""" +from_dict + +function from_dict(posterior::Dict; prior=nothing, kwargs...) + nt = as_namedtuple(posterior) + nt_prior = prior === nothing ? prior : as_namedtuple(prior) + return from_namedtuple(nt; prior=nt_prior, kwargs...) +end + +""" + convert_to_inference_data(obj::AbstractDict; kwargs...) -> InferenceData + +Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description +of `obj` possibilities and `kwargs`. +""" +function convert_to_inference_data(data::AbstractDict; group=:posterior, kwargs...) + group = Symbol(group) + group === :posterior && return from_dict(data; kwargs...) + return from_dict(; group => data, kwargs...) +end diff --git a/src/utils.jl b/src/utils.jl index 2270ff12..5f89d1b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,3 +67,9 @@ function rekey(d::NamedTuple, keymap) new_keys = map(k -> get(keymap, k, k), keys(d)) return NamedTuple{new_keys}(values(d)) end + +as_namedtuple(dict::AbstractDict{Symbol}) = NamedTuple(dict) +function as_namedtuple(dict::AbstractDict{<:AbstractString}) + return NamedTuple(Symbol(k) => v for (k, v) in dict) +end +as_namedtuple(nt::NamedTuple) = nt From 89db7db37e3869e9fe975c6b2282e3a1b1eb3b98 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:06:31 +0100 Subject: [PATCH 2/9] Fix from_namedtuple docstring --- src/from_namedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/from_namedtuple.jl b/src/from_namedtuple.jl index 70a5c03b..102dcfbb 100644 --- a/src/from_namedtuple.jl +++ b/src/from_namedtuple.jl @@ -33,7 +33,7 @@ whose first dimensions correspond to the dimensions of the containers. - `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution - `sample_stats::Any=nothing`: Statistics of the posterior sampling process - `predictions::Any=nothing`: Out-of-sample predictions for the posterior. - - `prior::Any=nothing`: Draws from the prior + - `prior=nothing`: Draws from the prior. Accepts the same types as `posterior`. - `prior_predictive::Any=nothing`: Draws from the prior predictive distribution - `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process - `observed_data::NamedTuple`: Observed data on which the `posterior` is From 94021133bd1d7bbcd2e13a0959aebd2cb5920ae4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:16:01 +0100 Subject: [PATCH 3/9] Generalize function signature --- src/from_dict.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/from_dict.jl b/src/from_dict.jl index d32bf984..92b968a7 100644 --- a/src/from_dict.jl +++ b/src/from_dict.jl @@ -51,8 +51,10 @@ idata = from_dict(data) """ from_dict -function from_dict(posterior::Dict; prior=nothing, kwargs...) - nt = as_namedtuple(posterior) +function from_dict( + posterior::Union{<:AbstractDict,Nothing}=nothing; prior=nothing, kwargs... +) + nt = posterior === nothing ? posterior : as_namedtuple(posterior) nt_prior = prior === nothing ? prior : as_namedtuple(prior) return from_namedtuple(nt; prior=nt_prior, kwargs...) end From 5e59a5a9aef67a9937e6a65e45d2f071a6727288 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:16:13 +0100 Subject: [PATCH 4/9] Add from_dict tests --- test/from_dict.jl | 70 +++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 71 insertions(+) create mode 100644 test/from_dict.jl diff --git a/test/from_dict.jl b/test/from_dict.jl new file mode 100644 index 00000000..b3680f2c --- /dev/null +++ b/test/from_dict.jl @@ -0,0 +1,70 @@ +using InferenceObjects, Test + +@testset "from_dict" begin + nchains, ndraws = 4, 10 + sizes = (x=(), y=(2,), z=(3, 5)) + dims = (y=[:yx], z=[:zx, :zy]) + coords = (yx=["y1", "y2"], zx=1:3, zy=1:5) + + dicts = [ + "Dict{Symbol}" => + Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), + "Dict{String}" => + Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), + ] + + @testset "posterior::$(type)" for (type, dict) in dicts + @test_broken @inferred from_dict(dict; dims, coords, library="MyLib") + idata1 = from_dict(dict; dims, coords, library="MyLib") + idata2 = convert_to_inference_data(dict; dims, coords, library="MyLib") + test_idata_approx_equal(idata1, idata2) + end + + @testset "$(group)" for group in [ + :posterior_predictive, :sample_stats, :predictions, :log_likelihood + ] + library = "MyLib" + @testset "::$(type)" for (type, dict) in dicts + idata1 = from_dict(dict; group => dict, dims, coords, library) + test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) + + idata2 = from_dict(dict; group => (:x,), dims, coords, library) + test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords) + test_idata_group_correct(idata2, group, (:x,); library, dims, coords) + end + end + + @testset "$(group)" for group in [:prior_predictive, :sample_stats_prior] + library = "MyLib" + @testset "::$(type)" for (type, dict) in dicts + idata1 = from_dict(; prior=dict, group => dict, dims, coords, library) + test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords) + test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) + + idata2 = from_dict(; prior=dict, group => (:x,), dims, coords, library) + test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords) + test_idata_group_correct(idata2, group, (:x,); library, dims, coords) + end + end + + @testset "$(group)" for group in + [:observed_data, :constant_data, :predictions_constant_data] + _, dict = dicts[1] + library = "MyLib" + dims = (; w=[:wx]) + coords = (; wx=1:2) + idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) + test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords) + test_idata_group_correct( + idata1, group, (:w,); library, dims, coords, default_dims=() + ) + + # ensure that dims are matched to named tuple keys + # https://github.com/arviz-devs/ArviZ.jl/issues/96 + idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) + test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords) + test_idata_group_correct( + idata2, group, (:w,); library, dims, coords, default_dims=() + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3004af19..55a52053 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,5 @@ using InferenceObjects, Test include("convert_dataset.jl") include("convert_inference_data.jl") include("from_namedtuple.jl") + include("from_dict.jl") end From 262a1ac26bc06c0b08ba0977eed001e6f969c8a5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:18:57 +0100 Subject: [PATCH 5/9] Test also with OrderedDict --- Project.toml | 4 +++- test/from_dict.jl | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 49160eb0..8d35b386 100644 --- a/Project.toml +++ b/Project.toml @@ -12,11 +12,13 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Compat = "3.46.0, 4.2.0" DimensionalData = "0.20, 0.21, 0.22, 0.23" OffsetArrays = "1" +OrderedCollections = "1" julia = "1.6" [extras] OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "OffsetArrays"] +test = ["Test", "OffsetArrays", "OrderedCollections"] diff --git a/test/from_dict.jl b/test/from_dict.jl index b3680f2c..13e82519 100644 --- a/test/from_dict.jl +++ b/test/from_dict.jl @@ -1,4 +1,4 @@ -using InferenceObjects, Test +using InferenceObjects, OrderedCollections, Test @testset "from_dict" begin nchains, ndraws = 4, 10 @@ -9,7 +9,7 @@ using InferenceObjects, Test dicts = [ "Dict{Symbol}" => Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), - "Dict{String}" => + "OrderedDict{String}" => Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), ] From 655b460ab183553c8b756e9755a60c824e72632f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:21:28 +0100 Subject: [PATCH 6/9] Test as_namedtuple --- test/utils.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index a1c0b4f5..a6786eec 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using InferenceObjects, Test +using InferenceObjects, OrderedDict, Test module TestSubModule end @@ -40,4 +40,10 @@ module TestSubModule end @test new == Dict(:y => 3, :a => 4, :z => 5) end end + + @testset "as_namedtuple" begin + @test InferenceObjects.as_namedtuple(OrderedDict(:x => 3, :y => 4)) === (x=3, y=4) + @test InferenceObjects.as_namedtuple(OrderedDict("x" => 4, "y" => 5)) === (x=4, y=5) + @test InferenceObjects.as_namedtuple((y=6, x=7)) === (y=6, x=7) + end end From eb2cd74e67756800493dc462ab4e34198ba37ef5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:21:42 +0100 Subject: [PATCH 7/9] Increment patch number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d35b386..ceb2dbde 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InferenceObjects" uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" authors = ["Seth Axen and contributors"] -version = "0.2.4" +version = "0.2.5" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From f499070502cc1639ae570056e9f9910eb2cfb7db Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 00:22:46 +0100 Subject: [PATCH 8/9] Add from_dict to docs --- docs/src/inference_data.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/inference_data.md b/docs/src/inference_data.md index 2e92f204..5c3ba1af 100644 --- a/docs/src/inference_data.md +++ b/docs/src/inference_data.md @@ -33,6 +33,7 @@ That is, iterating over an `InferenceData` iterates over its groups. ```@docs convert_to_inference_data +from_dict from_namedtuple ``` From 102052e711c2fbe59a8a73c9fe8de428245553bc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 09:01:48 +0100 Subject: [PATCH 9/9] Correctly import package --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index a6786eec..a15bd2a0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using InferenceObjects, OrderedDict, Test +using InferenceObjects, OrderedCollections, Test module TestSubModule end