From 4affc28b341f4763bd1abc8523e4e209e9f6aa6e Mon Sep 17 00:00:00 2001 From: Jaime RZ Date: Wed, 16 Aug 2023 23:45:07 +0200 Subject: [PATCH] MH Constructor (#2037) * first draft * abstractcontext + tests * bug * externalsampler() in tests * Name Tupple problems * moving stuff to DynamicPPL RP * using new DynamicPPL PR * mistakenly removed line * specific constructors * no StaticMH RWMH * Bump bijectors compat (#2052) * CompatHelper: bump compat for Bijectors to 0.13, (keep existing compat) * Update Project.toml * Replacement for #2039 (#2040) * Fix testset for external samplers * Update abstractmcmc.jl * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde * Update FillArrays compat to 1.4.1 (#2035) * Update FillArrays compat to 1.4.0 * Update test compat * Try to enable ReverseDiff tests * Update Project.toml * Update Project.toml * Bump version * Revert dependencies on FillArrays (#2042) * Update Project.toml * Update Project.toml * Fix redundant definition of `getstats` (#2044) * Fix redundant definition of `getstats` * Update Inference.jl * Revert "Update Inference.jl" This reverts commit e4f51c24fa7450d625d18b21ca3a273bb2d736d0. * Bump version --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Transfer some test utility function into DynamicPPL (#2049) * Update OptimInterface.jl * Only run optimisation tests in numerical stage. * fix function lookup after moving functions --------- Co-authored-by: Xianda Sun * Move Optim support to extension (#2051) * Move Optim support to extension * More imports * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --------- Co-authored-by: CompatHelper Julia Co-authored-by: haris organtzidis Co-authored-by: Tor Erlend Fjelde Co-authored-by: David Widmann Co-authored-by: Xianda Sun Co-authored-by: Cameron Pfiffer * Bugfixes. * Add TODO. * Update mh.jl * Update Inference.jl * Removed obsolete exports. * removed unnecessary import of extract_priors * added missing ) in MH tests * fixed incorrect referneces to AdvancedMH in tests * improve ESLogDensityFunction * remove hardcoding of SimpleVarInfo * added fixme comment * minor style changes * fixed issues with MH with RandomWalkProposal being used as an external sampler * fixed accidental typo * move definitions of unflatten for NamedTuple * improved TODO * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: CompatHelper Julia Co-authored-by: haris organtzidis Co-authored-by: Tor Erlend Fjelde Co-authored-by: David Widmann Co-authored-by: Xianda Sun Co-authored-by: Cameron Pfiffer Co-authored-by: Hong Ge --- src/Turing.jl | 1 - src/inference/Inference.jl | 18 ++++++++++++++++++ src/inference/mh.jl | 15 +++++++++++++++ test/inference/mh.jl | 6 ++++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/Turing.jl b/src/Turing.jl index 11dcbdb6f..33286a665 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -73,7 +73,6 @@ export @model, # modelling Prior, # Sampling from the prior MH, # classic sampling - RWMH, Emcee, ESS, Gibbs, diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 73f190dcc..45ae434a4 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -99,6 +99,24 @@ Wrap a sampler so it can be used as an inference algorithm. """ externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler) +""" + ESLogDensityFunction + +A log density function for the External sampler. + +""" +const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext} +function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) + return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) +end + +# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL. +function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) + set_namedtuple!(deepcopy(vi), θ) + return vi +end +DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) + # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end diff --git a/src/inference/mh.jl b/src/inference/mh.jl index ddbeaa2c5..dd97efd18 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,6 +188,20 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end +# Some of the proposals require working in unconstrained space. +transform_maybe(proposal::AMH.Proposal) = proposal +function transform_maybe(proposal::AMH.RandomWalkProposal) + return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) +end + +function MH(model::Model; proposal_type=AMH.StaticProposal) + priors = DynamicPPL.extract_priors(model) + props = Tuple([proposal_type(prop) for prop in values(priors)]) + vars = Tuple(map(Symbol, collect(keys(priors)))) + priors = map(transform_maybe, NamedTuple{vars}(props)) + return AMH.MetropolisHastings(priors) +end + ##################### # Utility functions # ##################### @@ -346,6 +360,7 @@ end function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal) return true end +# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`! function should_link( varinfo, sampler, diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 8e52aec9b..94f9aa992 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -17,6 +17,12 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) + + s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) + c5 = sample(gdemo_default, s5, N) + + s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) + c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin Random.seed!(125)