-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add from_dict * Fix from_namedtuple docstring * Generalize function signature * Add from_dict tests * Test also with OrderedDict * Test as_namedtuple * Increment patch number * Add from_dict to docs * Correctly import package
- Loading branch information
Showing
9 changed files
with
165 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "InferenceObjects" | ||
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" | ||
authors = ["Seth Axen <[email protected]> and contributors"] | ||
version = "0.2.4" | ||
version = "0.2.5" | ||
|
||
[deps] | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
|
@@ -12,11 +12,13 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" | |
Compat = "3.46.0, 4.2.0" | ||
DimensionalData = "0.20, 0.21, 0.22, 0.23" | ||
OffsetArrays = "1" | ||
OrderedCollections = "1" | ||
julia = "1.6" | ||
|
||
[extras] | ||
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" | ||
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test", "OffsetArrays"] | ||
test = ["Test", "OffsetArrays", "OrderedCollections"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
from_dict(posterior::AbstractDict; kwargs...) -> InferenceData | ||
Convert a `Dict` to an `InferenceData`. | ||
# Arguments | ||
- `posterior`: The data to be converted. Its strings must be `Symbol` or `AbstractString`, | ||
and its values must be arrays. | ||
# Keywords | ||
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution | ||
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process | ||
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior. | ||
- `prior::Dict=nothing`: Draws from the prior | ||
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution | ||
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process | ||
- `observed_data::NamedTuple`: Observed data on which the `posterior` is | ||
conditional. It should only contain data which is modeled as a random variable. Keys | ||
are parameter names and values. | ||
- `constant_data::NamedTuple`: Model constants, data included in the model | ||
which is not modeled as a random variable. Keys are parameter names and values. | ||
- `predictions_constant_data::NamedTuple`: Constants relevant to the model | ||
predictions (i.e. new `x` values in a linear regression). | ||
- `log_likelihood`: Pointwise log-likelihood for the data. It is recommended | ||
to use this argument as a `NamedTuple` whose keys are observed variable names and whose | ||
values are log likelihood arrays. | ||
- `library`: Name of library that generated the draws | ||
- `coords`: Map from named dimension to named indices | ||
- `dims`: Map from variable name to names of its dimensions | ||
# Returns | ||
- `InferenceData`: The data with groups corresponding to the provided data | ||
# Examples | ||
```@example | ||
using InferenceObjects | ||
nchains = 2 | ||
ndraws = 100 | ||
data = Dict( | ||
:x => rand(ndraws, nchains), | ||
:y => randn(2, ndraws, nchains), | ||
:z => randn(3, 2, ndraws, nchains), | ||
) | ||
idata = from_dict(data) | ||
``` | ||
""" | ||
from_dict | ||
|
||
function from_dict( | ||
posterior::Union{<:AbstractDict,Nothing}=nothing; prior=nothing, kwargs... | ||
) | ||
nt = posterior === nothing ? posterior : as_namedtuple(posterior) | ||
nt_prior = prior === nothing ? prior : as_namedtuple(prior) | ||
return from_namedtuple(nt; prior=nt_prior, kwargs...) | ||
end | ||
|
||
""" | ||
convert_to_inference_data(obj::AbstractDict; kwargs...) -> InferenceData | ||
Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description | ||
of `obj` possibilities and `kwargs`. | ||
""" | ||
function convert_to_inference_data(data::AbstractDict; group=:posterior, kwargs...) | ||
group = Symbol(group) | ||
group === :posterior && return from_dict(data; kwargs...) | ||
return from_dict(; group => data, kwargs...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
using InferenceObjects, OrderedCollections, Test | ||
|
||
@testset "from_dict" begin | ||
nchains, ndraws = 4, 10 | ||
sizes = (x=(), y=(2,), z=(3, 5)) | ||
dims = (y=[:yx], z=[:zx, :zy]) | ||
coords = (yx=["y1", "y2"], zx=1:3, zy=1:5) | ||
|
||
dicts = [ | ||
"Dict{Symbol}" => | ||
Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), | ||
"OrderedDict{String}" => | ||
Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), | ||
] | ||
|
||
@testset "posterior::$(type)" for (type, dict) in dicts | ||
@test_broken @inferred from_dict(dict; dims, coords, library="MyLib") | ||
idata1 = from_dict(dict; dims, coords, library="MyLib") | ||
idata2 = convert_to_inference_data(dict; dims, coords, library="MyLib") | ||
test_idata_approx_equal(idata1, idata2) | ||
end | ||
|
||
@testset "$(group)" for group in [ | ||
:posterior_predictive, :sample_stats, :predictions, :log_likelihood | ||
] | ||
library = "MyLib" | ||
@testset "::$(type)" for (type, dict) in dicts | ||
idata1 = from_dict(dict; group => dict, dims, coords, library) | ||
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) | ||
|
||
idata2 = from_dict(dict; group => (:x,), dims, coords, library) | ||
test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords) | ||
test_idata_group_correct(idata2, group, (:x,); library, dims, coords) | ||
end | ||
end | ||
|
||
@testset "$(group)" for group in [:prior_predictive, :sample_stats_prior] | ||
library = "MyLib" | ||
@testset "::$(type)" for (type, dict) in dicts | ||
idata1 = from_dict(; prior=dict, group => dict, dims, coords, library) | ||
test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords) | ||
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) | ||
|
||
idata2 = from_dict(; prior=dict, group => (:x,), dims, coords, library) | ||
test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords) | ||
test_idata_group_correct(idata2, group, (:x,); library, dims, coords) | ||
end | ||
end | ||
|
||
@testset "$(group)" for group in | ||
[:observed_data, :constant_data, :predictions_constant_data] | ||
_, dict = dicts[1] | ||
library = "MyLib" | ||
dims = (; w=[:wx]) | ||
coords = (; wx=1:2) | ||
idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) | ||
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords) | ||
test_idata_group_correct( | ||
idata1, group, (:w,); library, dims, coords, default_dims=() | ||
) | ||
|
||
# ensure that dims are matched to named tuple keys | ||
# https://github.com/arviz-devs/ArviZ.jl/issues/96 | ||
idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) | ||
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords) | ||
test_idata_group_correct( | ||
idata2, group, (:w,); library, dims, coords, default_dims=() | ||
) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84a5aa8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
84a5aa8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/72135
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: