Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach varname_to_symbol mapping to Chains #2078

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.29"
version = "0.29.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -56,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.23.15"
DynamicPPL = "0.23.17"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
2 changes: 1 addition & 1 deletion ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ function _optimize(
Turing.Inference.getparams(model, f.varinfo),
DynamicPPL.getlogp(f.varinfo)
)]
varnames, _ = Turing.Inference._params_to_array(model, ts)
varnames = map(Symbol, first(Turing.Inference._params_to_array(model, ts)))

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)
Expand Down
23 changes: 14 additions & 9 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,18 +310,17 @@ end


function _params_to_array(model::DynamicPPL.Model, ts::Vector)
# TODO: Do we really need to use `Symbol` here?
names_set = OrderedSet{Symbol}()
names_set = OrderedSet{VarName}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms_and_vs = getparams(model, t)
nms = map(Symbol ∘ first, nms_and_vs)
nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
push!(names_set, nm)
end
# Convert the names and values to a single dictionary.
return Dict(nms[j] => vs[j] for j in 1:length(vs))
return OrderedDict(zip(nms, vs))
end
names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
Expand Down Expand Up @@ -379,29 +378,35 @@ function AbstractMCMC.bundle_samples(
save_state = false,
stats = missing,
sort_chain = false,
include_varname_to_symbol = true,
discard_initial = 0,
thinning = 1,
kwargs...
)
# Convert transitions to array format.
# Also retrieve the variable names.
nms, vals = _params_to_array(model, ts)
varnames, vals = _params_to_array(model, ts)
varnames_symbol = map(Symbol, varnames)

# Get the values of the extra parameters in each transition.
extra_params, extra_values = get_transition_extras(ts)

# Extract names & construct param array.
nms = [nms; extra_params]
nms = [varnames_symbol; extra_params]
parray = hcat(vals, extra_values)

# Get the average or final log evidence, if it exists.
le = getlogevidence(ts, spl, state)

# Set up the info tuple.
info = NamedTuple()

if include_varname_to_symbol
info = merge(info, (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),))
end

if save_state
info = (model = model, sampler = spl, samplerstate = state)
else
info = NamedTuple()
info = merge(info, (model = model, sampler = spl, samplerstate = state))
end

# Merge in the timing info, if available
Expand Down
10 changes: 5 additions & 5 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ function AbstractMCMC.bundle_samples(
params_vec = map(Base.Fix1(_params_to_array, model), samples)

# Extract names and values separately.
nms = params_vec[1][1]
varnames = params_vec[1][1]
varnames_symbol = map(Symbol, varnames)
vals_vec = [p[2] for p in params_vec]

# Get the values of the extra parameters in each transition.
Expand All @@ -120,7 +121,7 @@ function AbstractMCMC.bundle_samples(
extra_values_vec = [e[2] for e in extra_vec]

# Extract names & construct param array.
nms = [nms; extra_params]
nms = [varnames_symbol; extra_params]
# `hcat` first to ensure we get the right `eltype`.
x = hcat(first(vals_vec), first(extra_values_vec))
# Pre-allocate to minimize memory usage.
Expand All @@ -133,10 +134,9 @@ function AbstractMCMC.bundle_samples(
le = getlogevidence(samples, state, spl)

# Set up the info tuple.
info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this inclusion optional? Forcing it could be annoying for serialization of the chains, etc. @yebai

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip this information while serializaing chains? I think we might want to carry such information in chains by default so functions like predict/generated_quantities can work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we skip it in serialization, then those methods won't work after deserialization anyways 😕

And yes I agree it should be enabled by default, but I was also thinking it would be useful to provide a way to disable it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why we can't serialise/deserialise VarName?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can, but it means that you no longer can do using MCMCChains; deserialize(...) but you need to have done using Turing first. IMO this is quite annoying 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VarName is defined in AbstractPPL, which is lightweight. Can we make MCMCChains depend on AbstractPPL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially in the longer run, but I don't think we should do that when MCMCChains doesn't use any functionality from AbstractPPL.

if save_state
info = (model = model, sampler = spl, samplerstate = state)
else
info = NamedTuple()
info = merge(info, (model = model, sampler = spl, samplerstate = state))
end

# Concretize the array before giving it to MCMCChains.
Expand Down
Loading