From dcf1da9cccf71ea439f33c80e52cd075c8c164ee Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 12 Jul 2024 09:26:33 +0100 Subject: [PATCH 01/56] very incomplete draft --- src/AbstractMCMC.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index dc464d42..8408d3df 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,6 +80,20 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end +""" + recompute_logprob!!(rng, model, sampler, state) + +Recompute the log-probability of the `model` based on the given `state` and return the resulting state. +""" +function recompute_logprob!!(rng, model, sampler, state) end + +""" + getparams + +TODO +""" +function getparams end + include("samplingstats.jl") include("logging.jl") include("interface.jl") From cdaa663b574666d61613725745364ffee839764a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 12 Jul 2024 16:44:34 +0100 Subject: [PATCH 02/56] update `getparams` --- src/AbstractMCMC.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 8408d3df..3b38c583 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -88,11 +88,11 @@ Recompute the log-probability of the `model` based on the given `state` and retu function recompute_logprob!!(rng, model, sampler, state) end """ - getparams + getparams(state) -TODO +Returns the values of the parameters in the state. """ -function getparams end +function getparams(state) end include("samplingstats.jl") include("logging.jl") From 57275f50fca7a2cfb8a93f37832d7a35b94b52b5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 18 Jul 2024 11:14:44 +0100 Subject: [PATCH 03/56] Upstream `condition` and `decondition` from `AbstractPPL` --- src/AbstractMCMC.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 3b38c583..a2aa7170 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,6 +80,40 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end +""" + decondition(conditioned_model) + +Remove the conditioning (i.e., observation data) from `conditioned_model`, turning it into a +generative model over prior and observed variables. + +The invariant + +``` +m == condition(decondition(m), obs) +``` + +should hold for models `m` with conditioned variables `obs`. +""" +function decondition end + +""" + condition(model, observations) + +Condition the generative model `model` on some observed data, creating a new model of the (possibly +unnormalized) posterior distribution over them. + +`observations` can be of any supported internal trace type, or a fixed probability expression. + +The invariant + +``` +m = decondition(condition(m, obs)) +``` + +should hold for generative models `m` and arbitrary `obs`. +""" +function condition end + """ recompute_logprob!!(rng, model, sampler, state) From 26027ea6fd1babe201fa9a68d908cd621ea44565 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:25:53 +0100 Subject: [PATCH 04/56] remove `condition` and `decondition` --- src/AbstractMCMC.jl | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index a2aa7170..3b38c583 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,40 +80,6 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end -""" - decondition(conditioned_model) - -Remove the conditioning (i.e., observation data) from `conditioned_model`, turning it into a -generative model over prior and observed variables. - -The invariant - -``` -m == condition(decondition(m), obs) -``` - -should hold for models `m` with conditioned variables `obs`. -""" -function decondition end - -""" - condition(model, observations) - -Condition the generative model `model` on some observed data, creating a new model of the (possibly -unnormalized) posterior distribution over them. - -`observations` can be of any supported internal trace type, or a fixed probability expression. - -The invariant - -``` -m = decondition(condition(m, obs)) -``` - -should hold for generative models `m` and arbitrary `obs`. -""" -function condition end - """ recompute_logprob!!(rng, model, sampler, state) From 6ebab49b9943780c2f22ad387e3e342ed2bc5511 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:29:07 +0100 Subject: [PATCH 05/56] add Compat to make new interface functions public --- Project.toml | 2 ++ src/AbstractMCMC.jl | 3 +++ 2 files changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 97c90709..02df4e3d 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "5.2.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19, 0.4" +Compat = "4.15.0" ConsoleProgressMonitor = "0.1" FillArrays = "1" LogDensityProblems = "2" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 3b38c583..0ae15774 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,6 +1,7 @@ module AbstractMCMC using BangBang: BangBang +using Compat using ConsoleProgressMonitor: ConsoleProgressMonitor using LogDensityProblems: LogDensityProblems using LoggingExtras: LoggingExtras @@ -21,6 +22,8 @@ export sample # Parallel sampling types export MCMCThreads, MCMCDistributed, MCMCSerial +@compat public recompute_logprob!!, getparams + """ AbstractChains From e1099f9faf9b6cfef732f6be948e4eae3fbe30d2 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:32:43 +0100 Subject: [PATCH 06/56] bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 02df4e3d..b2a9a4c9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.2.0" +version = "6.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 95d781b1d6a9661953ce2d5202fb91a6848c03c3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 10:15:38 +0100 Subject: [PATCH 07/56] bump minor version instead --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b2a9a4c9..942d37c0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "6.0.0" +version = "5.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From f05f2937fa706b87fac4608a06bdab971980b844 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 6 Aug 2024 15:29:03 +0100 Subject: [PATCH 08/56] unfinished gibbs example --- src/example.jl | 177 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 src/example.jl diff --git a/src/example.jl b/src/example.jl new file mode 100644 index 00000000..f49a2ba1 --- /dev/null +++ b/src/example.jl @@ -0,0 +1,177 @@ +using LogDensityProblems, Distributions, LinearAlgebra, Random +using OrderedCollections +## Define a simple GMM problem + +struct GMM{Tdata} + data::NamedTuple +end + +struct ConditionedGMM{conditioned_vars} + data::NamedTuple + conditioned_values::NamedTuple{conditioned_vars} +end + +function log_joint(;μ, w, z, x) + # μ is mean of each component + # w is weights of each component + # z is assignment of each data point + # x is data + + K = 2 + D = 2 + N = size(x, 1) + logp = .0 + + μ_prior = MvNormal(zeros(K), I) + logp += sum(logpdf(μ_prior, μ)) + + w_prior = Dirichlet(K, 1.0) + logp += logpdf(w_prior, w) + + z_prior = Categorical(w) + logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) + + for i in 1:N + logp += logpdf(MvNormal(fill(μ[z[i]], D), I), x[i, :]) + end + + return logp +end + +function condition(gmm::GMM, conditioned_values::NamedTuple) + return ConditionedGMM(gmm.data, conditioned_values) +end + +function logdensity(gmm::ConditionedGMM{conditioned_vars}, params) where {conditioned_vars} + if conditioned_vars == (:μ, :w) + return log_joint(;μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data) + elseif conditioned_vars == (:z,) + return log_joint(;μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data) + else + throw(ArgumentError("condition group not supported")) + end +end + +function LogDensityProblems.logdensity(gmm::ConditionedGMM{conditioned_vars}, params_vec::AbstractVector) where {conditioned_vars} + if conditioned_vars == (:μ, :w) + params = (; z= params_vec) + elseif conditioned_vars == (:z,) + params = (; μ= params_vec[1:2], w= params_vec[3:4]) + else + throw(ArgumentError("condition group not supported")) + end + + return logdensity(gmm, params) +end + +function LogDensityProblems.dimension(gmm::ConditionedGMM{conditioned_vars}) where {conditioned_vars} + if conditioned_vars == (:μ, :w) + return size(gmm.data.x, 1) + elseif conditioned_vars == (:z,) + return size(gmm.data.x, 1) + else + throw(ArgumentError("condition group not supported")) + end +end + +struct Gibbs <: AbstractMCMC.AbstractSampler + sampler_map::OrderedDict +end + +# ! initialize the params here +struct GibbsState + "contains all the values of the model parameters" + values::NamedTuple + states::OrderedDict +end + +struct GibbsTransition + values::NamedTuple +end + +function AbstractMCMC.step( + rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs... +) + states = OrderedDict() + for group in collect(keys(sampler.sampler_map)) + sampler = sampler.sampler_map[group] + cond_val = NamedTuple{group}([initial_params[g] for g in group]...) + trans, state = AbstractMCMC.step(rng, condition(model, cond_val), sampler, args...; kwargs...) + states[group] = state + end + return GibbsTransition(initial_params), GibbsState(initial_params, states) +end + +# questions is: when do we assume the logp from last iteration is not reliable anymore + +function AbstractMCMC.step( + rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs... +) + for group in collect(keys(sampler.sampler_map)) + sampler = sampler.sampler_map[group] + state = state.states[group] + trans, state = AbstractMCMC.step(rng, condition(model, state.values[group]), sampler, state, args...; kwargs...) + # TODO: what values to condition on here? stored where? + state.states[group] = state + end + return +end + +# importance sampling +struct ImportanceSampling <: AbstractMCMC.AbstractSampler + "number of samples" + n::Int + proposal +end + +struct ImportanceSamplingState + +end + +struct ImportanceSamplingTransition + values +end + +# initial step +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::ImportanceSampling, args...; kwargs... +) + +end + +function IS_step(rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs...) + proposals = rand(rng, sampler.proposal, sampler.n) + weights = logdensity.(proposals) .- log.(logpdf.(sampler.proposal, proposals)) + sample = rand(rng, Categorical(softmax(weights))) + return ImportanceSamplingTransition(proposals[sample]), ImportanceSamplingState() +end + + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs... +) + return +end + +struct RWMH <: AbstractMCMC.AbstractSampler + proposal +end + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs... +) + proposal = rand(rng, sampler.proposal) + + acceptance_probability = min(1, exp(logdensity(proposal) - logdensity(args[1]))) + if rand(rng) < acceptance_probability + return proposal, nothing + else + return args[1], nothing + end +end + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::RWMH, state::RWMHState, args...; kwargs... +) + return +end From 590d37f371027fc7551f056427998064c647ca1c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 14 Aug 2024 18:35:57 +0100 Subject: [PATCH 09/56] some updates --- gibbs_example/Project.toml | 10 +++ gibbs_example/gibbs.jl | 45 ++++++++++ gibbs_example/gmm.jl | 134 ++++++++++++++++++++++++++++ gibbs_example/mh.jl | 64 ++++++++++++++ src/AbstractMCMC.jl | 20 ++++- src/example.jl | 177 ------------------------------------- 6 files changed, 270 insertions(+), 180 deletions(-) create mode 100644 gibbs_example/Project.toml create mode 100644 gibbs_example/gibbs.jl create mode 100644 gibbs_example/gmm.jl create mode 100644 gibbs_example/mh.jl delete mode 100644 src/example.jl diff --git a/gibbs_example/Project.toml b/gibbs_example/Project.toml new file mode 100644 index 00000000..1e8d8677 --- /dev/null +++ b/gibbs_example/Project.toml @@ -0,0 +1,10 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl new file mode 100644 index 00000000..e201a099 --- /dev/null +++ b/gibbs_example/gibbs.jl @@ -0,0 +1,45 @@ +using LogDensityProblems, Distributions, LinearAlgebra, Random +using OrderedCollections + +struct Gibbs <: AbstractMCMC.AbstractSampler + sampler_map::OrderedDict +end + +struct GibbsState + values::NamedTuple + states::OrderedDict +end + +struct GibbsTransition + values::NamedTuple +end + +function AbstractMCMC.step( + rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs... +) + states = OrderedDict() + for group in keys(sampler.sampler_map) + sampler = sampler.sampler_map[group] + cond_val = NamedTuple{group}([initial_params[g] for g in group]...) + trans, state = AbstractMCMC.step( + rng, condition(model, cond_val), sampler, args...; kwargs... + ) + states[group] = state + end + return GibbsTransition(initial_params), GibbsState(initial_params, states) +end + +function AbstractMCMC.step( + rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs... +) + for group in collect(keys(sampler.sampler_map)) + sampler = sampler.sampler_map[group] + state = state.states[group] + trans, state = AbstractMCMC.step( + rng, condition(model, state.values[group]), sampler, state, args...; kwargs... + ) + # TODO: what values to condition on here? stored where? + state.states[group] = state + end + return nothing +end diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl new file mode 100644 index 00000000..b40b4579 --- /dev/null +++ b/gibbs_example/gmm.jl @@ -0,0 +1,134 @@ +using LogDensityProblems + +abstract type AbstractGMM end + +struct GMM <: AbstractGMM + data::NamedTuple +end + +struct ConditionedGMM{conditioned_vars} <: AbstractGMM + data::NamedTuple + conditioned_values::NamedTuple{conditioned_vars} +end + +function log_joint(; μ, w, z, x) + # μ is mean of each component + # w is weights of each component + # z is assignment of each data point + # x is data + + K = 2 # assume we know the number of components + D = 2 # dimension of each data point + N = size(x, 2) # number of data points + logp = 0.0 + + μ_prior = MvNormal(zeros(K), I) + logp += logpdf(μ_prior, μ) + + w_prior = Dirichlet(K, 1.0) + logp += logpdf(w_prior, w) + + z_prior = Categorical(w) + logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) + + obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ] + for i in 1:N + logp += logpdf(obs_priors[z[i]], x[:, i]) + end + + return logp +end + +function condition(gmm::GMM, conditioned_values::NamedTuple) + return ConditionedGMM(gmm.data, conditioned_values) +end + +function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params) + return log_joint(; + μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x + ) +end +function _logdensity(gmm::ConditionedGMM{(:z,)}, params) + return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x) +end + +function LogDensityProblems.logdensity( + gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector +) + return _logdensity(gmm, (; z=params_vec)) +end +function LogDensityProblems.logdensity( + gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector +) + return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4])) +end + +function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)}) + return size(gmm.data.x, 1) +end +function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)}) + return size(gmm.data.x, 1) +end + +## test using Turing + +# data generation + +using Distributions +using FillArrays +using LinearAlgebra +using Random + +w = [0.5, 0.5] +μ = [-3.5, 0.5] +mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) + +N = 60 +x = rand(mixturemodel, N); + +# Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/ +using Turing + +@model function gaussian_mixture_model(x) + # Draw the parameters for each of the K=2 clusters from a standard normal distribution. + K = 2 + μ ~ MvNormal(Zeros(K), I) + + # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. + w ~ Dirichlet(K, 1.0) + # Alternatively, one could use a fixed set of weights. + # w = fill(1/K, K) + + # Construct categorical distribution of assignments. + distribution_assignments = Categorical(w) + + # Construct multivariate normal distributions of each cluster. + D, N = size(x) + distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] + + # Draw assignments for each datum and generate it from the multivariate normal distribution. + k = Vector{Int}(undef, N) + for i in 1:N + k[i] ~ distribution_assignments + x[:, i] ~ distribution_clusters[k[i]] + end + + return μ, w, k, __varinfo__ +end + +model = gaussian_mixture_model(x); + +using Test +# full model +μ, w, k, vi = model() +@test log_joint(; μ=μ, w=w, z=k, x=x) ≈ DynamicPPL.getlogp(vi) + +gmm = GMM((; x=x)) + +# cond model on μ, w +μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))() +@test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) ≈ DynamicPPL.getlogp(vi) + +# cond model on z +μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))() +@test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) ≈ DynamicPPL.getlogp(vi) diff --git a/gibbs_example/mh.jl b/gibbs_example/mh.jl new file mode 100644 index 00000000..39d1c69f --- /dev/null +++ b/gibbs_example/mh.jl @@ -0,0 +1,64 @@ +struct RWMH <: AbstractMCMC.AbstractSampler + σ +end + +struct MHTransition{T} where {T} + params::T +end + +struct MHState{T} where {T} + params::T + logp::Float64 +end + +getparams(state::MHState) = state.params +setparams!!(state::MHState, params) = MHState(params, state.logp) +getlogp(state::MHState) = state.logp +setlogp!!(state::MHState, logp) = MHState(state.params, logp) + +function AbstractMCMC.step(rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs...) + params = rand(rng, LogDensityProblems.dimension(logdensity)) + return MHTransition(params), + MHState(params, LogDensityProblems.logdensity(logdensity, params)) +end + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::RWMH, state::MHState, args...; kwargs... +) + params = getparams(state) + proposal_dist = MvNormal(params, sampler.σ) + proposal = rand(rng, proposal_dist) + logp_proposal = logpdf(proposal_dist, proposal) + accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) + if accepted + return MHTransition(proposal), MHState(proposal, logp_proposal) + else + return MHTransition(params), MHState(params, getlogp(state)) + end +end + +struct PriorMH <: AbstractMCMC.AbstractSampler + prior_dist +end + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::PriorMH, args...; kwargs... +) + params = rand(rng, sampler.prior_dist) + return MHTransition(params), MHState(params, logdensity(params)) +end + +function AbstractMCMC.step( + rng::AbstractRNG, logdensity, sampler::PriorMH, state::MHState, args...; kwargs... +) + params = getparams(state) + proposal_dist = sampler.prior_dist + proposal = rand(rng, proposal_dist) + logp_proposal = logpdf(proposal_dist, proposal) + accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) + if accepted + return MHTransition(proposal), MHState(proposal, logp_proposal) + else + return MHTransition(params), MHState(params, getlogp(state)) + end +end diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 0ae15774..b347f4bf 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,11 +84,18 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr struct MCMCSerial <: AbstractMCMCEnsemble end """ - recompute_logprob!!(rng, model, sampler, state) + get_logprob(state) -Recompute the log-probability of the `model` based on the given `state` and return the resulting state. +Returns the log-probability of the last sampling step, stored in `state`. """ -function recompute_logprob!!(rng, model, sampler, state) end +function get_logprob(state) end + +""" + set_logprob!(state, logprob) + +Set the log-probability of the last sampling step, stored in `state`. +""" +function set_logprob!!(state, logprob) end """ getparams(state) @@ -97,6 +104,13 @@ Returns the values of the parameters in the state. """ function getparams(state) end +""" + setparams!(state, params) + +Set the values of the parameters in the state. +""" +function setparams!!(state, params) end + include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/example.jl b/src/example.jl deleted file mode 100644 index f49a2ba1..00000000 --- a/src/example.jl +++ /dev/null @@ -1,177 +0,0 @@ -using LogDensityProblems, Distributions, LinearAlgebra, Random -using OrderedCollections -## Define a simple GMM problem - -struct GMM{Tdata} - data::NamedTuple -end - -struct ConditionedGMM{conditioned_vars} - data::NamedTuple - conditioned_values::NamedTuple{conditioned_vars} -end - -function log_joint(;μ, w, z, x) - # μ is mean of each component - # w is weights of each component - # z is assignment of each data point - # x is data - - K = 2 - D = 2 - N = size(x, 1) - logp = .0 - - μ_prior = MvNormal(zeros(K), I) - logp += sum(logpdf(μ_prior, μ)) - - w_prior = Dirichlet(K, 1.0) - logp += logpdf(w_prior, w) - - z_prior = Categorical(w) - logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) - - for i in 1:N - logp += logpdf(MvNormal(fill(μ[z[i]], D), I), x[i, :]) - end - - return logp -end - -function condition(gmm::GMM, conditioned_values::NamedTuple) - return ConditionedGMM(gmm.data, conditioned_values) -end - -function logdensity(gmm::ConditionedGMM{conditioned_vars}, params) where {conditioned_vars} - if conditioned_vars == (:μ, :w) - return log_joint(;μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data) - elseif conditioned_vars == (:z,) - return log_joint(;μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data) - else - throw(ArgumentError("condition group not supported")) - end -end - -function LogDensityProblems.logdensity(gmm::ConditionedGMM{conditioned_vars}, params_vec::AbstractVector) where {conditioned_vars} - if conditioned_vars == (:μ, :w) - params = (; z= params_vec) - elseif conditioned_vars == (:z,) - params = (; μ= params_vec[1:2], w= params_vec[3:4]) - else - throw(ArgumentError("condition group not supported")) - end - - return logdensity(gmm, params) -end - -function LogDensityProblems.dimension(gmm::ConditionedGMM{conditioned_vars}) where {conditioned_vars} - if conditioned_vars == (:μ, :w) - return size(gmm.data.x, 1) - elseif conditioned_vars == (:z,) - return size(gmm.data.x, 1) - else - throw(ArgumentError("condition group not supported")) - end -end - -struct Gibbs <: AbstractMCMC.AbstractSampler - sampler_map::OrderedDict -end - -# ! initialize the params here -struct GibbsState - "contains all the values of the model parameters" - values::NamedTuple - states::OrderedDict -end - -struct GibbsTransition - values::NamedTuple -end - -function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs... -) - states = OrderedDict() - for group in collect(keys(sampler.sampler_map)) - sampler = sampler.sampler_map[group] - cond_val = NamedTuple{group}([initial_params[g] for g in group]...) - trans, state = AbstractMCMC.step(rng, condition(model, cond_val), sampler, args...; kwargs...) - states[group] = state - end - return GibbsTransition(initial_params), GibbsState(initial_params, states) -end - -# questions is: when do we assume the logp from last iteration is not reliable anymore - -function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs... -) - for group in collect(keys(sampler.sampler_map)) - sampler = sampler.sampler_map[group] - state = state.states[group] - trans, state = AbstractMCMC.step(rng, condition(model, state.values[group]), sampler, state, args...; kwargs...) - # TODO: what values to condition on here? stored where? - state.states[group] = state - end - return -end - -# importance sampling -struct ImportanceSampling <: AbstractMCMC.AbstractSampler - "number of samples" - n::Int - proposal -end - -struct ImportanceSamplingState - -end - -struct ImportanceSamplingTransition - values -end - -# initial step -function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::ImportanceSampling, args...; kwargs... -) - -end - -function IS_step(rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs...) - proposals = rand(rng, sampler.proposal, sampler.n) - weights = logdensity.(proposals) .- log.(logpdf.(sampler.proposal, proposals)) - sample = rand(rng, Categorical(softmax(weights))) - return ImportanceSamplingTransition(proposals[sample]), ImportanceSamplingState() -end - - -function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs... -) - return -end - -struct RWMH <: AbstractMCMC.AbstractSampler - proposal -end - -function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs... -) - proposal = rand(rng, sampler.proposal) - - acceptance_probability = min(1, exp(logdensity(proposal) - logdensity(args[1]))) - if rand(rng) < acceptance_probability - return proposal, nothing - else - return args[1], nothing - end -end - -function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::RWMH, state::RWMHState, args...; kwargs... -) - return -end From 3afc232ab5875f923a85b2cea7b69b63377bd881 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 15 Aug 2024 13:04:05 +0100 Subject: [PATCH 10/56] more progress; still need to deal with w being on simplex --- gibbs_example/Project.toml | 1 + gibbs_example/gibbs.jl | 100 ++++++++++++++++++++----- gibbs_example/gmm.jl | 56 +++++++++++++- gibbs_example/mh.jl | 149 ++++++++++++++++++++++++++++++------- 4 files changed, 258 insertions(+), 48 deletions(-) diff --git a/gibbs_example/Project.toml b/gibbs_example/Project.toml index 1e8d8677..81b2b669 100644 --- a/gibbs_example/Project.toml +++ b/gibbs_example/Project.toml @@ -6,5 +6,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index e201a099..1c68643f 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -1,12 +1,18 @@ +using AbstractMCMC using LogDensityProblems, Distributions, LinearAlgebra, Random using OrderedCollections +## + +# TODO: introduce some kind of parameter format, for instance, a flattened vector +# then define some kind of function to transform the flattened vector into model's representation + struct Gibbs <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict end struct GibbsState - values::NamedTuple + vi::NamedTuple states::OrderedDict end @@ -15,31 +21,91 @@ struct GibbsTransition end function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + args...; + initial_params::NamedTuple, + kwargs..., ) states = OrderedDict() - for group in keys(sampler.sampler_map) - sampler = sampler.sampler_map[group] - cond_val = NamedTuple{group}([initial_params[g] for g in group]...) - trans, state = AbstractMCMC.step( - rng, condition(model, cond_val), sampler, args...; kwargs... + for group in keys(spl.sampler_map) + sub_spl = spl.sampler_map[group] + + vars_to_be_conditioned_on = setdiff(keys(initial_params), group) + cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) + ) + params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + condition(logdensity_model.logdensity, cond_val) + ), + sub_spl, + args...; + initial_params=flatten(params_val), + kwargs..., + ), ) - states[group] = state + states[group] = sub_state end return GibbsTransition(initial_params), GibbsState(initial_params, states) end function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + state::GibbsState, + args...; + kwargs..., ) - for group in collect(keys(sampler.sampler_map)) - sampler = sampler.sampler_map[group] - state = state.states[group] - trans, state = AbstractMCMC.step( - rng, condition(model, state.values[group]), sampler, state, args...; kwargs... + vi = state.vi + for group in keys(spl.sampler_map) + for (group, sub_state) in state.states + vi = merge(vi, unflatten(getparams(sub_state), group)) + end + sub_spl = spl.sampler_map[group] + sub_state = state.states[group] + group_complement = setdiff(keys(vi), group) + cond_val = NamedTuple{Tuple(group_complement)}( + Tuple([vi[g] for g in group_complement]) + ) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + condition(logdensity_model.logdensity, cond_val) + ), + sub_spl, + sub_state, + args...; + kwargs..., + ), ) - # TODO: what values to condition on here? stored where? - state.states[group] = state + state.states[group] = sub_state end - return nothing + for sub_state in values(state.states) + vi = merge(vi, getparams(sub_state)) + end + return GibbsTransition(vi), GibbsState(vi, state.states) end + +## tests + +gmm = GMM((; x=x)) + +samples = sample( + gmm, + Gibbs( + OrderedDict( + (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), + (:w,) => PriorMH(Dirichlet(2, 1.0)), + (:μ, :w) => RWMH(1), + ), + ), + 10000; + initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]), +) diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl index b40b4579..7a8eab79 100644 --- a/gibbs_example/gmm.jl +++ b/gibbs_example/gmm.jl @@ -29,6 +29,7 @@ function log_joint(; μ, w, z, x) logp += logpdf(w_prior, w) z_prior = Categorical(w) + logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ] @@ -43,33 +44,80 @@ function condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end -function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params) +function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params) return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x ) end + function _logdensity(gmm::ConditionedGMM{(:z,)}, params) return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x) end function LogDensityProblems.logdensity( - gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector + gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, + params_vec::AbstractVector, ) + @assert length(params_vec) == 60 return _logdensity(gmm, (; z=params_vec)) end function LogDensityProblems.logdensity( gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector ) + @assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))" return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4])) end -function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)}) - return size(gmm.data.x, 1) +function LogDensityProblems.dimension(gmm::GMM) + return 4 + size(gmm.data.x, 1) +end + +function LogDensityProblems.dimension( + gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} +) + return 4 end + function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)}) return size(gmm.data.x, 1) end +function LogDensityProblems.capabilities(::GMM) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedGMM) + return LogDensityProblems.LogDensityOrder{0}() +end + +function flatten(nt::NamedTuple) + if Set(keys(nt)) == Set([:μ, :w]) + return vcat(nt.μ, nt.w) + elseif Set(keys(nt)) == Set([:z]) + return nt.z + else + error() + end +end + +function unflatten(vec::AbstractVector, group::Tuple) + if Set(group) == Set([:μ, :w]) + return (; μ=vec[1:2], w=vec[3:4]) + elseif Set(group) == Set([:z]) + return (; z=vec) + else + error() + end +end + +# sampler's states to internal representation +# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?) + +# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed +function recompute_logprob!!(gmm::ConditionedGMM, vals, state) + return setlogp!(state, _logdensity(gmm, vals)) +end + ## test using Turing # data generation diff --git a/gibbs_example/mh.jl b/gibbs_example/mh.jl index 39d1c69f..b2fa91dc 100644 --- a/gibbs_example/mh.jl +++ b/gibbs_example/mh.jl @@ -1,13 +1,9 @@ -struct RWMH <: AbstractMCMC.AbstractSampler - σ -end - -struct MHTransition{T} where {T} - params::T +struct MHTransition{T} + params::Vector{T} end -struct MHState{T} where {T} - params::T +struct MHState{T} + params::Vector{T} logp::Float64 end @@ -16,21 +12,43 @@ setparams!!(state::MHState, params) = MHState(params, state.logp) getlogp(state::MHState) = state.logp setlogp!!(state::MHState, logp) = MHState(state.params, logp) -function AbstractMCMC.step(rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs...) - params = rand(rng, LogDensityProblems.dimension(logdensity)) - return MHTransition(params), - MHState(params, LogDensityProblems.logdensity(logdensity, params)) +struct RWMH <: AbstractMCMC.AbstractSampler + σ::Float64 end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::RWMH, state::MHState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + args...; + initial_params, + kwargs..., ) - params = getparams(state) - proposal_dist = MvNormal(params, sampler.σ) - proposal = rand(rng, proposal_dist) - logp_proposal = logpdf(proposal_dist, proposal) - accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) - if accepted + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + state::MHState, + args...; + kwargs..., +) + params = state.params + proposal_dist = MvNormal(zeros(length(params)), sampler.σ) + proposal = params .+ rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min(0, logp_proposal - getlogp(state)) + + if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else return MHTransition(params), MHState(params, getlogp(state)) @@ -38,27 +56,104 @@ function AbstractMCMC.step( end struct PriorMH <: AbstractMCMC.AbstractSampler - prior_dist + prior_dist::Distribution end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::PriorMH, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + args...; + initial_params, + kwargs..., ) - params = rand(rng, sampler.prior_dist) - return MHTransition(params), MHState(params, logdensity(params)) + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::PriorMH, state::MHState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + state::MHState, + args...; + kwargs..., ) params = getparams(state) proposal_dist = sampler.prior_dist proposal = rand(rng, proposal_dist) - logp_proposal = logpdf(proposal_dist, proposal) - accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) - if accepted + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min( + 0, + logp_proposal - getlogp(state) + logpdf(proposal_dist, params) - + logpdf(proposal_dist, proposal), + ) + + if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else return MHTransition(params), MHState(params, getlogp(state)) end end + +## tests + +# for RWMH +# sample from Normal(10, 1) +struct NormalLogDensity end +LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) +LogDensityProblems.dimension(l::NormalLogDensity) = 1 +function LogDensityProblems.capabilities(::NormalLogDensity) + return LogDensityProblems.LogDensityOrder{1}() +end + +# for PriorMH +# sample from Categorical([0.2, 0.5, 0.3]) +struct CategoricalLogDensity end +function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) + return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) +end +LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 +function LogDensityProblems.capabilities(::CategoricalLogDensity) + return LogDensityProblems.LogDensityOrder{0}() +end + +## + +using StatsPlots + +samples = AbstractMCMC.sample( + Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] +) +_samples = map(t -> only(t.params), samples) + +histogram( + _samples; + normalize=:pdf, + label="Samples", + title="RWMH Sampling of Normal(10, 1)", +) +plot!(Normal(10, 1); linewidth=2, label="Ground Truth") + +samples = AbstractMCMC.sample( + Random.default_rng(), + CategoricalLogDensity(), + PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), + 100000; + initial_params=[1], +) +_samples = map(t -> only(t.params), samples) + +histogram( + _samples; + normalize=:probability, + label="Samples", + title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", +) +plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth") From 55dbab52933a7fb90019e7a64ef9a1a8203c6ba0 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 15 Aug 2024 13:04:40 +0100 Subject: [PATCH 11/56] bit of format --- gibbs_example/mh.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gibbs_example/mh.jl b/gibbs_example/mh.jl index b2fa91dc..1029a2ee 100644 --- a/gibbs_example/mh.jl +++ b/gibbs_example/mh.jl @@ -133,12 +133,7 @@ samples = AbstractMCMC.sample( ) _samples = map(t -> only(t.params), samples) -histogram( - _samples; - normalize=:pdf, - label="Samples", - title="RWMH Sampling of Normal(10, 1)", -) +histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)") plot!(Normal(10, 1); linewidth=2, label="Ground Truth") samples = AbstractMCMC.sample( From 67ff8e80d4e2018e6a6d9c3e40397a2522f89595 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 15 Aug 2024 14:01:08 +0100 Subject: [PATCH 12/56] results is wrong --- gibbs_example/gibbs.jl | 26 +++++++++------- gibbs_example/gmm.jl | 71 ++++++++---------------------------------- 2 files changed, 28 insertions(+), 69 deletions(-) diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index 1c68643f..2cd9643d 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -4,9 +4,6 @@ using OrderedCollections ## -# TODO: introduce some kind of parameter format, for instance, a flattened vector -# then define some kind of function to transform the flattened vector into model's representation - struct Gibbs <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict end @@ -73,12 +70,12 @@ function AbstractMCMC.step( cond_val = NamedTuple{Tuple(group_complement)}( Tuple([vi[g] for g in group_complement]) ) + cond_logdensity = condition(logdensity_model.logdensity, cond_val) + sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - condition(logdensity_model.logdensity, cond_val) - ), + AbstractMCMC.LogDensityModel(cond_logdensity), sub_spl, sub_state, args...; @@ -87,8 +84,8 @@ function AbstractMCMC.step( ) state.states[group] = sub_state end - for sub_state in values(state.states) - vi = merge(vi, getparams(sub_state)) + for (group, sub_state) in state.states + vi = merge(vi, unflatten(getparams(sub_state), group)) end return GibbsTransition(vi), GibbsState(vi, state.states) end @@ -103,9 +100,16 @@ samples = sample( OrderedDict( (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), (:w,) => PriorMH(Dirichlet(2, 1.0)), - (:μ, :w) => RWMH(1), + (:μ,) => RWMH(1), ), ), - 10000; + 100000; initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]), -) +); + +z_samples = [sample.values.z for sample in samples][20001:end] +μ_samples = [sample.values.μ for sample in samples][20001:end] +w_samples = [sample.values.w for sample in samples][20001:end] + +mean(μ_samples) +mean(w_samples) diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl index 7a8eab79..a843eafc 100644 --- a/gibbs_example/gmm.jl +++ b/gibbs_example/gmm.jl @@ -44,42 +44,16 @@ function condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end -function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params) - return log_joint(; - μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x - ) -end - -function _logdensity(gmm::ConditionedGMM{(:z,)}, params) - return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x) -end - -function LogDensityProblems.logdensity( - gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, - params_vec::AbstractVector, -) - @assert length(params_vec) == 60 - return _logdensity(gmm, (; z=params_vec)) -end -function LogDensityProblems.logdensity( - gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector -) - @assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))" - return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4])) -end - -function LogDensityProblems.dimension(gmm::GMM) - return 4 + size(gmm.data.x, 1) -end - -function LogDensityProblems.dimension( - gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} -) - return 4 -end - -function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)}) - return size(gmm.data.x, 1) +function LogDensityProblems.logdensity(gmm::ConditionedGMM{names}, params::AbstractVector) where {names} + if Set(names) == Set([:μ, :w]) # conditioned on μ, w, so params are z + return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x) + elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ + return log_joint(; μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x) + elseif Set(names) == Set([:z, :μ]) # conditioned on z, μ, so params are w + return log_joint(; μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x) + else + error("Unsupported conditioning configuration.") + end end function LogDensityProblems.capabilities(::GMM) @@ -91,41 +65,22 @@ function LogDensityProblems.capabilities(::ConditionedGMM) end function flatten(nt::NamedTuple) - if Set(keys(nt)) == Set([:μ, :w]) - return vcat(nt.μ, nt.w) - elseif Set(keys(nt)) == Set([:z]) - return nt.z - else - error() - end + return only(values(nt)) end function unflatten(vec::AbstractVector, group::Tuple) - if Set(group) == Set([:μ, :w]) - return (; μ=vec[1:2], w=vec[3:4]) - elseif Set(group) == Set([:z]) - return (; z=vec) - else - error() - end + return NamedTuple((only(group) => vec,)) end -# sampler's states to internal representation -# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?) - -# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed function recompute_logprob!!(gmm::ConditionedGMM, vals, state) - return setlogp!(state, _logdensity(gmm, vals)) + return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals)) end ## test using Turing # data generation -using Distributions using FillArrays -using LinearAlgebra -using Random w = [0.5, 0.5] μ = [-3.5, 0.5] From f758a4cc0dfc5eb0f935ecdb98a25d421132967c Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 15 Aug 2024 21:04:16 +0800 Subject: [PATCH 13/56] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- gibbs_example/gmm.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl index a843eafc..5d345401 100644 --- a/gibbs_example/gmm.jl +++ b/gibbs_example/gmm.jl @@ -44,13 +44,21 @@ function condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end -function LogDensityProblems.logdensity(gmm::ConditionedGMM{names}, params::AbstractVector) where {names} +function LogDensityProblems.logdensity( + gmm::ConditionedGMM{names}, params::AbstractVector +) where {names} if Set(names) == Set([:μ, :w]) # conditioned on μ, w, so params are z - return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x) + return log_joint(; + μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x + ) elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ - return log_joint(; μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x) + return log_joint(; + μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x + ) elseif Set(names) == Set([:z, :μ]) # conditioned on z, μ, so params are w - return log_joint(; μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x) + return log_joint(; + μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x + ) else error("Unsupported conditioning configuration.") end From 7d0ba7cf883f9ccad8f495776277b67bbeb9f9cf Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 22 Aug 2024 12:38:43 +0100 Subject: [PATCH 14/56] add hierarchical normal problem --- gibbs_example/gibbs.jl | 44 +++++++++++++++++++++- gibbs_example/hier_normal.jl | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 gibbs_example/hier_normal.jl diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index 2cd9643d..438343b2 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -92,6 +92,41 @@ end ## tests +# generate data +N = 100 # Number of data points +mu_true = 0.5 # True mean +tau2_true = 2.0 # True variance + +# Generate data based on true parameters +x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) + +# Store the generated data in the HierNormal structure +hn = HierNormal((x=x_data,)) + +## + +samples = sample( + hn, + Gibbs( + OrderedDict( + (:mu,) => RWMH(1), + (:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])), + ), + ), + 100000; + initial_params=(mu=[0.0], tau2=[1.0]), +) + +mu_samples = [sample.values.mu for sample in samples][20001:end] +tau2_samples = [sample.values.tau2 for sample in samples][20001:end] + +mean(mu_samples) +mean(tau2_samples) + +## + +# this is too difficult of a problem + gmm = GMM((; x=x)) samples = sample( @@ -104,12 +139,17 @@ samples = sample( ), ), 100000; - initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]), + initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), ); z_samples = [sample.values.z for sample in samples][20001:end] μ_samples = [sample.values.μ for sample in samples][20001:end] -w_samples = [sample.values.w for sample in samples][20001:end] +w_samples = [sample.values.w for sample in samples][20001:end]; + +# thin these samples +z_samples = z_samples[1:100:end] +μ_samples = μ_samples[1:100:end] +w_samples = w_samples[1:100:end]; mean(μ_samples) mean(w_samples) diff --git a/gibbs_example/hier_normal.jl b/gibbs_example/hier_normal.jl new file mode 100644 index 00000000..00de48cf --- /dev/null +++ b/gibbs_example/hier_normal.jl @@ -0,0 +1,73 @@ +using LogDensityProblems + +abstract type AbstractHierNormal end + +struct HierNormal <: AbstractHierNormal + data::NamedTuple +end + +struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal + data::NamedTuple + conditioned_values::NamedTuple{conditioned_vars} +end + +function log_joint(; mu, tau2, x) + # mu is the mean + # tau2 is the variance + # x is data + + # μ ~ Normal(0, 1) + # τ² ~ InverseGamma(1, 1) + # xᵢ ~ Normal(μ, √τ²) + + logp = 0.0 + mu = only(mu) + tau2 = only(tau2) + + mu_prior = Normal(0, 1) + logp += logpdf(mu_prior, mu) + + tau2_prior = InverseGamma(1, 1) + logp += logpdf(tau2_prior, tau2) + + obs_prior = Normal(mu, sqrt(tau2)) + logp += sum(logpdf(obs_prior, xi) for xi in x) + + return logp +end + +function condition(hn::HierNormal, conditioned_values::NamedTuple) + return ConditionedHierNormal(hn.data, conditioned_values) +end + +function LogDensityProblems.logdensity( + hn::ConditionedHierNormal{names}, params::AbstractVector +) where {names} + if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 + return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) + elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu + return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) + else + error("Unsupported conditioning configuration.") + end +end + +function LogDensityProblems.capabilities(::HierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedHierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end + +function flatten(nt::NamedTuple) + return only(values(nt)) +end + +function unflatten(vec::AbstractVector, group::Tuple) + return NamedTuple((only(group) => vec,)) +end + +function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) + return setlogp!!(state, LogDensityProblems.logdensity(hn, vals)) +end From 1ab6dd95123dc28250610c2ce0610770db67880f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:34:06 +0100 Subject: [PATCH 15/56] some updates; add doc --- docs/src/gibbs.md | 347 +++++++++++++++++++++++++++++++++++ gibbs_example/gibbs.jl | 10 +- gibbs_example/hier_normal.jl | 2 +- gibbs_example/mh.jl | 10 +- src/AbstractMCMC.jl | 8 +- 5 files changed, 362 insertions(+), 15 deletions(-) create mode 100644 docs/src/gibbs.md diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md new file mode 100644 index 00000000..ee1d279a --- /dev/null +++ b/docs/src/gibbs.md @@ -0,0 +1,347 @@ +# `state` Interface + +We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: + +```@doc +get_logprob +set_logprob!! +get_params +set_params!! +``` + +These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose. + +## Using the `state` Interface for block sampling within Gibbs + +In this sections, we will demonstrate how a `model` package may use this `state` interface to support a Gibbs sampler that can support blocking sampling using different inference algorithms. + +We consider a simple hierarchical model with a normal likelihood, with unknown mean and variance parameters. + +```math +\begin{align} +\mu &\sim \text{Normal}(0, 1) \\ +\tau^2 &\sim \text{InverseGamma}(1, 1) \\ +x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) +\end{align} +``` + +We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. + +```julia +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) + # mu is the mean + # tau2 is the variance + # x is data + + # μ ~ Normal(0, 1) + # τ² ~ InverseGamma(1, 1) + # xᵢ ~ Normal(μ, √τ²) + + logp = 0.0 + mu = only(mu) + tau2 = only(tau2) + + mu_prior = Normal(0, 1) + logp += logpdf(mu_prior, mu) + + tau2_prior = InverseGamma(1, 1) + logp += logpdf(tau2_prior, tau2) + + obs_prior = Normal(mu, sqrt(tau2)) + logp += sum(logpdf(obs_prior, xi) for xi in x) + + return logp +end +``` + +To make using `LogDensityProblems` interface, we create a simple type for this model. + +```julia +abstract type AbstractHierNormal end + +struct HierNormal <: AbstractHierNormal + data::NamedTuple +end + +struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal + data::NamedTuple + conditioned_values::NamedTuple{conditioned_vars} +end +``` + +where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and + +```julia +function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) + return ConditionedHierNormal(hn.data, conditioned_values) +end +``` + +then we can simply write down the `LogDensityProblems` interface for this model. + +```julia +function LogDensityProblems.logdensity( + hn::ConditionedHierNormal{names}, params::AbstractVector +) where {names} + if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 + return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) + elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu + return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) + else + error("Unsupported conditioning configuration.") + end +end + +function LogDensityProblems.capabilities(::HierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedHierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end +``` + +the model should also define a function that allows the recomputation of the log probability given a sampler state. +The reason for this is that, when we break down the joint probability into conditional probabilities, individual conditional probability problems are conditional on the values of the other variables. +Between the Gibbs sampler sweeps, the values of the variables may change, and we need to recompute the log probability of the current state. + +A recomputation function could use the `state` interface to return a new state with the updated log probability. +E.g. + +```julia +function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) + return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) +end +``` + +where the model doesn't need to know the details of the `state` type, as long as it can access the `log_joint` function. + +## Sampler Packages + +To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. + +Although the interface doesn't force the sampler to implement `Transition` and `State` types, in practice, it has been the convention to do so. + +Here we define some bare minimum types to represent the transitions and states. + +```julia +struct MHTransition{T} + params::Vector{T} +end + +struct MHState{T} + params::Vector{T} + logp::Float64 +end +``` + +Next we define the four `state` interface functions. + +```julia +AbstractMCMC.get_params(state::MHState) = state.params +AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) +AbstractMCMC.get_logprob(state::MHState) = state.logp +AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) +``` + +These are the functions that was used in the `recompute_logprob!!` function above. + +It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state. + +```julia +struct RWMH <: AbstractMCMC.AbstractSampler + σ::Float64 +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + args...; + initial_params, + kwargs..., +) + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + state::MHState, + args...; + kwargs..., +) + params = state.params + proposal_dist = MvNormal(zeros(length(params)), sampler.σ) + proposal = params .+ rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min(0, logp_proposal - get_logprob(state)) + + if log(rand(rng)) < log_acceptance_ratio + return MHTransition(proposal), MHState(proposal, logp_proposal) + else + return MHTransition(params), MHState(params, get_logprob(state)) + end +end +``` + +```julia +struct PriorMH <: AbstractMCMC.AbstractSampler + prior_dist::Distribution +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + args...; + initial_params, + kwargs..., +) + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + state::MHState, + args...; + kwargs..., +) + params = get_params(state) + proposal_dist = sampler.prior_dist + proposal = rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min( + 0, + logp_proposal - get_logprob(state) + logpdf(proposal_dist, params) - + logpdf(proposal_dist, proposal), + ) + + if log(rand(rng)) < log_acceptance_ratio + return MHTransition(proposal), MHState(proposal, logp_proposal) + else + return MHTransition(params), MHState(params, get_logprob(state)) + end +end +``` + +At last, we can proceed to implement the Gibbs sampler. + +```julia +struct Gibbs <: AbstractMCMC.AbstractSampler + sampler_map::OrderedDict +end + +struct GibbsState + vi::NamedTuple + states::OrderedDict +end + +struct GibbsTransition + values::NamedTuple +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + args...; + initial_params::NamedTuple, + kwargs..., +) + states = OrderedDict() + for group in keys(spl.sampler_map) + sub_spl = spl.sampler_map[group] + + vars_to_be_conditioned_on = setdiff(keys(initial_params), group) + cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) + ) + params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + condition(logdensity_model.logdensity, cond_val) + ), + sub_spl, + args...; + initial_params=flatten(params_val), + kwargs..., + ), + ) + states[group] = sub_state + end + return GibbsTransition(initial_params), GibbsState(initial_params, states) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + state::GibbsState, + args...; + kwargs..., +) + vi = state.vi + for group in keys(spl.sampler_map) + for (group, sub_state) in state.states + vi = merge(vi, unflatten(get_params(sub_state), group)) + end + sub_spl = spl.sampler_map[group] + sub_state = state.states[group] + group_complement = setdiff(keys(vi), group) + cond_val = NamedTuple{Tuple(group_complement)}( + Tuple([vi[g] for g in group_complement]) + ) + cond_logdensity = condition(logdensity_model.logdensity, cond_val) + sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel(cond_logdensity), + sub_spl, + sub_state, + args...; + kwargs..., + ), + ) + state.states[group] = sub_state + end + for (group, sub_state) in state.states + vi = merge(vi, unflatten(get_params(sub_state), group)) + end + return GibbsTransition(vi), GibbsState(vi, state.states) +end +``` + +Some points worth noting: + +1. We are using `OrderedDict` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. +2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration. +3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem. +4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following: + - first update the values from the last step of the sweep into the `vi`, which stores the values of all variables at the moment of the Gibbs sweep. + - condition on the values of all variables that are not in the current group + - recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed + - perform a step of the sampler for the conditional probability problem, and update the sampler state + - update the `vi` with the new values from the sampler state + +Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index 438343b2..25062d4e 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -1,6 +1,8 @@ using AbstractMCMC -using LogDensityProblems, Distributions, LinearAlgebra, Random +using Distributions +using LogDensityProblems using OrderedCollections +using Random ## @@ -62,7 +64,7 @@ function AbstractMCMC.step( vi = state.vi for group in keys(spl.sampler_map) for (group, sub_state) in state.states - vi = merge(vi, unflatten(getparams(sub_state), group)) + vi = merge(vi, unflatten(get_params(sub_state), group)) end sub_spl = spl.sampler_map[group] sub_state = state.states[group] @@ -71,7 +73,7 @@ function AbstractMCMC.step( Tuple([vi[g] for g in group_complement]) ) cond_logdensity = condition(logdensity_model.logdensity, cond_val) - sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state) + sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state) sub_state = last( AbstractMCMC.step( rng, @@ -85,7 +87,7 @@ function AbstractMCMC.step( state.states[group] = sub_state end for (group, sub_state) in state.states - vi = merge(vi, unflatten(getparams(sub_state), group)) + vi = merge(vi, unflatten(get_params(sub_state), group)) end return GibbsTransition(vi), GibbsState(vi, state.states) end diff --git a/gibbs_example/hier_normal.jl b/gibbs_example/hier_normal.jl index 00de48cf..fa3b47bc 100644 --- a/gibbs_example/hier_normal.jl +++ b/gibbs_example/hier_normal.jl @@ -69,5 +69,5 @@ function unflatten(vec::AbstractVector, group::Tuple) end function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) - return setlogp!!(state, LogDensityProblems.logdensity(hn, vals)) + return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) end diff --git a/gibbs_example/mh.jl b/gibbs_example/mh.jl index 1029a2ee..926394c0 100644 --- a/gibbs_example/mh.jl +++ b/gibbs_example/mh.jl @@ -7,10 +7,10 @@ struct MHState{T} logp::Float64 end -getparams(state::MHState) = state.params -setparams!!(state::MHState, params) = MHState(params, state.logp) -getlogp(state::MHState) = state.logp -setlogp!!(state::MHState, logp) = MHState(state.params, logp) +AbstractMCMC.get_params(state::MHState) = state.params +AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) +AbstractMCMC.get_logprob(state::MHState) = state.logp +AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) struct RWMH <: AbstractMCMC.AbstractSampler σ::Float64 @@ -82,7 +82,7 @@ function AbstractMCMC.step( args...; kwargs..., ) - params = getparams(state) + params = get_params(state) proposal_dist = sampler.prior_dist proposal = rand(rng, proposal_dist) logp_proposal = only( diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index b347f4bf..f87997bf 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -22,8 +22,6 @@ export sample # Parallel sampling types export MCMCThreads, MCMCDistributed, MCMCSerial -@compat public recompute_logprob!!, getparams - """ AbstractChains @@ -98,18 +96,18 @@ Set the log-probability of the last sampling step, stored in `state`. function set_logprob!!(state, logprob) end """ - getparams(state) + get_params(state) Returns the values of the parameters in the state. """ -function getparams(state) end +function get_params(state) end """ setparams!(state, params) Set the values of the parameters in the state. """ -function setparams!!(state, params) end +function set_params!!(state, params) end include("samplingstats.jl") include("logging.jl") From 923c1167706ceb35e57a59f525989a4f6abcb8c1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:34:45 +0100 Subject: [PATCH 16/56] move folder into test --- {gibbs_example => test/gibbs_example}/Project.toml | 0 {gibbs_example => test/gibbs_example}/gibbs.jl | 0 {gibbs_example => test/gibbs_example}/gmm.jl | 0 {gibbs_example => test/gibbs_example}/hier_normal.jl | 0 {gibbs_example => test/gibbs_example}/mh.jl | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename {gibbs_example => test/gibbs_example}/Project.toml (100%) rename {gibbs_example => test/gibbs_example}/gibbs.jl (100%) rename {gibbs_example => test/gibbs_example}/gmm.jl (100%) rename {gibbs_example => test/gibbs_example}/hier_normal.jl (100%) rename {gibbs_example => test/gibbs_example}/mh.jl (100%) diff --git a/gibbs_example/Project.toml b/test/gibbs_example/Project.toml similarity index 100% rename from gibbs_example/Project.toml rename to test/gibbs_example/Project.toml diff --git a/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl similarity index 100% rename from gibbs_example/gibbs.jl rename to test/gibbs_example/gibbs.jl diff --git a/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl similarity index 100% rename from gibbs_example/gmm.jl rename to test/gibbs_example/gmm.jl diff --git a/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl similarity index 100% rename from gibbs_example/hier_normal.jl rename to test/gibbs_example/hier_normal.jl diff --git a/gibbs_example/mh.jl b/test/gibbs_example/mh.jl similarity index 100% rename from gibbs_example/mh.jl rename to test/gibbs_example/mh.jl From 63028d36a2865a6673538c0dd8c62a3f99fdb16d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:51:52 +0100 Subject: [PATCH 17/56] setup as a test --- Project.toml | 6 +- docs/src/gibbs.md | 46 ++++++++++++++ test/gibbs_example/Project.toml | 11 ---- test/gibbs_example/gibbs.jl | 88 +++++++++++++------------- test/gibbs_example/gmm.jl | 88 +++++++++++++------------- test/gibbs_example/mh.jl | 106 ++++++++++++++++---------------- 6 files changed, 193 insertions(+), 152 deletions(-) delete mode 100644 test/gibbs_example/Project.toml diff --git a/Project.toml b/Project.toml index 942d37c0..b7378daf 100644 --- a/Project.toml +++ b/Project.toml @@ -10,10 +10,12 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -34,10 +36,12 @@ Transducers = "0.4.30" julia = "1.6" [extras] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "IJulia", "Statistics", "Test"] +test = ["FillArrays", "Distributions", "IJulia", "OrderedCollections", "Statistics", "Test"] diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index ee1d279a..db8430b8 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -345,3 +345,49 @@ Some points worth noting: - update the `vi` with the new values from the sampler state Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. + +Now we can use the Gibbs sampler to sample from the hierarchical normal model. + +First we generate some data, + +```julia +N = 100 # Number of data points +mu_true = 0.5 # True mean +tau2_true = 2.0 # True variance + +x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) +``` + +``` + +Then we can create a `HierNormal` model with the data. + +```julia +hn = HierNormal((x=x_data,)) +``` + +sampling is easy: we use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support on positive real numbers. + +```julia +samples = sample( + hn, + Gibbs( + OrderedDict( + (:mu,) => RWMH(1), + (:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])), + ), + ), + 100000; + initial_params=(mu=[0.0], tau2=[1.0]), +) +``` + +Then we can extract the samples and compute the mean of the samples. + +```julia +mu_samples = [sample.values.mu for sample in samples][20001:end] +tau2_samples = [sample.values.tau2 for sample in samples][20001:end] + +mean(mu_samples) +mean(tau2_samples) +``` diff --git a/test/gibbs_example/Project.toml b/test/gibbs_example/Project.toml deleted file mode 100644 index 81b2b669..00000000 --- a/test/gibbs_example/Project.toml +++ /dev/null @@ -1,11 +0,0 @@ -[deps] -AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 25062d4e..87c1c90a 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -3,8 +3,11 @@ using Distributions using LogDensityProblems using OrderedCollections using Random +using Test -## +include("hier_normal.jl") +# include("gmm.jl") +include("mh.jl") struct Gibbs <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict @@ -64,7 +67,7 @@ function AbstractMCMC.step( vi = state.vi for group in keys(spl.sampler_map) for (group, sub_state) in state.states - vi = merge(vi, unflatten(get_params(sub_state), group)) + vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group)) end sub_spl = spl.sampler_map[group] sub_state = state.states[group] @@ -73,7 +76,7 @@ function AbstractMCMC.step( Tuple([vi[g] for g in group_complement]) ) cond_logdensity = condition(logdensity_model.logdensity, cond_val) - sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state) + sub_state = recompute_logprob!!(cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state) sub_state = last( AbstractMCMC.step( rng, @@ -87,15 +90,15 @@ function AbstractMCMC.step( state.states[group] = sub_state end for (group, sub_state) in state.states - vi = merge(vi, unflatten(get_params(sub_state), group)) + vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group)) end return GibbsTransition(vi), GibbsState(vi, state.states) end -## tests +## tests with hierarchical normal model # generate data -N = 100 # Number of data points +N = 1000 # Number of data points mu_true = 0.5 # True mean tau2_true = 2.0 # True variance @@ -105,8 +108,6 @@ x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) # Store the generated data in the HierNormal structure hn = HierNormal((x=x_data,)) -## - samples = sample( hn, Gibbs( @@ -115,43 +116,46 @@ samples = sample( (:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])), ), ), - 100000; + 200000; initial_params=(mu=[0.0], tau2=[1.0]), ) -mu_samples = [sample.values.mu for sample in samples][20001:end] -tau2_samples = [sample.values.tau2 for sample in samples][20001:end] - -mean(mu_samples) -mean(tau2_samples) - -## - -# this is too difficult of a problem - -gmm = GMM((; x=x)) - -samples = sample( - gmm, - Gibbs( - OrderedDict( - (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), - (:w,) => PriorMH(Dirichlet(2, 1.0)), - (:μ,) => RWMH(1), - ), - ), - 100000; - initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), -); +mu_samples = [sample.values.mu for sample in samples][40001:end] +tau2_samples = [sample.values.tau2 for sample in samples][40001:end] -z_samples = [sample.values.z for sample in samples][20001:end] -μ_samples = [sample.values.μ for sample in samples][20001:end] -w_samples = [sample.values.w for sample in samples][20001:end]; +mu_mean = mean(mu_samples)[1] +tau2_mean = mean(tau2_samples)[1] -# thin these samples -z_samples = z_samples[1:100:end] -μ_samples = μ_samples[1:100:end] -w_samples = w_samples[1:100:end]; +@testset "hierarchical normal with gibbs" begin + @test mu_mean ≈ mu_true atol = 0.1 + @test tau2_mean ≈ tau2_true atol = 0.3 +end -mean(μ_samples) -mean(w_samples) +## test with gmm -- too hard, doesn't converge + +# gmm = GMM((; x=x)) + +# samples = sample( +# gmm, +# Gibbs( +# OrderedDict( +# (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), +# (:w,) => PriorMH(Dirichlet(2, 1.0)), +# (:μ,) => RWMH(1), +# ), +# ), +# 100000; +# initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), +# ); + +# z_samples = [sample.values.z for sample in samples][20001:end] +# μ_samples = [sample.values.μ for sample in samples][20001:end] +# w_samples = [sample.values.w for sample in samples][20001:end]; + +# # thin these samples +# z_samples = z_samples[1:100:end] +# μ_samples = μ_samples[1:100:end] +# w_samples = w_samples[1:100:end]; + +# mean(μ_samples) +# mean(w_samples) diff --git a/test/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl index 5d345401..f5139307 100644 --- a/test/gibbs_example/gmm.jl +++ b/test/gibbs_example/gmm.jl @@ -1,5 +1,3 @@ -using LogDensityProblems - abstract type AbstractGMM end struct GMM <: AbstractGMM @@ -81,65 +79,65 @@ function unflatten(vec::AbstractVector, group::Tuple) end function recompute_logprob!!(gmm::ConditionedGMM, vals, state) - return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals)) + return set_logp!!(state, LogDensityProblems.logdensity(gmm, vals)) end ## test using Turing -# data generation +# # data generation -using FillArrays +# using FillArrays -w = [0.5, 0.5] -μ = [-3.5, 0.5] -mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) +# w = [0.5, 0.5] +# μ = [-3.5, 0.5] +# mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) -N = 60 -x = rand(mixturemodel, N); +# N = 60 +# x = rand(mixturemodel, N); -# Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/ -using Turing +# # Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/ +# using Turing -@model function gaussian_mixture_model(x) - # Draw the parameters for each of the K=2 clusters from a standard normal distribution. - K = 2 - μ ~ MvNormal(Zeros(K), I) +# @model function gaussian_mixture_model(x) +# # Draw the parameters for each of the K=2 clusters from a standard normal distribution. +# K = 2 +# μ ~ MvNormal(Zeros(K), I) - # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. - w ~ Dirichlet(K, 1.0) - # Alternatively, one could use a fixed set of weights. - # w = fill(1/K, K) +# # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. +# w ~ Dirichlet(K, 1.0) +# # Alternatively, one could use a fixed set of weights. +# # w = fill(1/K, K) - # Construct categorical distribution of assignments. - distribution_assignments = Categorical(w) +# # Construct categorical distribution of assignments. +# distribution_assignments = Categorical(w) - # Construct multivariate normal distributions of each cluster. - D, N = size(x) - distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] +# # Construct multivariate normal distributions of each cluster. +# D, N = size(x) +# distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] - # Draw assignments for each datum and generate it from the multivariate normal distribution. - k = Vector{Int}(undef, N) - for i in 1:N - k[i] ~ distribution_assignments - x[:, i] ~ distribution_clusters[k[i]] - end +# # Draw assignments for each datum and generate it from the multivariate normal distribution. +# k = Vector{Int}(undef, N) +# for i in 1:N +# k[i] ~ distribution_assignments +# x[:, i] ~ distribution_clusters[k[i]] +# end - return μ, w, k, __varinfo__ -end +# return μ, w, k, __varinfo__ +# end -model = gaussian_mixture_model(x); +# model = gaussian_mixture_model(x); -using Test -# full model -μ, w, k, vi = model() -@test log_joint(; μ=μ, w=w, z=k, x=x) ≈ DynamicPPL.getlogp(vi) +# using Test +# # full model +# μ, w, k, vi = model() +# @test log_joint(; μ=μ, w=w, z=k, x=x) ≈ DynamicPPL.getlogp(vi) -gmm = GMM((; x=x)) +# gmm = GMM((; x=x)) -# cond model on μ, w -μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))() -@test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) ≈ DynamicPPL.getlogp(vi) +# # cond model on μ, w +# μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))() +# @test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) ≈ DynamicPPL.getlogp(vi) -# cond model on z -μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))() -@test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) ≈ DynamicPPL.getlogp(vi) +# # cond model on z +# μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))() +# @test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) ≈ DynamicPPL.getlogp(vi) diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 926394c0..43b27b77 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -46,12 +46,12 @@ function AbstractMCMC.step( LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) ) - log_acceptance_ratio = min(0, logp_proposal - getlogp(state)) + log_acceptance_ratio = min(0, logp_proposal - AbstractMCMC.get_logprob(state)) if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else - return MHTransition(params), MHState(params, getlogp(state)) + return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) end end @@ -82,7 +82,7 @@ function AbstractMCMC.step( args...; kwargs..., ) - params = get_params(state) + params = AbstractMCMC.get_params(state) proposal_dist = sampler.prior_dist proposal = rand(rng, proposal_dist) logp_proposal = only( @@ -91,64 +91,64 @@ function AbstractMCMC.step( log_acceptance_ratio = min( 0, - logp_proposal - getlogp(state) + logpdf(proposal_dist, params) - + logp_proposal - AbstractMCMC.get_logprob(state) + logpdf(proposal_dist, params) - logpdf(proposal_dist, proposal), ) if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else - return MHTransition(params), MHState(params, getlogp(state)) + return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) end end ## tests -# for RWMH -# sample from Normal(10, 1) -struct NormalLogDensity end -LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) -LogDensityProblems.dimension(l::NormalLogDensity) = 1 -function LogDensityProblems.capabilities(::NormalLogDensity) - return LogDensityProblems.LogDensityOrder{1}() -end - -# for PriorMH -# sample from Categorical([0.2, 0.5, 0.3]) -struct CategoricalLogDensity end -function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) - return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) -end -LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 -function LogDensityProblems.capabilities(::CategoricalLogDensity) - return LogDensityProblems.LogDensityOrder{0}() -end - -## - -using StatsPlots - -samples = AbstractMCMC.sample( - Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] -) -_samples = map(t -> only(t.params), samples) - -histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)") -plot!(Normal(10, 1); linewidth=2, label="Ground Truth") - -samples = AbstractMCMC.sample( - Random.default_rng(), - CategoricalLogDensity(), - PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), - 100000; - initial_params=[1], -) -_samples = map(t -> only(t.params), samples) - -histogram( - _samples; - normalize=:probability, - label="Samples", - title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", -) -plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth") +# # for RWMH +# # sample from Normal(10, 1) +# struct NormalLogDensity end +# LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) +# LogDensityProblems.dimension(l::NormalLogDensity) = 1 +# function LogDensityProblems.capabilities(::NormalLogDensity) +# return LogDensityProblems.LogDensityOrder{1}() +# end + +# # for PriorMH +# # sample from Categorical([0.2, 0.5, 0.3]) +# struct CategoricalLogDensity end +# function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) +# return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) +# end +# LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 +# function LogDensityProblems.capabilities(::CategoricalLogDensity) +# return LogDensityProblems.LogDensityOrder{0}() +# end + +# ## + +# using StatsPlots + +# samples = AbstractMCMC.sample( +# Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] +# ) +# _samples = map(t -> only(t.params), samples) + +# histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)") +# plot!(Normal(10, 1); linewidth=2, label="Ground Truth") + +# samples = AbstractMCMC.sample( +# Random.default_rng(), +# CategoricalLogDensity(), +# PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), +# 100000; +# initial_params=[1], +# ) +# _samples = map(t -> only(t.params), samples) + +# histogram( +# _samples; +# normalize=:probability, +# label="Samples", +# title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", +# ) +# plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth") From 44de81ce138d8ac79adae8e29b874858b4912d05 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:53:10 +0100 Subject: [PATCH 18/56] add to doc --- docs/make.jl | 2 +- docs/src/gibbs.md | 2 +- test/runtests.jl | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 9395d2a0..a2adb8e9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=["Home" => "index.md", "api.md", "design.md"], + pages=["Home" => "index.md", "api.md", "design.md", "gibbs.md"], checkdocs=:exports, ) diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index db8430b8..3ff87fa0 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -1,4 +1,4 @@ -# `state` Interface +# `state` Interface Functions We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..afc804b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,5 @@ include("utils.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") + include("gibbs_example/gibbs.jl") end From be43178b10295bb4d07e52c571a64ea093f12958 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:53:54 +0100 Subject: [PATCH 19/56] format --- test/gibbs_example/gibbs.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 87c1c90a..da84acee 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -76,7 +76,9 @@ function AbstractMCMC.step( Tuple([vi[g] for g in group_complement]) ) cond_logdensity = condition(logdensity_model.logdensity, cond_val) - sub_state = recompute_logprob!!(cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state) + sub_state = recompute_logprob!!( + cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state + ) sub_state = last( AbstractMCMC.step( rng, From 1a6e0d573a467c1fa55e5c2b36ce3aba462a24ae Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:58:49 +0100 Subject: [PATCH 20/56] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b7378daf..ff588ba0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.3.0" +version = "5.4.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 6b60b72b8900f6454c4cda242d13a3334c856a77 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 13:59:38 +0100 Subject: [PATCH 21/56] reverse version bump -- already done --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff588ba0..b7378daf 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.4.0" +version = "5.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From c58b39a77b216ebc1879dfc3386c2749cdbdd7ad Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 14:00:14 +0100 Subject: [PATCH 22/56] remove dep on `Compat` --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index b7378daf..472aed89 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "5.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -24,7 +23,6 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19, 0.4" -Compat = "4.15.0" ConsoleProgressMonitor = "0.1" FillArrays = "1" LogDensityProblems = "2" From ac0ce7a3cd1504d1b09e7f061ba173b8c4cb3f4c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 14:05:50 +0100 Subject: [PATCH 23/56] updates to doc --- docs/src/gibbs.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index 3ff87fa0..5d635d76 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -116,6 +116,19 @@ end where the model doesn't need to know the details of the `state` type, as long as it can access the `log_joint` function. +Additionally, we define a couple of helper functions to transform between the sampler representation and the model representation of the parameters values. +In this simple example, the model representation is a vector, and the sampler representation is a named tuple. + +```julia +function flatten(nt::NamedTuple) + return only(values(nt)) +end + +function unflatten(vec::AbstractVector, group::Tuple) + return NamedTuple((only(group) => vec,)) +end +``` + ## Sampler Packages To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. @@ -358,15 +371,13 @@ tau2_true = 2.0 # True variance x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) ``` -``` - -Then we can create a `HierNormal` model with the data. +Then we can create a `HierNormal` model, with the data we just generated. ```julia hn = HierNormal((x=x_data,)) ``` -sampling is easy: we use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support on positive real numbers. +Using Gibbs sampling allows us to use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support only on positive real numbers. ```julia samples = sample( From 280eaf1095e8549ceae3d47ca81c9b10c0a1f91b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 8 Sep 2024 15:10:46 +0100 Subject: [PATCH 24/56] update gibbs to add to the src folder --- Project.toml | 4 +- src/AbstractMCMC.jl | 6 +- src/gibbs.jl | 222 ++++++++++++++++++++++++++++++ test/gibbs_example/gibbs.jl | 219 +++++++++-------------------- test/gibbs_example/gmm.jl | 60 -------- test/gibbs_example/hier_normal.jl | 14 +- test/gibbs_example/mh.jl | 4 +- 7 files changed, 297 insertions(+), 232 deletions(-) create mode 100644 src/gibbs.jl diff --git a/Project.toml b/Project.toml index 472aed89..79f1f60f 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -37,9 +36,8 @@ julia = "1.6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "Distributions", "IJulia", "OrderedCollections", "Statistics", "Test"] +test = ["FillArrays", "Distributions", "IJulia", "Statistics", "Test"] diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index f87997bf..0fde7b90 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,7 +1,6 @@ module AbstractMCMC using BangBang: BangBang -using Compat using ConsoleProgressMonitor: ConsoleProgressMonitor using LogDensityProblems: LogDensityProblems using LoggingExtras: LoggingExtras @@ -81,6 +80,10 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end +function condition end + +function recompute_logprob!! end + """ get_logprob(state) @@ -116,5 +119,6 @@ include("sample.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") +include("gibbs.jl") end # module AbstractMCMC diff --git a/src/gibbs.jl b/src/gibbs.jl new file mode 100644 index 00000000..d5c3b2bc --- /dev/null +++ b/src/gibbs.jl @@ -0,0 +1,222 @@ +""" + Gibbs(sampler_map::NamedTuple) + +An interface for block sampling in Markov Chain Monte Carlo (MCMC). + +Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems. +It allows different sampling methods to be applied to different parameters. +""" +struct Gibbs <: AbstractMCMC.AbstractSampler + sampler_map::NamedTuple + parameter_names::Tuple{Vararg{Symbol}} + + function Gibbs(sampler_map::NamedTuple) + parameter_names = Tuple(keys(sampler_map)) + return new(sampler_map, parameter_names) + end +end + +struct GibbsState + """ + `trace` contains the values of the values of _all_ parameters up to the last iteration. + """ + trace::NamedTuple + + """ + `mcmc_states` maps parameters to their sampler-specific MCMC states. + """ + mcmc_states::NamedTuple + + """ + `variable_sizes` maps parameters to their sizes. + """ + variable_sizes::NamedTuple +end + +struct GibbsTransition + """ + Realizations of the parameters, this is considered a "sample" in the MCMC chain. + """ + values::NamedTuple +end + +""" + flatten(trace::Union{NamedTuple,OrderedCollections.OrderedDict}) + +Flatten all the values in the trace into a single vector. + +# Examples + +```jldoctest +julia> flatten((a=[1,2], b=[3,4,5])) +[1, 2, 3, 4, 5] + +julia> flatten(OrderedCollections.OrderedDict(:x=>[1.0,2.0], :y=>[3.0,4.0,5.0])) +[1.0, 2.0, 3.0, 4.0, 5.0] +``` +""" +function flatten(trace::NamedTuple) + return reduce(vcat, vec.(values(trace))) +end + +""" + unflatten(vec::AbstractVector, group_names_and_sizes::NamedTuple) + +Reverse operation of flatten. Reshape the vector into the original arrays using size information. + +# Examples + +```jldoctest +julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,))) +(a=[1,2], b=[3,4,5]) + +julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) +(x=[1.0 3.0; 2.0 4.0], y=[5.0,6.0]) +``` +""" +function unflatten(vec::AbstractVector, variable_sizes::NamedTuple) + result = Dict{Symbol,Array}() + start_idx = 1 + for (name, size) in pairs(variable_sizes) + end_idx = start_idx + prod(size) - 1 + result[name] = reshape(vec[start_idx:end_idx], size...) + start_idx = end_idx + 1 + end + + # ensure the order of the keys is the same as the one in variable_sizes + return NamedTuple{Tuple(keys(variable_sizes))}([ + result[name] for name in keys(variable_sizes) + ]) +end + +""" + update_trace(trace::NamedTuple, gibbs_state::GibbsState) + +Update the trace with the values from the MCMC states of the sub-problems. +""" +function update_trace(trace::NamedTuple, gibbs_state::GibbsState) + for parameter_variable in keys(gibbs_state.mcmc_states) + sub_state = gibbs_state.mcmc_states[parameter_variable] + trace = merge( + trace, + unflatten( + AbstractMCMC.get_params(sub_state), + NamedTuple{(parameter_variable,)}(( + gibbs_state.variable_sizes[parameter_variable], + )), + ), + ) + end + return trace +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs, + args...; + initial_params::NamedTuple, + kwargs..., +) + if Set(keys(initial_params)) != Set(sampler.parameter_names) + throw( + ArgumentError( + "initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))", + ), + ) + end + + mcmc_states = Dict{Symbol,Any}() + variable_sizes = Dict{Symbol,Tuple}() + for parameter_variable in sampler.parameter_names + sub_sampler = sampler.sampler_map[parameter_variable] + + variables_to_be_conditioned_on = setdiff( + sampler.parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) + ) + sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( + initial_params[parameter_variable], + )) + + # LogDensityProblems' `logdensity` function expects a single vector of real numbers + # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values + # and unflatten after the sampling step + variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable])) + flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + AbstractMCMC.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ), + sub_sampler, + args...; + initial_params=flattened_sub_problem_parameters_values, + kwargs..., + ), + ) + mcmc_states[parameter_variable] = sub_state + end + + gibbs_state = GibbsState( + initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes) + ) + trace = update_trace(NamedTuple(), gibbs_state) + return GibbsTransition(trace), gibbs_state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs, + gibbs_state::GibbsState, + args...; + kwargs..., +) + (; trace, mcmc_states, variable_sizes) = gibbs_state + mcmc_states_dict = Dict( + keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)] + ) + for parameter_variable in sampler.parameter_names + sub_sampler = sampler.sampler_map[parameter_variable] + sub_state = mcmc_states[parameter_variable] + variables_to_be_conditioned_on = setdiff( + sampler.parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([trace[g] for g in variables_to_be_conditioned_on]) + ) + cond_logdensity = AbstractMCMC.condition( + logdensity_model.logdensity, conditioning_variables_values + ) + + # recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems + sub_state = AbstractMCMC.recompute_logprob!!( + cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state + ) + + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel(cond_logdensity), + sub_sampler, + sub_state, + args...; + kwargs..., + ), + ) + mcmc_states_dict[parameter_variable] = sub_state + trace = update_trace(trace, gibbs_state) + end + + mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}( + Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)]) + ) + return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) +end diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index da84acee..417446b1 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,163 +1,72 @@ -using AbstractMCMC -using Distributions -using LogDensityProblems -using OrderedCollections -using Random -using Test - -include("hier_normal.jl") -# include("gmm.jl") include("mh.jl") - -struct Gibbs <: AbstractMCMC.AbstractSampler - sampler_map::OrderedDict -end - -struct GibbsState - vi::NamedTuple - states::OrderedDict -end - -struct GibbsTransition - values::NamedTuple -end - -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - spl::Gibbs, - args...; - initial_params::NamedTuple, - kwargs..., -) - states = OrderedDict() - for group in keys(spl.sampler_map) - sub_spl = spl.sampler_map[group] - - vars_to_be_conditioned_on = setdiff(keys(initial_params), group) - cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( - Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) - ) - params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) - sub_state = last( - AbstractMCMC.step( - rng, - AbstractMCMC.LogDensityModel( - condition(logdensity_model.logdensity, cond_val) - ), - sub_spl, - args...; - initial_params=flatten(params_val), - kwargs..., - ), - ) - states[group] = sub_state - end - return GibbsTransition(initial_params), GibbsState(initial_params, states) -end - -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - spl::Gibbs, - state::GibbsState, - args...; - kwargs..., -) - vi = state.vi - for group in keys(spl.sampler_map) - for (group, sub_state) in state.states - vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group)) - end - sub_spl = spl.sampler_map[group] - sub_state = state.states[group] - group_complement = setdiff(keys(vi), group) - cond_val = NamedTuple{Tuple(group_complement)}( - Tuple([vi[g] for g in group_complement]) - ) - cond_logdensity = condition(logdensity_model.logdensity, cond_val) - sub_state = recompute_logprob!!( - cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state - ) - sub_state = last( - AbstractMCMC.step( - rng, - AbstractMCMC.LogDensityModel(cond_logdensity), - sub_spl, - sub_state, - args...; - kwargs..., - ), - ) - state.states[group] = sub_state - end - for (group, sub_state) in state.states - vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group)) - end - return GibbsTransition(vi), GibbsState(vi, state.states) -end - -## tests with hierarchical normal model - -# generate data -N = 1000 # Number of data points -mu_true = 0.5 # True mean -tau2_true = 2.0 # True variance - -# Generate data based on true parameters -x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) - -# Store the generated data in the HierNormal structure -hn = HierNormal((x=x_data,)) - -samples = sample( - hn, - Gibbs( - OrderedDict( - (:mu,) => RWMH(1), - (:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])), - ), - ), - 200000; - initial_params=(mu=[0.0], tau2=[1.0]), -) - -mu_samples = [sample.values.mu for sample in samples][40001:end] -tau2_samples = [sample.values.tau2 for sample in samples][40001:end] - -mu_mean = mean(mu_samples)[1] -tau2_mean = mean(tau2_samples)[1] +# include("gmm.jl") +include("hier_normal.jl") @testset "hierarchical normal with gibbs" begin + # generate data + N = 1000 # Number of data points + mu_true = 0.5 # True mean + tau2_true = 2.0 # True variance + x_data = rand(Distributions.Normal(mu_true, sqrt(tau2_true)), N) + + # Store the generated data in the HierNormal structure + hn = HierNormal((x=x_data,)) + + samples = sample( + hn, + AbstractMCMC.Gibbs(( + mu=RWMH(1), tau2=PriorMH(product_distribution([InverseGamma(1, 1)])) + )), + 200000; + initial_params=(mu=[0.0], tau2=[1.0]), + ) + + warmup = 40000 + thin = 10 + thinned_samples = samples[(warmup + 1):thin:end] + mu_samples = [sample.values.mu for sample in thinned_samples] + tau2_samples = [sample.values.tau2 for sample in thinned_samples] + + mu_mean = only(mean(mu_samples)) + tau2_mean = only(mean(tau2_samples)) + @test mu_mean ≈ mu_true atol = 0.1 @test tau2_mean ≈ tau2_true atol = 0.3 end -## test with gmm -- too hard, doesn't converge - -# gmm = GMM((; x=x)) - -# samples = sample( -# gmm, -# Gibbs( -# OrderedDict( -# (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), -# (:w,) => PriorMH(Dirichlet(2, 1.0)), -# (:μ,) => RWMH(1), +# This is too difficult to sample, disable for now +# @testset "gmm with gibbs" begin +# w = [0.5, 0.5] +# μ = [-3.5, 0.5] +# mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) + +# N = 60 +# x = rand(mixturemodel, N) + +# gmm = GMM((; x=x)) + +# samples = sample( +# gmm, +# Gibbs( +# ( +# z = PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), +# w = PriorMH(Dirichlet(2, 1.0)), +# μ = RWMH(1), +# ), # ), -# ), -# 100000; -# initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), -# ); - -# z_samples = [sample.values.z for sample in samples][20001:end] -# μ_samples = [sample.values.μ for sample in samples][20001:end] -# w_samples = [sample.values.w for sample in samples][20001:end]; - -# # thin these samples -# z_samples = z_samples[1:100:end] -# μ_samples = μ_samples[1:100:end] -# w_samples = w_samples[1:100:end]; - -# mean(μ_samples) -# mean(w_samples) +# 100000; +# initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), +# ) + +# z_samples = [sample.values.z for sample in samples][20001:end] +# μ_samples = [sample.values.μ for sample in samples][20001:end] +# w_samples = [sample.values.w for sample in samples][20001:end] + +# # thin these samples +# z_samples = z_samples[1:100:end] +# μ_samples = μ_samples[1:100:end] +# w_samples = w_samples[1:100:end] + +# mean(μ_samples) +# mean(w_samples) +# end diff --git a/test/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl index f5139307..5bc01b31 100644 --- a/test/gibbs_example/gmm.jl +++ b/test/gibbs_example/gmm.jl @@ -81,63 +81,3 @@ end function recompute_logprob!!(gmm::ConditionedGMM, vals, state) return set_logp!!(state, LogDensityProblems.logdensity(gmm, vals)) end - -## test using Turing - -# # data generation - -# using FillArrays - -# w = [0.5, 0.5] -# μ = [-3.5, 0.5] -# mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) - -# N = 60 -# x = rand(mixturemodel, N); - -# # Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/ -# using Turing - -# @model function gaussian_mixture_model(x) -# # Draw the parameters for each of the K=2 clusters from a standard normal distribution. -# K = 2 -# μ ~ MvNormal(Zeros(K), I) - -# # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. -# w ~ Dirichlet(K, 1.0) -# # Alternatively, one could use a fixed set of weights. -# # w = fill(1/K, K) - -# # Construct categorical distribution of assignments. -# distribution_assignments = Categorical(w) - -# # Construct multivariate normal distributions of each cluster. -# D, N = size(x) -# distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] - -# # Draw assignments for each datum and generate it from the multivariate normal distribution. -# k = Vector{Int}(undef, N) -# for i in 1:N -# k[i] ~ distribution_assignments -# x[:, i] ~ distribution_clusters[k[i]] -# end - -# return μ, w, k, __varinfo__ -# end - -# model = gaussian_mixture_model(x); - -# using Test -# # full model -# μ, w, k, vi = model() -# @test log_joint(; μ=μ, w=w, z=k, x=x) ≈ DynamicPPL.getlogp(vi) - -# gmm = GMM((; x=x)) - -# # cond model on μ, w -# μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))() -# @test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) ≈ DynamicPPL.getlogp(vi) - -# # cond model on z -# μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))() -# @test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) ≈ DynamicPPL.getlogp(vi) diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index fa3b47bc..deba5336 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -1,5 +1,3 @@ -using LogDensityProblems - abstract type AbstractHierNormal end struct HierNormal <: AbstractHierNormal @@ -36,7 +34,7 @@ function log_joint(; mu, tau2, x) return logp end -function condition(hn::HierNormal, conditioned_values::NamedTuple) +function AbstractMCMC.condition(hn::HierNormal, conditioned_values::NamedTuple) return ConditionedHierNormal(hn.data, conditioned_values) end @@ -60,14 +58,6 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal) return LogDensityProblems.LogDensityOrder{0}() end -function flatten(nt::NamedTuple) - return only(values(nt)) -end - -function unflatten(vec::AbstractVector, group::Tuple) - return NamedTuple((only(group) => vec,)) -end - -function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) +function AbstractMCMC.recompute_logprob!!(hn::ConditionedHierNormal, vals, state) return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 43b27b77..965419db 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -1,3 +1,5 @@ +using Distributions + struct MHTransition{T} params::Vector{T} end @@ -56,7 +58,7 @@ function AbstractMCMC.step( end struct PriorMH <: AbstractMCMC.AbstractSampler - prior_dist::Distribution + prior_dist::Distributions.Distribution end function AbstractMCMC.step( From b262ea99f3de532ffc7b34d098835e86139f2238 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 8 Sep 2024 19:05:19 +0100 Subject: [PATCH 25/56] update mh code --- docs/src/gibbs.md | 16 ++--- test/gibbs_example/gibbs.jl | 8 +-- test/gibbs_example/mh.jl | 128 +++++++----------------------------- 3 files changed, 37 insertions(+), 115 deletions(-) diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index 5d635d76..4b9ef8f9 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -162,14 +162,14 @@ These are the functions that was used in the `recompute_logprob!!` function abov It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state. ```julia -struct RWMH <: AbstractMCMC.AbstractSampler +struct RandomWalkMH <: AbstractMCMC.AbstractSampler σ::Float64 end function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RWMH, + sampler::RandomWalkMH, args...; initial_params, kwargs..., @@ -184,7 +184,7 @@ end function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RWMH, + sampler::RandomWalkMH, state::MHState, args...; kwargs..., @@ -207,14 +207,14 @@ end ``` ```julia -struct PriorMH <: AbstractMCMC.AbstractSampler +struct IndependentMH <: AbstractMCMC.AbstractSampler prior_dist::Distribution end function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::PriorMH, + sampler::IndependentMH, args...; initial_params, kwargs..., @@ -229,7 +229,7 @@ end function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::PriorMH, + sampler::IndependentMH, state::MHState, args...; kwargs..., @@ -384,8 +384,8 @@ samples = sample( hn, Gibbs( OrderedDict( - (:mu,) => RWMH(1), - (:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])), + (:mu,) => RandomWalkMH(1), + (:tau2,) => IndependentMH(product_distribution([InverseGamma(1, 1)])), ), ), 100000; diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 417446b1..13c0f3e1 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -15,7 +15,7 @@ include("hier_normal.jl") samples = sample( hn, AbstractMCMC.Gibbs(( - mu=RWMH(1), tau2=PriorMH(product_distribution([InverseGamma(1, 1)])) + mu=RandomWalkMH(1), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])) )), 200000; initial_params=(mu=[0.0], tau2=[1.0]), @@ -49,9 +49,9 @@ end # gmm, # Gibbs( # ( -# z = PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), -# w = PriorMH(Dirichlet(2, 1.0)), -# μ = RWMH(1), +# z = IndependentMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), +# w = IndependentMH(Dirichlet(2, 1.0)), +# μ = RandomWalkMH(1), # ), # ), # 100000; diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 965419db..b41b61a9 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -1,12 +1,12 @@ using Distributions -struct MHTransition{T} +struct MHState{T} params::Vector{T} + logp::Float64 end -struct MHState{T} +struct MHTransition{T} params::Vector{T} - logp::Float64 end AbstractMCMC.get_params(state::MHState) = state.params @@ -14,14 +14,18 @@ AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) AbstractMCMC.get_logprob(state::MHState) = state.logp AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) -struct RWMH <: AbstractMCMC.AbstractSampler +struct RandomWalkMH <: AbstractMCMC.AbstractSampler σ::Float64 end +struct IndependentMH <: AbstractMCMC.AbstractSampler + proposal_dist::Distributions.Distribution +end + function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RWMH, + sampler::Union{RandomWalkMH,IndependentMH}, args...; initial_params, kwargs..., @@ -36,121 +40,39 @@ end function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RWMH, + sampler::Union{RandomWalkMH,IndependentMH}, state::MHState, args...; kwargs..., ) params = state.params - proposal_dist = MvNormal(zeros(length(params)), sampler.σ) - proposal = params .+ rand(rng, proposal_dist) + proposal_dist = + sampler isa RandomWalkMH ? MvNormal(state.params, sampler.σ) : sampler.proposal_dist + proposal = rand(rng, proposal_dist) logp_proposal = only( LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) ) - log_acceptance_ratio = min(0, logp_proposal - AbstractMCMC.get_logprob(state)) - - if log(rand(rng)) < log_acceptance_ratio + if log(rand(rng)) < + compute_log_acceptance_ratio(sampler, state, proposal, logp_proposal) return MHTransition(proposal), MHState(proposal, logp_proposal) else - return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) + return MHTransition(params), MHState(params, state.logp) end end -struct PriorMH <: AbstractMCMC.AbstractSampler - prior_dist::Distributions.Distribution -end - -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - sampler::PriorMH, - args...; - initial_params, - kwargs..., +function compute_log_acceptance_ratio( + ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 ) - return MHTransition(initial_params), - MHState( - initial_params, - only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), - ) + return min(0, logp_proposal - AbstractMCMC.get_logprob(state)) end -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - sampler::PriorMH, - state::MHState, - args...; - kwargs..., -) - params = AbstractMCMC.get_params(state) - proposal_dist = sampler.prior_dist - proposal = rand(rng, proposal_dist) - logp_proposal = only( - LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) - ) - - log_acceptance_ratio = min( +function compute_log_acceptance_ratio( + sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64 +) where {T} + return min( 0, - logp_proposal - AbstractMCMC.get_logprob(state) + logpdf(proposal_dist, params) - - logpdf(proposal_dist, proposal), + logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) - + logpdf(sampler.proposal_dist, proposal), ) - - if log(rand(rng)) < log_acceptance_ratio - return MHTransition(proposal), MHState(proposal, logp_proposal) - else - return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) - end end - -## tests - -# # for RWMH -# # sample from Normal(10, 1) -# struct NormalLogDensity end -# LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) -# LogDensityProblems.dimension(l::NormalLogDensity) = 1 -# function LogDensityProblems.capabilities(::NormalLogDensity) -# return LogDensityProblems.LogDensityOrder{1}() -# end - -# # for PriorMH -# # sample from Categorical([0.2, 0.5, 0.3]) -# struct CategoricalLogDensity end -# function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) -# return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) -# end -# LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 -# function LogDensityProblems.capabilities(::CategoricalLogDensity) -# return LogDensityProblems.LogDensityOrder{0}() -# end - -# ## - -# using StatsPlots - -# samples = AbstractMCMC.sample( -# Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] -# ) -# _samples = map(t -> only(t.params), samples) - -# histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)") -# plot!(Normal(10, 1); linewidth=2, label="Ground Truth") - -# samples = AbstractMCMC.sample( -# Random.default_rng(), -# CategoricalLogDensity(), -# PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), -# 100000; -# initial_params=[1], -# ) -# _samples = map(t -> only(t.params), samples) - -# histogram( -# _samples; -# normalize=:probability, -# label="Samples", -# title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", -# ) -# plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth") From c47ade4f2118177cc4f8ccf9de9e8aa4457a717c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 8 Sep 2024 20:22:04 +0100 Subject: [PATCH 26/56] update code further --- src/AbstractMCMC.jl | 30 ------------------------------ src/gibbs.jl | 16 +++++++++++----- test/gibbs_example/gibbs.jl | 5 +++-- test/gibbs_example/gmm.jl | 4 ---- test/gibbs_example/mh.jl | 14 +++++++++----- 5 files changed, 23 insertions(+), 46 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 0fde7b90..6d18f962 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -82,36 +82,6 @@ struct MCMCSerial <: AbstractMCMCEnsemble end function condition end -function recompute_logprob!! end - -""" - get_logprob(state) - -Returns the log-probability of the last sampling step, stored in `state`. -""" -function get_logprob(state) end - -""" - set_logprob!(state, logprob) - -Set the log-probability of the last sampling step, stored in `state`. -""" -function set_logprob!!(state, logprob) end - -""" - get_params(state) - -Returns the values of the parameters in the state. -""" -function get_params(state) end - -""" - setparams!(state, params) - -Set the values of the parameters in the state. -""" -function set_params!!(state, params) end - include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/gibbs.jl b/src/gibbs.jl index d5c3b2bc..43818116 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -77,7 +77,8 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) function unflatten(vec::AbstractVector, variable_sizes::NamedTuple) result = Dict{Symbol,Array}() start_idx = 1 - for (name, size) in pairs(variable_sizes) + for name in keys(variable_sizes) + size = variable_sizes[name] end_idx = start_idx + prod(size) - 1 result[name] = reshape(vec[start_idx:end_idx], size...) start_idx = end_idx + 1 @@ -100,7 +101,7 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState) trace = merge( trace, unflatten( - AbstractMCMC.get_params(sub_state), + vec(sub_state), NamedTuple{(parameter_variable,)}(( gibbs_state.variable_sizes[parameter_variable], )), @@ -197,9 +198,14 @@ function AbstractMCMC.step( ) # recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems - sub_state = AbstractMCMC.recompute_logprob!!( - cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state - ) + updated_log_prob = LogDensityProblems.logdensity(cond_logdensity, sub_state) + + if !hasproperty(sub_state, :logp) + error( + "$(typeof(sub_state)) does not have a `:logp` field, which is required by Gibbs sampling", + ) + end + sub_state = BangBang.setproperty!!(sub_state, :logp, updated_log_prob) sub_state = last( AbstractMCMC.step( diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 13c0f3e1..87286b5a 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -5,7 +5,7 @@ include("hier_normal.jl") @testset "hierarchical normal with gibbs" begin # generate data N = 1000 # Number of data points - mu_true = 0.5 # True mean + mu_true = 5 # True mean tau2_true = 2.0 # True variance x_data = rand(Distributions.Normal(mu_true, sqrt(tau2_true)), N) @@ -15,7 +15,8 @@ include("hier_normal.jl") samples = sample( hn, AbstractMCMC.Gibbs(( - mu=RandomWalkMH(1), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])) + mu=RandomWalkMH(0.3), + tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), )), 200000; initial_params=(mu=[0.0], tau2=[1.0]), diff --git a/test/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl index 5bc01b31..2d59b0aa 100644 --- a/test/gibbs_example/gmm.jl +++ b/test/gibbs_example/gmm.jl @@ -77,7 +77,3 @@ end function unflatten(vec::AbstractVector, group::Tuple) return NamedTuple((only(group) => vec,)) end - -function recompute_logprob!!(gmm::ConditionedGMM, vals, state) - return set_logp!!(state, LogDensityProblems.logdensity(gmm, vals)) -end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index b41b61a9..152ca6cb 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -9,10 +9,14 @@ struct MHTransition{T} params::Vector{T} end -AbstractMCMC.get_params(state::MHState) = state.params -AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) -AbstractMCMC.get_logprob(state::MHState) = state.logp -AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) +function AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state::MHState) + # recompute the logdensity, instead of using the one stored in the state + return AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) +end + +function Base.vec(state::MHState) + return state.params +end struct RandomWalkMH <: AbstractMCMC.AbstractSampler σ::Float64 @@ -64,7 +68,7 @@ end function compute_log_acceptance_ratio( ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 ) - return min(0, logp_proposal - AbstractMCMC.get_logprob(state)) + return min(0, logp_proposal - state.logp) end function compute_log_acceptance_ratio( From 8d29ad348563a933a08cadf0e11d7d774dcceb37 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 9 Sep 2024 08:47:23 +0100 Subject: [PATCH 27/56] fix test errors --- src/gibbs.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gibbs.jl b/src/gibbs.jl index 43818116..fb8bbd64 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -47,7 +47,7 @@ Flatten all the values in the trace into a single vector. # Examples -```jldoctest +```jldoctest; setup = :(using AbstractMCMC: flatten) julia> flatten((a=[1,2], b=[3,4,5])) [1, 2, 3, 4, 5] @@ -66,7 +66,7 @@ Reverse operation of flatten. Reshape the vector into the original arrays using # Examples -```jldoctest +```jldoctest; setup = :(using AbstractMCMC: unflatten) julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,))) (a=[1,2], b=[3,4,5]) @@ -180,7 +180,10 @@ function AbstractMCMC.step( args...; kwargs..., ) - (; trace, mcmc_states, variable_sizes) = gibbs_state + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes + mcmc_states_dict = Dict( keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)] ) From c28a75ab5c4a5fc728eb79c631cae011eee9679f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 9 Sep 2024 08:52:49 +0100 Subject: [PATCH 28/56] format --- src/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gibbs.jl b/src/gibbs.jl index fb8bbd64..df0ef7c8 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -183,7 +183,7 @@ function AbstractMCMC.step( trace = gibbs_state.trace mcmc_states = gibbs_state.mcmc_states variable_sizes = gibbs_state.variable_sizes - + mcmc_states_dict = Dict( keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)] ) From 1382054b3be11cf28f231b636ff1c34be187d881 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 9 Sep 2024 10:37:10 +0100 Subject: [PATCH 29/56] fix doctest error --- src/gibbs.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gibbs.jl b/src/gibbs.jl index df0ef7c8..bb5866c1 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -49,10 +49,13 @@ Flatten all the values in the trace into a single vector. ```jldoctest; setup = :(using AbstractMCMC: flatten) julia> flatten((a=[1,2], b=[3,4,5])) -[1, 2, 3, 4, 5] +5-element Vector{Int64}: + 1 + 2 + 3 + 4 + 5 -julia> flatten(OrderedCollections.OrderedDict(:x=>[1.0,2.0], :y=>[3.0,4.0,5.0])) -[1.0, 2.0, 3.0, 4.0, 5.0] ``` """ function flatten(trace::NamedTuple) @@ -68,10 +71,10 @@ Reverse operation of flatten. Reshape the vector into the original arrays using ```jldoctest; setup = :(using AbstractMCMC: unflatten) julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,))) -(a=[1,2], b=[3,4,5]) +(a = [1, 2], b = [3, 4, 5]) julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) -(x=[1.0 3.0; 2.0 4.0], y=[5.0,6.0]) +(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0]) ``` """ function unflatten(vec::AbstractVector, variable_sizes::NamedTuple) From 8962d40626a255eabe09ff37adb2436d5975a93b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 9 Sep 2024 10:46:52 +0100 Subject: [PATCH 30/56] tidy up --- test/gibbs_example/gmm.jl | 10 +--------- test/gibbs_example/hier_normal.jl | 4 ---- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/test/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl index 2d59b0aa..7cfe26f7 100644 --- a/test/gibbs_example/gmm.jl +++ b/test/gibbs_example/gmm.jl @@ -38,7 +38,7 @@ function log_joint(; μ, w, z, x) return logp end -function condition(gmm::GMM, conditioned_values::NamedTuple) +function AbstractMCMC.condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end @@ -69,11 +69,3 @@ end function LogDensityProblems.capabilities(::ConditionedGMM) return LogDensityProblems.LogDensityOrder{0}() end - -function flatten(nt::NamedTuple) - return only(values(nt)) -end - -function unflatten(vec::AbstractVector, group::Tuple) - return NamedTuple((only(group) => vec,)) -end diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index deba5336..6e807b34 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -57,7 +57,3 @@ end function LogDensityProblems.capabilities(::ConditionedHierNormal) return LogDensityProblems.LogDensityOrder{0}() end - -function AbstractMCMC.recompute_logprob!!(hn::ConditionedHierNormal, vals, state) - return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) -end From dc6001c596e7edabce8c4c9ebc81c78cedb4d4e7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 17 Sep 2024 15:36:44 +0100 Subject: [PATCH 31/56] updates --- src/AbstractMCMC.jl | 2 ++ src/gibbs.jl | 11 +---------- test/gibbs_example/mh.jl | 12 +++++++++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 6d18f962..5b0883c1 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -82,6 +82,8 @@ struct MCMCSerial <: AbstractMCMCEnsemble end function condition end +function logdensity_and_state end + include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/gibbs.jl b/src/gibbs.jl index bb5866c1..c0b4a539 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -203,16 +203,7 @@ function AbstractMCMC.step( logdensity_model.logdensity, conditioning_variables_values ) - # recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems - updated_log_prob = LogDensityProblems.logdensity(cond_logdensity, sub_state) - - if !hasproperty(sub_state, :logp) - error( - "$(typeof(sub_state)) does not have a `:logp` field, which is required by Gibbs sampling", - ) - end - sub_state = BangBang.setproperty!!(sub_state, :logp, updated_log_prob) - + _, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state) sub_state = last( AbstractMCMC.step( rng, diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 152ca6cb..88be1e18 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -9,9 +9,15 @@ struct MHTransition{T} params::Vector{T} end -function AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state::MHState) - # recompute the logdensity, instead of using the one stored in the state - return AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) +function AbstractMCMC.logdensity_and_state( + logdensity_function, state::MHState; recompute_logp::Bool=true +) + if recompute_logp + logp, substate = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) + return logp, MHState(substate.params, logp) + else + return state.logp, state + end end function Base.vec(state::MHState) From e1941088edef4e4d86984307cc48398166d740f0 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 17 Sep 2024 22:40:42 +0800 Subject: [PATCH 32/56] Update test/gibbs_example/mh.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/gibbs_example/mh.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 88be1e18..45b4661c 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -13,7 +13,9 @@ function AbstractMCMC.logdensity_and_state( logdensity_function, state::MHState; recompute_logp::Bool=true ) if recompute_logp - logp, substate = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) + logp, substate = AbstractMCMC.LogDensityProblems.logdensity( + logdensity_function, state.params + ) return logp, MHState(substate.params, logp) else return state.logp, state From 64eb0e4b69d853852cee68f97c1ad4438f1d1af2 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 17 Sep 2024 15:47:26 +0100 Subject: [PATCH 33/56] fix error --- test/gibbs_example/mh.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 45b4661c..06895ca9 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -13,9 +13,7 @@ function AbstractMCMC.logdensity_and_state( logdensity_function, state::MHState; recompute_logp::Bool=true ) if recompute_logp - logp, substate = AbstractMCMC.LogDensityProblems.logdensity( - logdensity_function, state.params - ) + logp = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) return logp, MHState(substate.params, logp) else return state.logp, state From 9361c3902d19ae5a5549ee96d794786c9c2c6dbd Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 17 Sep 2024 15:54:54 +0100 Subject: [PATCH 34/56] typo fix --- test/gibbs_example/mh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 06895ca9..fae03611 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -14,7 +14,7 @@ function AbstractMCMC.logdensity_and_state( ) if recompute_logp logp = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) - return logp, MHState(substate.params, logp) + return logp, MHState(state.params, logp) else return state.logp, state end From 39c4d8778eeaf8bcf32b67a3e892af4d63aa53fb Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 18 Sep 2024 21:29:06 +0800 Subject: [PATCH 35/56] Update src/gibbs.jl Co-authored-by: Tor Erlend Fjelde --- src/gibbs.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gibbs.jl b/src/gibbs.jl index c0b4a539..8eb1a07e 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -6,14 +6,8 @@ An interface for block sampling in Markov Chain Monte Carlo (MCMC). Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems. It allows different sampling methods to be applied to different parameters. """ -struct Gibbs <: AbstractMCMC.AbstractSampler - sampler_map::NamedTuple - parameter_names::Tuple{Vararg{Symbol}} - - function Gibbs(sampler_map::NamedTuple) - parameter_names = Tuple(keys(sampler_map)) - return new(sampler_map, parameter_names) - end +struct Gibbs{NT} <: AbstractMCMC.AbstractSampler + sampler_map::NT end struct GibbsState From 7f889cf076f4d132272dc045806baa73fc9cfb24 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Sep 2024 16:05:19 +0100 Subject: [PATCH 36/56] rename gibbs test file to prepare for moving --- src/gibbs.jl | 69 +++++++++---------- .../gibbs_example/{gibbs.jl => gibbs_test.jl} | 0 2 files changed, 34 insertions(+), 35 deletions(-) rename test/gibbs_example/{gibbs.jl => gibbs_test.jl} (100%) diff --git a/src/gibbs.jl b/src/gibbs.jl index 8eb1a07e..817eb89e 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -6,49 +6,50 @@ An interface for block sampling in Markov Chain Monte Carlo (MCMC). Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems. It allows different sampling methods to be applied to different parameters. """ -struct Gibbs{NT} <: AbstractMCMC.AbstractSampler +struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler sampler_map::NT end -struct GibbsState +struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} """ - `trace` contains the values of the values of _all_ parameters up to the last iteration. + Contains the values of all parameters up to the last iteration. """ - trace::NamedTuple + trace::TraceNT """ - `mcmc_states` maps parameters to their sampler-specific MCMC states. + Maps parameters to their sampler-specific MCMC states. """ - mcmc_states::NamedTuple + mcmc_states::StateNT """ - `variable_sizes` maps parameters to their sizes. + Maps parameters to their sizes. """ - variable_sizes::NamedTuple + variable_sizes::SizeNT end -struct GibbsTransition +struct GibbsTransition{ValuesNT<:NamedTuple} """ Realizations of the parameters, this is considered a "sample" in the MCMC chain. """ - values::NamedTuple + values::ValuesNT end """ - flatten(trace::Union{NamedTuple,OrderedCollections.OrderedDict}) + flatten(trace::NamedTuple) -Flatten all the values in the trace into a single vector. +Flatten all the values in the trace into a single vector. Variable names information is discarded. # Examples ```jldoctest; setup = :(using AbstractMCMC: flatten) -julia> flatten((a=[1,2], b=[3,4,5])) -5-element Vector{Int64}: - 1 - 2 - 3 - 4 - 5 +julia> flatten((a=ones(2), b=ones(2, 2))) +6-element Vector{Float64}: + 1.0 + 1.0 + 1.0 + 1.0 + 1.0 + 1.0 ``` """ @@ -57,7 +58,7 @@ function flatten(trace::NamedTuple) end """ - unflatten(vec::AbstractVector, group_names_and_sizes::NamedTuple) + unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) Reverse operation of flatten. Reshape the vector into the original arrays using size information. @@ -71,20 +72,19 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) (x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0]) ``` """ -function unflatten(vec::AbstractVector, variable_sizes::NamedTuple) +function unflatten( + vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} +) where {variable_names} result = Dict{Symbol,Array}() start_idx = 1 - for name in keys(variable_sizes) - size = variable_sizes[name] + for name in variable_names + size = variable_names_and_sizes[name] end_idx = start_idx + prod(size) - 1 result[name] = reshape(vec[start_idx:end_idx], size...) start_idx = end_idx + 1 end - # ensure the order of the keys is the same as the one in variable_sizes - return NamedTuple{Tuple(keys(variable_sizes))}([ - result[name] for name in keys(variable_sizes) - ]) + return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) end """ @@ -95,15 +95,14 @@ Update the trace with the values from the MCMC states of the sub-problems. function update_trace(trace::NamedTuple, gibbs_state::GibbsState) for parameter_variable in keys(gibbs_state.mcmc_states) sub_state = gibbs_state.mcmc_states[parameter_variable] - trace = merge( - trace, - unflatten( - vec(sub_state), - NamedTuple{(parameter_variable,)}(( - gibbs_state.variable_sizes[parameter_variable], - )), - ), + sub_state_params = vec(sub_state) + unflattened_sub_state_params = unflatten( + sub_state_params, + NamedTuple{(parameter_variable,)}(( + gibbs_state.variable_sizes[parameter_variable], + )), ) + trace = merge(trace, unflattened_sub_state_params) end return trace end diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs_test.jl similarity index 100% rename from test/gibbs_example/gibbs.jl rename to test/gibbs_example/gibbs_test.jl From 62a2332347fdb76c6413298fd3af8f2e44a0fbcd Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Sep 2024 16:45:34 +0100 Subject: [PATCH 37/56] move gibbs.jl --- Project.toml | 4 +++- src/AbstractMCMC.jl | 4 ---- {src => test/gibbs_example}/gibbs.jl | 0 3 files changed, 3 insertions(+), 5 deletions(-) rename {src => test/gibbs_example}/gibbs.jl (100%) diff --git a/Project.toml b/Project.toml index 79f1f60f..6c6b39c0 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ desc = "A lightweight interface for common MCMC methods." version = "5.3.0" [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -21,6 +22,7 @@ TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] +AbstractPPL = "0.8" BangBang = "0.3.19, 0.4" ConsoleProgressMonitor = "0.1" FillArrays = "1" @@ -40,4 +42,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "Distributions", "IJulia", "Statistics", "Test"] +test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "Statistics", "Test"] diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 5b0883c1..4c7813ff 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,10 +80,6 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end -function condition end - -function logdensity_and_state end - include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/gibbs.jl b/test/gibbs_example/gibbs.jl similarity index 100% rename from src/gibbs.jl rename to test/gibbs_example/gibbs.jl From 6132f0cf25e39c5dd123e4751306a232542dc9a5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Sep 2024 19:32:09 +0100 Subject: [PATCH 38/56] update code --- test/gibbs_example/gibbs.jl | 65 ++++++++++++++----------------------- test/gibbs_example/mh.jl | 8 ++--- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 817eb89e..44f24811 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,36 +1,28 @@ +using AbstractMCMC, AbstractPPL +using BangBang.ConstructorBase: ConstructorBase + """ Gibbs(sampler_map::NamedTuple) -An interface for block sampling in Markov Chain Monte Carlo (MCMC). - -Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems. -It allows different sampling methods to be applied to different parameters. +A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. """ -struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler - sampler_map::NT +struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + sampler_map::T end struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} - """ - Contains the values of all parameters up to the last iteration. - """ + "Contains the values of all parameters up to the last iteration." trace::TraceNT - """ - Maps parameters to their sampler-specific MCMC states. - """ + "Maps parameters to their sampler-specific MCMC states." mcmc_states::StateNT - """ - Maps parameters to their sizes. - """ + "Maps parameters to their sizes." variable_sizes::SizeNT end struct GibbsTransition{ValuesNT<:NamedTuple} - """ - Realizations of the parameters, this is considered a "sample" in the MCMC chain. - """ + "Realizations of the parameters, this is considered a \"sample\" in the MCMC chain." values::ValuesNT end @@ -95,7 +87,7 @@ Update the trace with the values from the MCMC states of the sub-problems. function update_trace(trace::NamedTuple, gibbs_state::GibbsState) for parameter_variable in keys(gibbs_state.mcmc_states) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = vec(sub_state) + sub_state_params = Base.vec(sub_state) unflattened_sub_state_params = unflatten( sub_state_params, NamedTuple{(parameter_variable,)}(( @@ -115,21 +107,19 @@ function AbstractMCMC.step( initial_params::NamedTuple, kwargs..., ) - if Set(keys(initial_params)) != Set(sampler.parameter_names) + if Set(keys(initial_params)) != Set(keys(sampler.sampler_map)) throw( ArgumentError( - "initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))", + "initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))", ), ) end - mcmc_states = Dict{Symbol,Any}() - variable_sizes = Dict{Symbol,Tuple}() - for parameter_variable in sampler.parameter_names + mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] variables_to_be_conditioned_on = setdiff( - sampler.parameter_names, (parameter_variable,) + keys(sampler.sampler_map), (parameter_variable,) ) conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) @@ -141,7 +131,6 @@ function AbstractMCMC.step( # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable])) flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) sub_state = last( @@ -158,11 +147,13 @@ function AbstractMCMC.step( kwargs..., ), ) - mcmc_states[parameter_variable] = sub_state + (sub_state, Tuple(size(initial_params[parameter_variable]))) end gibbs_state = GibbsState( - initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes) + initial_params, + NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states), + NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes), ) trace = update_trace(NamedTuple(), gibbs_state) return GibbsTransition(trace), gibbs_state @@ -176,14 +167,9 @@ function AbstractMCMC.step( args...; kwargs..., ) - trace = gibbs_state.trace - mcmc_states = gibbs_state.mcmc_states - variable_sizes = gibbs_state.variable_sizes + (; trace, mcmc_states, variable_sizes) = gibbs_state - mcmc_states_dict = Dict( - keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)] - ) - for parameter_variable in sampler.parameter_names + mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] sub_state = mcmc_states[parameter_variable] variables_to_be_conditioned_on = setdiff( @@ -196,7 +182,8 @@ function AbstractMCMC.step( logdensity_model.logdensity, conditioning_variables_values ) - _, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state) + logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state) + sub_state = constructorof(typeof(sub_state))(; logp=logp) sub_state = last( AbstractMCMC.step( rng, @@ -207,12 +194,10 @@ function AbstractMCMC.step( kwargs..., ), ) - mcmc_states_dict[parameter_variable] = sub_state trace = update_trace(trace, gibbs_state) + sub_state end + mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states) - mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}( - Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)]) - ) return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index fae03611..24f1522b 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -24,12 +24,12 @@ function Base.vec(state::MHState) return state.params end -struct RandomWalkMH <: AbstractMCMC.AbstractSampler - σ::Float64 +struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + σ::T end -struct IndependentMH <: AbstractMCMC.AbstractSampler - proposal_dist::Distributions.Distribution +struct IndependentMH{T} <: AbstractMCMC.AbstractSampler + proposal_dist::T end function AbstractMCMC.step( From af208bc408ff6064a6190c0bfdb90b87c31809a9 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 22 Sep 2024 15:19:40 +0100 Subject: [PATCH 39/56] updates --- test/gibbs_example/hier_normal.jl | 28 ++++++++----- test/gibbs_example/mh.jl | 70 +++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index 6e807b34..9fa49ab3 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -1,14 +1,16 @@ abstract type AbstractHierNormal end -struct HierNormal <: AbstractHierNormal - data::NamedTuple +struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal + data::Tdata end -struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal - data::NamedTuple - conditioned_values::NamedTuple{conditioned_vars} +struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: + AbstractHierNormal + data::Tdata + conditioned_values::Tconditioned_vars end +# `mu` and `tau2` are length-1 vectors to make function log_joint(; mu, tau2, x) # mu is the mean # tau2 is the variance @@ -39,14 +41,18 @@ function AbstractMCMC.condition(hn::HierNormal, conditioned_values::NamedTuple) end function LogDensityProblems.logdensity( - hn::ConditionedHierNormal{names}, params::AbstractVector + hier_normal_model::ConditionedHierNormal{names}, params::AbstractVector ) where {names} - if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 - return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) - elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu - return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) + variable_to_condition = only(names) + data = hier_normal_model.data + conditioned_values = hier_normal_model.conditioned_values + + if variable_to_condition == :mu + return log_joint(; mu=conditioned_values.mu, tau2=params, x=data.x) + elseif variable_to_condition == :tau2 + return log_joint(; mu=params, tau2=conditioned_values.tau2, x=data.x) else - error("Unsupported conditioning configuration.") + error("Unsupported conditioning variable: $variable_to_condition") end end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 24f1522b..93bbc555 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -1,5 +1,7 @@ using Distributions +abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end + struct MHState{T} params::Vector{T} logp::Float64 @@ -9,65 +11,89 @@ struct MHTransition{T} params::Vector{T} end +# Interface 1: LogDensityProblems.logdensity +# This function takes the logdensity function and the state (state is defined by the sampler package) +# and returns the logdensity. It allows for optional recomputation of the log probability. +# If recomputation is not needed, it returns the stored log probability from the state. function AbstractMCMC.logdensity_and_state( - logdensity_function, state::MHState; recompute_logp::Bool=true + logdensity_function, state::MHState; recompute_logp=true ) - if recompute_logp - logp = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) - return logp, MHState(state.params, logp) + return if recompute_logp + AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) else - return state.logp, state + state.logp end end +# Interface 2: Base.vec +# This function takes a state and returns a vector of the parameter values stored in the state. +# It is part of the interface for interacting with the state object. function Base.vec(state::MHState) return state.params end -struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler +""" + RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + +A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ +is the standard deviation of the proposal distribution. +""" +struct RandomWalkMH{T} <: AbstractMHSampler σ::T end -struct IndependentMH{T} <: AbstractMCMC.AbstractSampler +""" + IndependentMH{T} <: AbstractMCMC.AbstractSampler + +A Metropolis-Hastings sampler with an independent proposal distribution. +""" +struct IndependentMH{T} <: AbstractMHSampler proposal_dist::T end +# the first step of the sampler function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Union{RandomWalkMH,IndependentMH}, + sampler::AbstractMHSampler, args...; initial_params, kwargs..., ) - return MHTransition(initial_params), - MHState( - initial_params, - only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), - ) + logdensity_function = logdensity_model.logdensity + transition = MHTransition(initial_params) + state = MHState(initial_params, only(logdensity_function(initial_params))) + + return transition, state end +@inline proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = + MvNormal(current_params, sampler.σ) +@inline proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = + sampler.proposal_dist + +# the subsequent steps of the sampler function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Union{RandomWalkMH,IndependentMH}, + sampler::AbstractMHSampler, state::MHState, args...; kwargs..., ) - params = state.params - proposal_dist = - sampler isa RandomWalkMH ? MvNormal(state.params, sampler.σ) : sampler.proposal_dist - proposal = rand(rng, proposal_dist) + logdensity_function = logdensity_model.logdensity + current_params = state.params + proposal_dist = proposal_dist(sampler, current_params) + proposed_params = rand(rng, proposal_dist) logp_proposal = only( - LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + LogDensityProblems.logdensity(logdensity_function, proposed_params) ) if log(rand(rng)) < - compute_log_acceptance_ratio(sampler, state, proposal, logp_proposal) - return MHTransition(proposal), MHState(proposal, logp_proposal) + compute_log_acceptance_ratio(sampler, state, proposed_params, logp_proposal) + return MHTransition(proposed_params), MHState(proposed_params, logp_proposal) else - return MHTransition(params), MHState(params, state.logp) + return MHTransition(current_params), MHState(current_params, state.logp) end end From fd472dfee65450277eeb5c36a78687e13f7fcfda Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 22 Sep 2024 16:29:32 +0100 Subject: [PATCH 40/56] rework the code; still not type stable --- Project.toml | 6 ++- src/AbstractMCMC.jl | 1 - test/gibbs_example/gibbs.jl | 63 +++++++++++++++++++------------ test/gibbs_example/hier_normal.jl | 13 +++++-- test/gibbs_example/mh.jl | 25 +++++++----- test/runtests.jl | 2 +- 6 files changed, 69 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 6c6b39c0..a30156a4 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ desc = "A lightweight interface for common MCMC methods." version = "5.3.0" [deps] -AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -28,6 +27,7 @@ ConsoleProgressMonitor = "0.1" FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" +MCMCChains = "6" ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" @@ -35,11 +35,13 @@ Transducers = "0.4.30" julia = "1.6" [extras] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "Statistics", "Test"] +test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "MCMCChains", "Statistics", "Test"] diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 4c7813ff..dc464d42 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -87,6 +87,5 @@ include("sample.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") -include("gibbs.jl") end # module AbstractMCMC diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 44f24811..72f03828 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,5 +1,7 @@ -using AbstractMCMC, AbstractPPL -using BangBang.ConstructorBase: ConstructorBase +using AbstractMCMC: AbstractMCMC +using AbstractPPL: AbstractPPL +using MCMCChains: Chains +using Random """ Gibbs(sampler_map::NamedTuple) @@ -99,27 +101,34 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState) return trace end +function error_if_not_fully_initialized( + initial_params::NamedTuple{ParamNames}, sampler::Gibbs{<:NamedTuple{SamplerNames}} +) where {ParamNames,SamplerNames} + if Set(ParamNames) != Set(SamplerNames) + throw( + ArgumentError( + "initial_params must contain all parameters in the model, expected $(SamplerNames), got $(ParamNames)", + ), + ) + end +end + function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs, + sampler::Gibbs{Tsamplingmap}, args...; initial_params::NamedTuple, kwargs..., -) - if Set(keys(initial_params)) != Set(keys(sampler.sampler_map)) - throw( - ArgumentError( - "initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))", - ), - ) - end +) where {Tsamplingmap} + error_if_not_fully_initialized(initial_params, sampler) - mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable + model_parameter_names = fieldnames(Tsamplingmap) + results = map(model_parameter_names) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] variables_to_be_conditioned_on = setdiff( - keys(sampler.sampler_map), (parameter_variable,) + model_parameter_names, (parameter_variable,) ) conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) @@ -137,7 +146,7 @@ function AbstractMCMC.step( AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel( - AbstractMCMC.condition( + AbstractPPL.condition( logdensity_model.logdensity, conditioning_variables_values ), ), @@ -150,40 +159,46 @@ function AbstractMCMC.step( (sub_state, Tuple(size(initial_params[parameter_variable]))) end + mcmc_states = first.(results) + variable_sizes = last.(results) + gibbs_state = GibbsState( initial_params, - NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states), - NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes), + NamedTuple{Tuple(model_parameter_names)}(mcmc_states), + NamedTuple{Tuple(model_parameter_names)}(variable_sizes), ) + trace = update_trace(NamedTuple(), gibbs_state) return GibbsTransition(trace), gibbs_state end +# subsequent steps function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs, + sampler::Gibbs{Tsamplingmap}, gibbs_state::GibbsState, args...; kwargs..., -) +) where {Tsamplingmap} (; trace, mcmc_states, variable_sizes) = gibbs_state - mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable + model_parameter_names = fieldnames(Tsamplingmap) + mcmc_states = map(model_parameter_names) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] sub_state = mcmc_states[parameter_variable] variables_to_be_conditioned_on = setdiff( - sampler.parameter_names, (parameter_variable,) + model_parameter_names, (parameter_variable,) ) conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([trace[g] for g in variables_to_be_conditioned_on]) ) - cond_logdensity = AbstractMCMC.condition( + cond_logdensity = AbstractPPL.condition( logdensity_model.logdensity, conditioning_variables_values ) - logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state) - sub_state = constructorof(typeof(sub_state))(; logp=logp) + logp = LogDensityProblems.logdensity(cond_logdensity, sub_state) + sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( rng, @@ -197,7 +212,7 @@ function AbstractMCMC.step( trace = update_trace(trace, gibbs_state) sub_state end - mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states) + mcmc_states = NamedTuple{Tuple(model_parameter_names)}(mcmc_states) return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) end diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index 9fa49ab3..2f58bf1e 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -1,3 +1,5 @@ +using AbstractPPL: AbstractPPL + abstract type AbstractHierNormal end struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal @@ -7,6 +9,8 @@ end struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: AbstractHierNormal data::Tdata + + " The variable to be conditioned on and its value" conditioned_values::Tconditioned_vars end @@ -36,14 +40,15 @@ function log_joint(; mu, tau2, x) return logp end -function AbstractMCMC.condition(hn::HierNormal, conditioned_values::NamedTuple) +function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) return ConditionedHierNormal(hn.data, conditioned_values) end function LogDensityProblems.logdensity( - hier_normal_model::ConditionedHierNormal{names}, params::AbstractVector -) where {names} - variable_to_condition = only(names) + hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, + params::AbstractVector, +) where {Tdata,Tconditioned_vars} + variable_to_condition = only(fieldnames(Tconditioned_vars)) data = hier_normal_model.data conditioned_values = hier_normal_model.conditioned_values diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index 93bbc555..a8a3240b 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -1,5 +1,6 @@ +using AbstractMCMC: AbstractMCMC, LogDensityProblems using Distributions - +using Random abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end struct MHState{T} @@ -7,6 +8,11 @@ struct MHState{T} logp::Float64 end +# Interface 3: (state::MHState)(logp::Float64) +# This function allows the state to be updated with a new log probability. +# ! this makes state into a Julia functor +(state::MHState)(logp::Float64) = MHState(state.params, logp) + struct MHTransition{T} params::Vector{T} end @@ -15,7 +21,7 @@ end # This function takes the logdensity function and the state (state is defined by the sampler package) # and returns the logdensity. It allows for optional recomputation of the log probability. # If recomputation is not needed, it returns the stored log probability from the state. -function AbstractMCMC.logdensity_and_state( +function LogDensityProblems.logdensity( logdensity_function, state::MHState; recompute_logp=true ) return if recompute_logp @@ -28,9 +34,7 @@ end # Interface 2: Base.vec # This function takes a state and returns a vector of the parameter values stored in the state. # It is part of the interface for interacting with the state object. -function Base.vec(state::MHState) - return state.params -end +Base.vec(state::MHState) = state.params """ RandomWalkMH{T} <: AbstractMCMC.AbstractSampler @@ -62,14 +66,17 @@ function AbstractMCMC.step( ) logdensity_function = logdensity_model.logdensity transition = MHTransition(initial_params) - state = MHState(initial_params, only(logdensity_function(initial_params))) + state = MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_function, initial_params)), + ) return transition, state end -@inline proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = +@inline get_proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = MvNormal(current_params, sampler.σ) -@inline proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = +@inline get_proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = sampler.proposal_dist # the subsequent steps of the sampler @@ -83,7 +90,7 @@ function AbstractMCMC.step( ) logdensity_function = logdensity_model.logdensity current_params = state.params - proposal_dist = proposal_dist(sampler, current_params) + proposal_dist = get_proposal_dist(sampler, current_params) proposed_params = rand(rng, proposal_dist) logp_proposal = only( LogDensityProblems.logdensity(logdensity_function, proposed_params) diff --git a/test/runtests.jl b/test/runtests.jl index afc804b8..5ecd67d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,5 +24,5 @@ include("utils.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") - include("gibbs_example/gibbs.jl") + include("gibbs_example/gibbs_test.jl") end From 4306aee03ed4e9bde05fde229563c6e2d7f63c6b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 22 Sep 2024 19:39:21 +0100 Subject: [PATCH 41/56] fix test --- test/gibbs_example/gibbs.jl | 8 ++++---- test/gibbs_example/gibbs_test.jl | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 72f03828..1c4c2cf4 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -159,13 +159,13 @@ function AbstractMCMC.step( (sub_state, Tuple(size(initial_params[parameter_variable]))) end - mcmc_states = first.(results) - variable_sizes = last.(results) + mcmc_states_tuple = first.(results) + variable_sizes_tuple = last.(results) gibbs_state = GibbsState( initial_params, - NamedTuple{Tuple(model_parameter_names)}(mcmc_states), - NamedTuple{Tuple(model_parameter_names)}(variable_sizes), + NamedTuple{Tuple(model_parameter_names)}(mcmc_states_tuple), + NamedTuple{Tuple(model_parameter_names)}(variable_sizes_tuple), ) trace = update_trace(NamedTuple(), gibbs_state) diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl index 87286b5a..7f1e6d1e 100644 --- a/test/gibbs_example/gibbs_test.jl +++ b/test/gibbs_example/gibbs_test.jl @@ -1,3 +1,4 @@ +include("gibbs.jl") include("mh.jl") # include("gmm.jl") include("hier_normal.jl") @@ -14,7 +15,7 @@ include("hier_normal.jl") samples = sample( hn, - AbstractMCMC.Gibbs(( + Gibbs(( mu=RandomWalkMH(0.3), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), )), From b798b2eb2a49a92a3fc3170c19cadd73fb38b884 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 22 Sep 2024 20:10:43 +0100 Subject: [PATCH 42/56] update doc -- need proofread --- docs/src/gibbs.md | 453 +++++++++++++++++++------------ test/gibbs_example/gibbs.jl | 12 +- test/gibbs_example/gibbs_test.jl | 4 +- test/gibbs_example/mh.jl | 4 +- 4 files changed, 292 insertions(+), 181 deletions(-) diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index 4b9ef8f9..bb2b3455 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -2,13 +2,26 @@ We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: -```@doc -get_logprob -set_logprob!! -get_params -set_params!! +```julia +LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true) +``` + +This function takes the logdensity model and the state, and returns the log probability of the state. +If `recompute_logp` is `true`, it should recompute the log probability of the state. +Otherwise, it should use the log probability stored in the state. + +```julia +Base.vec(state) +``` + +This function takes the state and returns a vector of the parameter values stored in the state. + +```julia +(state::StateType)(logp::Float64) ``` +This function takes the state and a log probability value, and updates the state with the new log probability. + These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose. ## Using the `state` Interface for block sampling within Gibbs @@ -59,17 +72,20 @@ To make using `LogDensityProblems` interface, we create a simple type for this m ```julia abstract type AbstractHierNormal end -struct HierNormal <: AbstractHierNormal - data::NamedTuple +struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal + data::Tdata end -struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal - data::NamedTuple - conditioned_values::NamedTuple{conditioned_vars} +struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: + AbstractHierNormal + data::Tdata + + " The variable to be conditioned on and its value" + conditioned_values::Tconditioned_vars end ``` -where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and +where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and ```julia function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) @@ -81,14 +97,19 @@ then we can simply write down the `LogDensityProblems` interface for this model. ```julia function LogDensityProblems.logdensity( - hn::ConditionedHierNormal{names}, params::AbstractVector -) where {names} - if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 - return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) - elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu - return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) + hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, + params::AbstractVector, +) where {Tdata,Tconditioned_vars} + variable_to_condition = only(fieldnames(Tconditioned_vars)) + data = hier_normal_model.data + conditioned_values = hier_normal_model.conditioned_values + + if variable_to_condition == :mu + return log_joint(; mu=conditioned_values.mu, tau2=params, x=data.x) + elseif variable_to_condition == :tau2 + return log_joint(; mu=params, tau2=conditioned_values.tau2, x=data.x) else - error("Unsupported conditioning configuration.") + error("Unsupported conditioning variable: $variable_to_condition") end end @@ -101,34 +122,6 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal) end ``` -the model should also define a function that allows the recomputation of the log probability given a sampler state. -The reason for this is that, when we break down the joint probability into conditional probabilities, individual conditional probability problems are conditional on the values of the other variables. -Between the Gibbs sampler sweeps, the values of the variables may change, and we need to recompute the log probability of the current state. - -A recomputation function could use the `state` interface to return a new state with the updated log probability. -E.g. - -```julia -function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) - return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) -end -``` - -where the model doesn't need to know the details of the `state` type, as long as it can access the `log_joint` function. - -Additionally, we define a couple of helper functions to transform between the sampler representation and the model representation of the parameters values. -In this simple example, the model representation is a vector, and the sampler representation is a named tuple. - -```julia -function flatten(nt::NamedTuple) - return only(values(nt)) -end - -function unflatten(vec::AbstractVector, group::Tuple) - return NamedTuple((only(group) => vec,)) -end -``` - ## Sampler Packages To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. @@ -148,216 +141,329 @@ struct MHState{T} end ``` -Next we define the four `state` interface functions. +Next we define the `state` interface functions mentioned at the beginning of this section. ```julia -AbstractMCMC.get_params(state::MHState) = state.params -AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) -AbstractMCMC.get_logprob(state::MHState) = state.logp -AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) -``` +# Interface 1: LogDensityProblems.logdensity +# This function takes the logdensity function and the state (state is defined by the sampler package) +# and returns the logdensity. It allows for optional recomputation of the log probability. +# If recomputation is not needed, it returns the stored log probability from the state. +function LogDensityProblems.logdensity( + logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true +) + logdensity_function = logdensity_model.logdensity + return if recompute_logp + AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) + else + state.logp + end +end -These are the functions that was used in the `recompute_logprob!!` function above. +# Interface 2: Base.vec +# This function takes a state and returns a vector of the parameter values stored in the state. +# It is part of the interface for interacting with the state object. +Base.vec(state::MHState) = state.params -It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state. +# Interface 3: (state::MHState)(logp::Float64) +# This function allows the state to be updated with a new log probability. +# ! this makes state into a Julia functor +(state::MHState)(logp::Float64) = MHState(state.params, logp) +``` + +It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. ```julia -struct RandomWalkMH <: AbstractMCMC.AbstractSampler - σ::Float64 +""" + RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + +A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ +is the standard deviation of the proposal distribution. +""" +struct RandomWalkMH{T} <: AbstractMHSampler + σ::T +end + +""" + IndependentMH{T} <: AbstractMCMC.AbstractSampler + +A Metropolis-Hastings sampler with an independent proposal distribution. +""" +struct IndependentMH{T} <: AbstractMHSampler + proposal_dist::T end +# the first step of the sampler function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RandomWalkMH, + sampler::AbstractMHSampler, args...; initial_params, kwargs..., ) - return MHTransition(initial_params), - MHState( + logdensity_function = logdensity_model.logdensity + transition = MHTransition(initial_params) + state = MHState( initial_params, - only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + only(LogDensityProblems.logdensity(logdensity_function, initial_params)), ) + + return transition, state end +@inline get_proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = + MvNormal(current_params, sampler.σ) +@inline get_proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = + sampler.proposal_dist + +# the subsequent steps of the sampler function AbstractMCMC.step( rng::AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::RandomWalkMH, + sampler::AbstractMHSampler, state::MHState, args...; kwargs..., ) - params = state.params - proposal_dist = MvNormal(zeros(length(params)), sampler.σ) - proposal = params .+ rand(rng, proposal_dist) + logdensity_function = logdensity_model.logdensity + current_params = state.params + proposal_dist = get_proposal_dist(sampler, current_params) + proposed_params = rand(rng, proposal_dist) logp_proposal = only( - LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + LogDensityProblems.logdensity(logdensity_function, proposed_params) ) - log_acceptance_ratio = min(0, logp_proposal - get_logprob(state)) - - if log(rand(rng)) < log_acceptance_ratio - return MHTransition(proposal), MHState(proposal, logp_proposal) + if log(rand(rng)) < + compute_log_acceptance_ratio(sampler, state, proposed_params, logp_proposal) + return MHTransition(proposed_params), MHState(proposed_params, logp_proposal) else - return MHTransition(params), MHState(params, get_logprob(state)) + return MHTransition(current_params), MHState(current_params, state.logp) end end -``` -```julia -struct IndependentMH <: AbstractMCMC.AbstractSampler - prior_dist::Distribution -end - -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - sampler::IndependentMH, - args...; - initial_params, - kwargs..., +function compute_log_acceptance_ratio( + ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 ) - return MHTransition(initial_params), - MHState( - initial_params, - only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), - ) + return min(0, logp_proposal - state.logp) end -function AbstractMCMC.step( - rng::AbstractRNG, - logdensity_model::AbstractMCMC.LogDensityModel, - sampler::IndependentMH, - state::MHState, - args...; - kwargs..., -) - params = get_params(state) - proposal_dist = sampler.prior_dist - proposal = rand(rng, proposal_dist) - logp_proposal = only( - LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) - ) - - log_acceptance_ratio = min( +function compute_log_acceptance_ratio( + sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64 +) where {T} + return min( 0, - logp_proposal - get_logprob(state) + logpdf(proposal_dist, params) - - logpdf(proposal_dist, proposal), + logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) - + logpdf(sampler.proposal_dist, proposal), ) - - if log(rand(rng)) < log_acceptance_ratio - return MHTransition(proposal), MHState(proposal, logp_proposal) - else - return MHTransition(params), MHState(params, get_logprob(state)) - end end ``` At last, we can proceed to implement the Gibbs sampler. ```julia -struct Gibbs <: AbstractMCMC.AbstractSampler - sampler_map::OrderedDict +""" + Gibbs(sampler_map::NamedTuple) + +A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. +""" +struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + sampler_map::T +end + +struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} + "Contains the values of all parameters up to the last iteration." + trace::TraceNT + + "Maps parameters to their sampler-specific MCMC states." + mcmc_states::StateNT + + "Maps parameters to their sizes." + variable_sizes::SizeNT end -struct GibbsState - vi::NamedTuple - states::OrderedDict +struct GibbsTransition{ValuesNT<:NamedTuple} + "Realizations of the parameters, this is considered a \"sample\" in the MCMC chain." + values::ValuesNT end -struct GibbsTransition - values::NamedTuple +""" + update_trace(trace::NamedTuple, gibbs_state::GibbsState) + +Update the trace with the values from the MCMC states of the sub-problems. +""" +function update_trace(trace::NamedTuple, gibbs_state::GibbsState) + for parameter_variable in keys(gibbs_state.mcmc_states) + sub_state = gibbs_state.mcmc_states[parameter_variable] + sub_state_params = Base.vec(sub_state) + unflattened_sub_state_params = unflatten( + sub_state_params, + NamedTuple{(parameter_variable,)}(( + gibbs_state.variable_sizes[parameter_variable], + )), + ) + trace = merge(trace, unflattened_sub_state_params) + end + return trace +end + +function error_if_not_fully_initialized( + initial_params::NamedTuple{ParamNames}, sampler::Gibbs{<:NamedTuple{SamplerNames}} +) where {ParamNames,SamplerNames} + if Set(ParamNames) != Set(SamplerNames) + throw( + ArgumentError( + "initial_params must contain all parameters in the model, expected $(SamplerNames), got $(ParamNames)", + ), + ) + end end function AbstractMCMC.step( - rng::AbstractRNG, + rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - spl::Gibbs, + sampler::Gibbs{Tsamplingmap}, args...; initial_params::NamedTuple, kwargs..., -) - states = OrderedDict() - for group in keys(spl.sampler_map) - sub_spl = spl.sampler_map[group] +) where {Tsamplingmap} + error_if_not_fully_initialized(initial_params, sampler) - vars_to_be_conditioned_on = setdiff(keys(initial_params), group) - cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( - Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) + model_parameter_names = fieldnames(Tsamplingmap) + results = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) ) - params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) + sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( + initial_params[parameter_variable], + )) + + # LogDensityProblems' `logdensity` function expects a single vector of real numbers + # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values + # and unflatten after the sampling step + flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + sub_state = last( AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel( - condition(logdensity_model.logdensity, cond_val) + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), ), - sub_spl, + sub_sampler, args...; - initial_params=flatten(params_val), + initial_params=flattened_sub_problem_parameters_values, kwargs..., ), ) - states[group] = sub_state + (sub_state, Tuple(size(initial_params[parameter_variable]))) end - return GibbsTransition(initial_params), GibbsState(initial_params, states) + + mcmc_states_tuple = first.(results) + variable_sizes_tuple = last.(results) + + gibbs_state = GibbsState( + initial_params, + NamedTuple{Tuple(model_parameter_names)}(mcmc_states_tuple), + NamedTuple{Tuple(model_parameter_names)}(variable_sizes_tuple), + ) + + trace = update_trace(NamedTuple(), gibbs_state) + return GibbsTransition(trace), gibbs_state end +# subsequent steps function AbstractMCMC.step( - rng::AbstractRNG, + rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - spl::Gibbs, - state::GibbsState, + sampler::Gibbs{Tsamplingmap}, + gibbs_state::GibbsState, args...; kwargs..., -) - vi = state.vi - for group in keys(spl.sampler_map) - for (group, sub_state) in state.states - vi = merge(vi, unflatten(get_params(sub_state), group)) - end - sub_spl = spl.sampler_map[group] - sub_state = state.states[group] - group_complement = setdiff(keys(vi), group) - cond_val = NamedTuple{Tuple(group_complement)}( - Tuple([vi[g] for g in group_complement]) +) where {Tsamplingmap} + (; trace, mcmc_states, variable_sizes) = gibbs_state + + model_parameter_names = fieldnames(Tsamplingmap) + mcmc_states = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + sub_state = mcmc_states[parameter_variable] + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([trace[g] for g in variables_to_be_conditioned_on]) ) - cond_logdensity = condition(logdensity_model.logdensity, cond_val) - sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state) + cond_logdensity = AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ) + cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) + + logp = LogDensityProblems.logdensity(cond_logdensity_model, sub_state) + sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, - AbstractMCMC.LogDensityModel(cond_logdensity), - sub_spl, - sub_state, - args...; - kwargs..., + rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... ), ) - state.states[group] = sub_state + trace = update_trace(trace, gibbs_state) + sub_state end - for (group, sub_state) in state.states - vi = merge(vi, unflatten(get_params(sub_state), group)) + mcmc_states = NamedTuple{Tuple(model_parameter_names)}(mcmc_states) + + return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) +end +``` + +where we use two utility functions `flatten` and `unflatten` to convert between the single vector of real numbers and the named tuple of parameters. + +```julia +""" + flatten(trace::NamedTuple) + +Flatten all the values in the trace into a single vector. Variable names information is discarded. +""" +function flatten(trace::NamedTuple) + return reduce(vcat, vec.(values(trace))) +end + +""" + unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) + +Reverse operation of flatten. Reshape the vector into the original arrays using size information. +""" +function unflatten( + vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} +) where {variable_names} + result = Dict{Symbol,Array}() + start_idx = 1 + for name in variable_names + size = variable_names_and_sizes[name] + end_idx = start_idx + prod(size) - 1 + result[name] = reshape(vec[start_idx:end_idx], size...) + start_idx = end_idx + 1 end - return GibbsTransition(vi), GibbsState(vi, state.states) + + return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) end ``` Some points worth noting: -1. We are using `OrderedDict` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. +1. We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. 2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration. 3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem. 4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following: - - first update the values from the last step of the sweep into the `vi`, which stores the values of all variables at the moment of the Gibbs sweep. - condition on the values of all variables that are not in the current group - recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed - perform a step of the sampler for the conditional probability problem, and update the sampler state - update the `vi` with the new values from the sampler state -Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. +The `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. Now we can use the Gibbs sampler to sample from the hierarchical normal model. @@ -382,13 +488,11 @@ Using Gibbs sampling allows us to use random walk MH for `mu` and prior MH for ` ```julia samples = sample( hn, - Gibbs( - OrderedDict( - (:mu,) => RandomWalkMH(1), - (:tau2,) => IndependentMH(product_distribution([InverseGamma(1, 1)])), - ), - ), - 100000; + Gibbs(( + mu=RandomWalkMH(0.3), + tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), + )), + 10000; initial_params=(mu=[0.0], tau2=[1.0]), ) ``` @@ -401,4 +505,13 @@ tau2_samples = [sample.values.tau2 for sample in samples][20001:end] mean(mu_samples) mean(tau2_samples) +(mu_mean, tau2_mean) +``` + +the result should looks like: + +```julia +(4.995812149309413, 1.9372372289677886) ``` + +which is close to the true values `(5, 2)`. diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 1c4c2cf4..3e8efcf9 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,5 +1,5 @@ using AbstractMCMC: AbstractMCMC -using AbstractPPL: AbstractPPL +using AbstractPPL: AbstractPPL using MCMCChains: Chains using Random @@ -196,17 +196,13 @@ function AbstractMCMC.step( cond_logdensity = AbstractPPL.condition( logdensity_model.logdensity, conditioning_variables_values ) + cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) - logp = LogDensityProblems.logdensity(cond_logdensity, sub_state) + logp = LogDensityProblems.logdensity(cond_logdensity_model, sub_state) sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, - AbstractMCMC.LogDensityModel(cond_logdensity), - sub_sampler, - sub_state, - args...; - kwargs..., + rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... ), ) trace = update_trace(trace, gibbs_state) diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl index 7f1e6d1e..eedb41d1 100644 --- a/test/gibbs_example/gibbs_test.jl +++ b/test/gibbs_example/gibbs_test.jl @@ -19,11 +19,11 @@ include("hier_normal.jl") mu=RandomWalkMH(0.3), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), )), - 200000; + 10000; initial_params=(mu=[0.0], tau2=[1.0]), ) - warmup = 40000 + warmup = 5000 thin = 10 thinned_samples = samples[(warmup + 1):thin:end] mu_samples = [sample.values.mu for sample in thinned_samples] diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index a8a3240b..ab8c11d4 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -1,6 +1,7 @@ using AbstractMCMC: AbstractMCMC, LogDensityProblems using Distributions using Random + abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end struct MHState{T} @@ -22,8 +23,9 @@ end # and returns the logdensity. It allows for optional recomputation of the log probability. # If recomputation is not needed, it returns the stored log probability from the state. function LogDensityProblems.logdensity( - logdensity_function, state::MHState; recompute_logp=true + logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true ) + logdensity_function = logdensity_model.logdensity return if recompute_logp AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) else From 3ed5cb3d1bc32aebf81e00610a7bd7571fd5390a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 22 Sep 2024 21:00:16 +0100 Subject: [PATCH 43/56] fix 1.6 struct field splatting compat issue --- test/gibbs_example/gibbs.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 3e8efcf9..0795e36a 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -181,7 +181,9 @@ function AbstractMCMC.step( args...; kwargs..., ) where {Tsamplingmap} - (; trace, mcmc_states, variable_sizes) = gibbs_state + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes model_parameter_names = fieldnames(Tsamplingmap) mcmc_states = map(model_parameter_names) do parameter_variable From 6fde1980521a43ae9d0059e2ce40ce1b62965923 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 27 Sep 2024 17:01:29 +0100 Subject: [PATCH 44/56] update code and doc --- docs/src/gibbs.md | 130 +++++++++++++----------------- test/gibbs_example/gibbs.jl | 107 ++++++------------------ test/gibbs_example/gibbs_test.jl | 2 +- test/gibbs_example/hier_normal.jl | 2 +- 4 files changed, 80 insertions(+), 161 deletions(-) diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index bb2b3455..9de77c0e 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -8,7 +8,7 @@ LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, st This function takes the logdensity model and the state, and returns the log probability of the state. If `recompute_logp` is `true`, it should recompute the log probability of the state. -Otherwise, it should use the log probability stored in the state. +Otherwise, it could use the log probability stored in the state. ```julia Base.vec(state) @@ -20,9 +20,11 @@ This function takes the state and returns a vector of the parameter values store (state::StateType)(logp::Float64) ``` -This function takes the state and a log probability value, and updates the state with the new log probability. +This function takes the state and a log probability value, and returns a new state with the updated log probability. -These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose. +These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. +The interface facilitates the implementation of "meta-algorithms" that combine different samplers. +We will demonstrate how it can be used to implement Gibbs sampling in the following sections. ## Using the `state` Interface for block sampling within Gibbs @@ -122,7 +124,7 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal) end ``` -## Sampler Packages +### Implementing A Sampler with `AbstractMCMC` Interface To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. @@ -258,15 +260,11 @@ function compute_log_acceptance_ratio( end ``` -At last, we can proceed to implement the Gibbs sampler. +At last, we can proceed to implement a very simple Gibbs sampler. ```julia -""" - Gibbs(sampler_map::NamedTuple) - -A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. -""" struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." sampler_map::T end @@ -291,16 +289,18 @@ end Update the trace with the values from the MCMC states of the sub-problems. """ -function update_trace(trace::NamedTuple, gibbs_state::GibbsState) - for parameter_variable in keys(gibbs_state.mcmc_states) +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = Base.vec(sub_state) - unflattened_sub_state_params = unflatten( - sub_state_params, - NamedTuple{(parameter_variable,)}(( - gibbs_state.variable_sizes[parameter_variable], - )), + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) trace = merge(trace, unflattened_sub_state_params) end return trace @@ -321,8 +321,7 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs{Tsamplingmap}, - args...; + sampler::Gibbs{Tsamplingmap}; initial_params::NamedTuple, kwargs..., ) where {Tsamplingmap} @@ -338,30 +337,27 @@ function AbstractMCMC.step( conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) ) - sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( - initial_params[parameter_variable], - )) # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - AbstractPPL.condition( - logdensity_model.logdensity, conditioning_variables_values - ), - ), - sub_sampler, - args...; + sub_logdensity_model, + sub_sampler; initial_params=flattened_sub_problem_parameters_values, kwargs..., ), ) - (sub_state, Tuple(size(initial_params[parameter_variable]))) + (sub_state, size(initial_params[parameter_variable])) end mcmc_states_tuple = first.(results) @@ -382,11 +378,12 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, sampler::Gibbs{Tsamplingmap}, - gibbs_state::GibbsState, - args...; + gibbs_state::GibbsState; kwargs..., ) where {Tsamplingmap} - (; trace, mcmc_states, variable_sizes) = gibbs_state + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes model_parameter_names = fieldnames(Tsamplingmap) mcmc_states = map(model_parameter_names) do parameter_variable @@ -407,7 +404,7 @@ function AbstractMCMC.step( sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... ), ) trace = update_trace(trace, gibbs_state) @@ -419,53 +416,36 @@ function AbstractMCMC.step( end ``` -where we use two utility functions `flatten` and `unflatten` to convert between the single vector of real numbers and the named tuple of parameters. +We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. -```julia -""" - flatten(trace::NamedTuple) - -Flatten all the values in the trace into a single vector. Variable names information is discarded. -""" -function flatten(trace::NamedTuple) - return reduce(vcat, vec.(values(trace))) -end +We uses the `AbstractPPL.condition` to devide the full model into smaller conditional probability problems. +And each conditional probability problem corresponds to a sampler and corresponding state. -""" - unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) +The `Gibbs` sampler has the same interface as other samplers in `AbstractMCMC` (we don't implement the above state interface for `GibbsState` to keep it simple, but it can be implemented similarly). -Reverse operation of flatten. Reshape the vector into the original arrays using size information. -""" -function unflatten( - vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} -) where {variable_names} - result = Dict{Symbol,Array}() - start_idx = 1 - for name in variable_names - size = variable_names_and_sizes[name] - end_idx = start_idx + prod(size) - 1 - result[name] = reshape(vec[start_idx:end_idx], size...) - start_idx = end_idx + 1 - end +The Gibbs sampler operates in two main phases: - return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) -end -``` +1. Initialization: + - Set up initial states for each conditional probability problem. -Some points worth noting: +2. Iterative Sampling: + For each iteration, the sampler performs a sweep over all conditional probability problems: -1. We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. -2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration. -3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem. -4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following: - - condition on the values of all variables that are not in the current group - - recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed - - perform a step of the sampler for the conditional probability problem, and update the sampler state - - update the `vi` with the new values from the sampler state + a. Condition on other variables: + - Fix the values of all variables except the current one. + b. Update current variable: + - Recompute the log probability of the current state, as other variables may have changed: + - Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability. + - Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability. + - Perform a sampling step for the current conditional probability problem: + - Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state. + - Update the global trace: + - Extract parameter values from the new state using `Base.vec(new_sub_state)`. + - Incorporate these values into the overall Gibbs state trace. -The `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. +This process allows the Gibbs sampler to iteratively update each variable while conditioning on the others, gradually exploring the joint distribution of all variables. -Now we can use the Gibbs sampler to sample from the hierarchical normal model. +Now we can use the Gibbs sampler to sample from the hierarchical Normal model. First we generate some data, diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 0795e36a..82d7b5a7 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,14 +1,9 @@ using AbstractMCMC: AbstractMCMC using AbstractPPL: AbstractPPL -using MCMCChains: Chains using Random -""" - Gibbs(sampler_map::NamedTuple) - -A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. -""" struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." sampler_map::T end @@ -28,74 +23,23 @@ struct GibbsTransition{ValuesNT<:NamedTuple} values::ValuesNT end -""" - flatten(trace::NamedTuple) - -Flatten all the values in the trace into a single vector. Variable names information is discarded. - -# Examples - -```jldoctest; setup = :(using AbstractMCMC: flatten) -julia> flatten((a=ones(2), b=ones(2, 2))) -6-element Vector{Float64}: - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 - -``` -""" -function flatten(trace::NamedTuple) - return reduce(vcat, vec.(values(trace))) -end - -""" - unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) - -Reverse operation of flatten. Reshape the vector into the original arrays using size information. - -# Examples - -```jldoctest; setup = :(using AbstractMCMC: unflatten) -julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,))) -(a = [1, 2], b = [3, 4, 5]) - -julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) -(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0]) -``` -""" -function unflatten( - vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} -) where {variable_names} - result = Dict{Symbol,Array}() - start_idx = 1 - for name in variable_names - size = variable_names_and_sizes[name] - end_idx = start_idx + prod(size) - 1 - result[name] = reshape(vec[start_idx:end_idx], size...) - start_idx = end_idx + 1 - end - - return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) -end - """ update_trace(trace::NamedTuple, gibbs_state::GibbsState) Update the trace with the values from the MCMC states of the sub-problems. """ -function update_trace(trace::NamedTuple, gibbs_state::GibbsState) - for parameter_variable in keys(gibbs_state.mcmc_states) +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = Base.vec(sub_state) - unflattened_sub_state_params = unflatten( - sub_state_params, - NamedTuple{(parameter_variable,)}(( - gibbs_state.variable_sizes[parameter_variable], - )), + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) trace = merge(trace, unflattened_sub_state_params) end return trace @@ -116,8 +60,7 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs{Tsamplingmap}, - args...; + sampler::Gibbs{Tsamplingmap}; initial_params::NamedTuple, kwargs..., ) where {Tsamplingmap} @@ -133,30 +76,27 @@ function AbstractMCMC.step( conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) ) - sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( - initial_params[parameter_variable], - )) # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - AbstractPPL.condition( - logdensity_model.logdensity, conditioning_variables_values - ), - ), - sub_sampler, - args...; + sub_logdensity_model, + sub_sampler; initial_params=flattened_sub_problem_parameters_values, kwargs..., ), ) - (sub_state, Tuple(size(initial_params[parameter_variable]))) + (sub_state, size(initial_params[parameter_variable])) end mcmc_states_tuple = first.(results) @@ -177,8 +117,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, sampler::Gibbs{Tsamplingmap}, - gibbs_state::GibbsState, - args...; + gibbs_state::GibbsState; kwargs..., ) where {Tsamplingmap} trace = gibbs_state.trace @@ -204,7 +143,7 @@ function AbstractMCMC.step( sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... ), ) trace = update_trace(trace, gibbs_state) diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl index eedb41d1..48e679c3 100644 --- a/test/gibbs_example/gibbs_test.jl +++ b/test/gibbs_example/gibbs_test.jl @@ -33,7 +33,7 @@ include("hier_normal.jl") tau2_mean = only(mean(tau2_samples)) @test mu_mean ≈ mu_true atol = 0.1 - @test tau2_mean ≈ tau2_true atol = 0.3 + @test tau2_mean ≈ tau2_true atol = 0.1 end # This is too difficult to sample, disable for now diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index 2f58bf1e..2e3e381a 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -15,7 +15,7 @@ struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: end # `mu` and `tau2` are length-1 vectors to make -function log_joint(; mu, tau2, x) +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) # mu is the mean # tau2 is the variance # x is data From c7f577d3d38dc34d7e5b994e8d0dc75ddbeeeacc Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 28 Sep 2024 06:52:01 +0100 Subject: [PATCH 45/56] relax test error --- test/gibbs_example/gibbs_test.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl index 48e679c3..da42fd34 100644 --- a/test/gibbs_example/gibbs_test.jl +++ b/test/gibbs_example/gibbs_test.jl @@ -32,8 +32,8 @@ include("hier_normal.jl") mu_mean = only(mean(mu_samples)) tau2_mean = only(mean(tau2_samples)) - @test mu_mean ≈ mu_true atol = 0.1 - @test tau2_mean ≈ tau2_true atol = 0.1 + @test mu_mean ≈ mu_true rtol = 0.1 + @test tau2_mean ≈ tau2_true rtol = 0.1 end # This is too difficult to sample, disable for now From 8f11a15b1d9ede5267d136afa8223c7cc41a4305 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 28 Sep 2024 06:52:49 +0100 Subject: [PATCH 46/56] rename gibbs markdown file --- docs/make.jl | 2 +- docs/src/{gibbs.md => state_interface.md} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/src/{gibbs.md => state_interface.md} (100%) diff --git a/docs/make.jl b/docs/make.jl index a2adb8e9..adec1df9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=["Home" => "index.md", "api.md", "design.md", "gibbs.md"], + pages=["Home" => "index.md", "api.md", "design.md", "state_interface.md"], checkdocs=:exports, ) diff --git a/docs/src/gibbs.md b/docs/src/state_interface.md similarity index 100% rename from docs/src/gibbs.md rename to docs/src/state_interface.md From 48a160d942c3627c87b8e725a3948d8fe9d916ba Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 28 Sep 2024 06:53:42 +0100 Subject: [PATCH 47/56] change title --- docs/src/state_interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index 9de77c0e..613832d9 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -1,4 +1,4 @@ -# `state` Interface Functions +# Interface For Sampler `state` and Gibbs Sampling We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: From 8d7488900c3992e19ec9430e27f0ee405d76d115 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 18:11:41 +0100 Subject: [PATCH 48/56] update code and note --- design_notes/on_gibbs_implementation.md | 59 +++++++++++++++++++++++++ docs/src/state_interface.md | 32 +++++++------- test/gibbs_example/gibbs.jl | 7 ++- test/gibbs_example/mh.jl | 7 +-- 4 files changed, 83 insertions(+), 22 deletions(-) create mode 100644 design_notes/on_gibbs_implementation.md diff --git a/design_notes/on_gibbs_implementation.md b/design_notes/on_gibbs_implementation.md new file mode 100644 index 00000000..e7976d53 --- /dev/null +++ b/design_notes/on_gibbs_implementation.md @@ -0,0 +1,59 @@ +# On `AbstractMCMC` Interface Supporting `Gibbs` + +This is written at Oct 1st, 2024. Version of packages described in this passage are: + +* `Turing.jl`: 0.34.1 + +In this passage, `Gibbs` refers to `Experimental.Gibbs`. + +## Current Implementation of `Gibbs` in `Turing` + +Here I describe the current implementation of `Gibbs` in `Turing` and the interface it requires from its sampler states. + +### Interface 1: `getparams` + +From the [definition of `GibbsState`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L244-L248), we can see that a `vi::DynamicPPL.AbstractVarInfo` field is used to keep track of the names and values of parameters and the log density. The `states` field collects the sampler-specific *state*s. + +(The *link*ing of *varinfo*s is omitted in this discussion.) +A local `VarInfo` is initially created with `DynamicPPL.subset(::VarInfo, ::Vector{<:VarName})` to make the conditioned model. After the Gibbs step, an updated `varinfo` is obtained by calling `Turing.Inference.varinfo` on the sampler state. + +For samplers and their states defined in `Turing` (including `DynamicHMC`, as `DynamicNUTSState` is defined by `Turing` in the package extension), we (à la `Turing.jl` package) assume that the *state*s all have a field called `vi`. Then `varinfo(_some_sampler_state_)` is simply `varinfo(state) = state.vi` (defined in [`src/mcmc/gibbs.jl`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/gibbs.jl#L97)). (`GibbsState` conforms to this assumption.) + +For `ExternalSamplers`, we currently only support `AdvancedHMC` and `AdvancedMH`. The mechanism is as follows: at the end of the `step` call with an external sampler, [`transition_to_turing` and `state_to_turing` are called](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/abstractmcmc.jl#L147). These two functions then call `getparams` on the sampler state of the external samplers. `getparams` for `AdvancedHMC.HMCState` and `AdvancedMH.Transition` (`AdvancedMH` uses `Transition` as state) are defined in `abstractmcmc.jl`. + +Thus, the first interface emerges: `getparams`. As `getparams` is designed to be implemented by a sampler that works with the `LogDensityProblems` interface, it makes sense for `getparams` to return a vector of `Real`s. The `logdensity_problem` should then be responsible for performing the transformation between its underlying representation and the vector of `Real`s. + +It's worth noting that: + +* `getparams` is not a function specific for `Gibbs`. It is required for the current support of external samplers. +* There is another [`getparams`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/Inference.jl#L328-L351) in `Turing.jl` that takes *model* and *varinfo*, then returns a `NamedTuple`. + +### Interface 2: `recompute_logp!!` + +Consider a model with multiple groups of variables, say $\theta_1, \theta_2, \ldots, \theta_k$. At the beginning of the $t$-th Gibbs step, the model parameters in the `GibbsState` are typically updated and different from the $(t-1)$-th step. The `GibbsState` maintains $k$ sub-states, one for each variable group, denoted as $\text{state}_{t,1}, \text{state}_{t,2}, \ldots, \text{state}_{t,k}$. + +The parameter values in each sub-state, i.e., $\theta_{t,i}$ in $\text{state}_{t,i}$, are always in sync with the corresponding values in the `GibbsState`. At the end of the $t$-th Gibbs step, $\text{state}_{t,i}$ will store the log density of the $i$-th variable group conditioned on all other variable groups at their values from step $t$, denoted as $\log p(\theta_{t,i} \mid \theta_{t,-i})$. This log density is equal to the joint log density of the whole model evaluated at the current parameter values $(\theta_{t,1}, \ldots, \theta_{t,k})$. + +However, the log density stored in each sub-state is in general not equal to the log density needed for the next Gibbs step at $t+1$, i.e., $\log p(\theta_{t,i} \mid \theta_{t+1,-i})$. This is because the values of the other variable groups $\theta_{-i}$ will have been updated in the Gibbs step from $t$ to $t+1$, changing the conditioning set. Therefore, the log density typically needs to be recomputed at each Gibbs step to account for the updated values of the conditioning variables. + +Only in certain special cases, the recomputation can be skipped. For example, in a Metropolis-Hastings step where the proposal is rejected for all other variable groups, i.e., $\theta_{t+1,-i} = \theta_{t,-i}$, the log density $\log p(\theta_{t,i} \mid \theta_{t,-i})$ remains valid and doesn't need to be recomputed. + +The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. It takes an updated conditioned log density function $\log p(\cdot \mid \theta_{t+1,j})$ and the parameter values $\theta_{t,i}$ stored in $\text{state}_{t,i}$ to compute the updated log density $\log p(\theta_{t,i} \mid \theta_{t+1,j})$. + +## Proposed Interface + +The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. + +Here, I propose some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!`, but without introducing new interface functions. + +For `getparams`, I propose we use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. + +For `recompute_logp!!`, I propose we overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. + +While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviate from the interface in `LogDensityProblems`, I believe it provides a clean and extensible solution for handling log probability recomputation within the existing interface. + +An example demonstrating these interfaces is provided in `src/state_interface.md`. + +## A More Standalone `Gibbs` Implementation + +`Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintain a vector of parameter values. while `logdensity_problem` should manage both the name and transformations. diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index 613832d9..a7affe54 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -17,10 +17,10 @@ Base.vec(state) This function takes the state and returns a vector of the parameter values stored in the state. ```julia -(state::StateType)(logp::Float64) +state = StateType(state, logp) ``` -This function takes the state and a log probability value, and returns a new state with the updated log probability. +This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. The interface facilitates the implementation of "meta-algorithms" that combine different samplers. @@ -166,10 +166,12 @@ end # It is part of the interface for interacting with the state object. Base.vec(state::MHState) = state.params -# Interface 3: (state::MHState)(logp::Float64) +# Interface 3: constructorof and MHState(state::MHState, logp::Float64) # This function allows the state to be updated with a new log probability. -# ! this makes state into a Julia functor -(state::MHState)(logp::Float64) = MHState(state.params, logp) +BangBang.constructorof(state::MHState{T}) where {T} = MHState +function MHState(state::MHState, logp::Float64) + return MHState(state.params, logp) +end ``` It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. @@ -400,8 +402,10 @@ function AbstractMCMC.step( ) cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) - logp = LogDensityProblems.logdensity(cond_logdensity_model, sub_state) - sub_state = (sub_state)(logp) + logp = LogDensityProblems.logdensity( + cond_logdensity_model, sub_state; recompute_logp=true + ) + sub_state = constructorof(typeof(sub_state))(sub_state, logp) sub_state = last( AbstractMCMC.step( rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... @@ -449,7 +453,7 @@ Now we can use the Gibbs sampler to sample from the hierarchical Normal model. First we generate some data, -```julia +```@example gibbs_example N = 100 # Number of data points mu_true = 0.5 # True mean tau2_true = 2.0 # True variance @@ -459,13 +463,13 @@ x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) Then we can create a `HierNormal` model, with the data we just generated. -```julia +```@example gibbs_example hn = HierNormal((x=x_data,)) ``` Using Gibbs sampling allows us to use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support only on positive real numbers. -```julia +```@example gibbs_example samples = sample( hn, Gibbs(( @@ -479,7 +483,7 @@ samples = sample( Then we can extract the samples and compute the mean of the samples. -```julia +```@example gibbs_example mu_samples = [sample.values.mu for sample in samples][20001:end] tau2_samples = [sample.values.tau2 for sample in samples][20001:end] @@ -488,10 +492,4 @@ mean(tau2_samples) (mu_mean, tau2_mean) ``` -the result should looks like: - -```julia -(4.995812149309413, 1.9372372289677886) -``` - which is close to the true values `(5, 2)`. diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 82d7b5a7..17bc9552 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,5 +1,6 @@ using AbstractMCMC: AbstractMCMC using AbstractPPL: AbstractPPL +using BangBang: constructorof using Random struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler @@ -139,8 +140,10 @@ function AbstractMCMC.step( ) cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) - logp = LogDensityProblems.logdensity(cond_logdensity_model, sub_state) - sub_state = (sub_state)(logp) + logp = LogDensityProblems.logdensity( + cond_logdensity_model, sub_state; recompute_logp=true + ) + sub_state = constructorof(typeof(sub_state))(sub_state, logp) sub_state = last( AbstractMCMC.step( rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index ab8c11d4..4068268d 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -9,10 +9,11 @@ struct MHState{T} logp::Float64 end -# Interface 3: (state::MHState)(logp::Float64) +# Interface 3: outer constructor that takes a state and a logp # This function allows the state to be updated with a new log probability. -# ! this makes state into a Julia functor -(state::MHState)(logp::Float64) = MHState(state.params, logp) +function MHState(state::MHState, logp::Float64) + return MHState(state.params, logp) +end struct MHTransition{T} params::Vector{T} From bceb510234edba4c55c4f3830c567f23bacdc879 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 18:25:30 +0100 Subject: [PATCH 49/56] fix doc example --- docs/src/state_interface.md | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index a7affe54..60a72af3 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -17,7 +17,7 @@ Base.vec(state) This function takes the state and returns a vector of the parameter values stored in the state. ```julia -state = StateType(state, logp) +state = StateType(state::StateType, logp) ``` This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. @@ -42,7 +42,17 @@ x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. -```julia +```@example gibbs_example +using AbstractMCMC: AbstractMCMC, LogDensityProblems # hide +using Distributions # hide +using Random # hide +using AbstractMCMC: AbstractMCMC # hide +using AbstractPPL: AbstractPPL # hide +using BangBang: constructorof # hide +using AbstractPPL: AbstractPPL +``` + +```@example gibbs_example function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) # mu is the mean # tau2 is the variance @@ -71,7 +81,7 @@ end To make using `LogDensityProblems` interface, we create a simple type for this model. -```julia +```@example gibbs_example abstract type AbstractHierNormal end struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal @@ -89,7 +99,7 @@ end where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and -```julia +```@example gibbs_example function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) return ConditionedHierNormal(hn.data, conditioned_values) end @@ -97,7 +107,7 @@ end then we can simply write down the `LogDensityProblems` interface for this model. -```julia +```@example gibbs_example function LogDensityProblems.logdensity( hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, params::AbstractVector, @@ -132,7 +142,7 @@ Although the interface doesn't force the sampler to implement `Transition` and ` Here we define some bare minimum types to represent the transitions and states. -```julia +```@example gibbs_example struct MHTransition{T} params::Vector{T} end @@ -145,7 +155,7 @@ end Next we define the `state` interface functions mentioned at the beginning of this section. -```julia +```@example gibbs_example # Interface 1: LogDensityProblems.logdensity # This function takes the logdensity function and the state (state is defined by the sampler package) # and returns the logdensity. It allows for optional recomputation of the log probability. @@ -168,7 +178,6 @@ Base.vec(state::MHState) = state.params # Interface 3: constructorof and MHState(state::MHState, logp::Float64) # This function allows the state to be updated with a new log probability. -BangBang.constructorof(state::MHState{T}) where {T} = MHState function MHState(state::MHState, logp::Float64) return MHState(state.params, logp) end @@ -176,7 +185,7 @@ end It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. -```julia +```@example gibbs_example """ RandomWalkMH{T} <: AbstractMCMC.AbstractSampler @@ -264,7 +273,7 @@ end At last, we can proceed to implement a very simple Gibbs sampler. -```julia +```@example gibbs_example struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler "Maps variables to their samplers." sampler_map::T @@ -440,7 +449,7 @@ The Gibbs sampler operates in two main phases: b. Update current variable: - Recompute the log probability of the current state, as other variables may have changed: - Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability. - - Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability. + - Update the state with `sub_state = constructorof(typeof(sub_state))(sub_state, logp)` to incorporate the new log probability. - Perform a sampling step for the current conditional probability problem: - Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state. - Update the global trace: From c177271d7ab121cc7a1264d7006d06fb43d3b697 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 18:32:22 +0100 Subject: [PATCH 50/56] try to fix doc example error --- docs/make.jl | 3 +++ docs/src/state_interface.md | 14 ++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index adec1df9..6fe9b086 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,7 @@ using AbstractMCMC +using AbstractPPL +using BangBang +using Distributions using Documenter using Random diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index 60a72af3..299001cb 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -49,7 +49,6 @@ using Random # hide using AbstractMCMC: AbstractMCMC # hide using AbstractPPL: AbstractPPL # hide using BangBang: constructorof # hide -using AbstractPPL: AbstractPPL ``` ```@example gibbs_example @@ -493,11 +492,14 @@ samples = sample( Then we can extract the samples and compute the mean of the samples. ```@example gibbs_example -mu_samples = [sample.values.mu for sample in samples][20001:end] -tau2_samples = [sample.values.tau2 for sample in samples][20001:end] - -mean(mu_samples) -mean(tau2_samples) +warmup = 5000 +thin = 10 +thinned_samples = samples[(warmup + 1):thin:end] +mu_samples = [sample.values.mu for sample in thinned_samples] +tau2_samples = [sample.values.tau2 for sample in thinned_samples] + +mu_mean = only(mean(mu_samples)) +tau2_mean = only(mean(tau2_samples)) (mu_mean, tau2_mean) ``` From bdba893b5d898740b565622485ab1f2f4258c3e5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 18:36:57 +0100 Subject: [PATCH 51/56] fix doc deps --- docs/Project.toml | 4 ++++ docs/make.jl | 3 --- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..040a68b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,8 @@ [deps] +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/make.jl b/docs/make.jl index 6fe9b086..adec1df9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,4 @@ using AbstractMCMC -using AbstractPPL -using BangBang -using Distributions using Documenter using Random From e7e2870a096b810e7637fb25a8eea57c0320c164 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 21:04:35 +0100 Subject: [PATCH 52/56] fix more doc example error --- docs/src/state_interface.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index 299001cb..782f5bf2 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -185,6 +185,8 @@ end It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. ```@example gibbs_example +abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end + """ RandomWalkMH{T} <: AbstractMCMC.AbstractSampler From 80df1879ba7a225434e345c0f58223a0b706dcf1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 21:14:16 +0100 Subject: [PATCH 53/56] minor update --- docs/src/state_interface.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index 782f5bf2..a6f9032b 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -43,15 +43,13 @@ x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. ```@example gibbs_example -using AbstractMCMC: AbstractMCMC, LogDensityProblems # hide +using AbstractMCMC: AbstractMCMC # hide +using LogDensityProblems # hide using Distributions # hide using Random # hide using AbstractMCMC: AbstractMCMC # hide using AbstractPPL: AbstractPPL # hide using BangBang: constructorof # hide -``` - -```@example gibbs_example function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) # mu is the mean # tau2 is the variance From 076e4318ee1811cdf6c0a02d93b33f4d0497e2a5 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 3 Oct 2024 22:57:34 +0800 Subject: [PATCH 54/56] Apply suggestions from code review Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- design_notes/on_gibbs_implementation.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/design_notes/on_gibbs_implementation.md b/design_notes/on_gibbs_implementation.md index e7976d53..78a42e60 100644 --- a/design_notes/on_gibbs_implementation.md +++ b/design_notes/on_gibbs_implementation.md @@ -44,16 +44,16 @@ The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. -Here, I propose some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!`, but without introducing new interface functions. +Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions. -For `getparams`, I propose we use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. +For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. -For `recompute_logp!!`, I propose we overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. +For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. -While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviate from the interface in `LogDensityProblems`, I believe it provides a clean and extensible solution for handling log probability recomputation within the existing interface. +While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface. An example demonstrating these interfaces is provided in `src/state_interface.md`. ## A More Standalone `Gibbs` Implementation -`Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintain a vector of parameter values. while `logdensity_problem` should manage both the name and transformations. +`AbstractMCMC.Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintains a vector of parameter values, while a higher-level interface like `AbstractPPL` / `DynamicPPL` should manage both the name and transformations. From 4293868c231230a643bd7a371bb7351becafa7e8 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 3 Oct 2024 23:50:34 +0800 Subject: [PATCH 55/56] Update docs/src/state_interface.md Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- docs/src/state_interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index a6f9032b..c9274593 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -8,7 +8,7 @@ LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, st This function takes the logdensity model and the state, and returns the log probability of the state. If `recompute_logp` is `true`, it should recompute the log probability of the state. -Otherwise, it could use the log probability stored in the state. +Otherwise, if available, it will use the log probability stored in the state. ```julia Base.vec(state) From 1cee0ab58c781a71d8e089a1f6c3f5db612a7312 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 3 Oct 2024 23:50:44 +0800 Subject: [PATCH 56/56] Update docs/src/state_interface.md Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- docs/src/state_interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index c9274593..26efb9d4 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -24,7 +24,7 @@ This function takes an existing `state` and a log probability value `logp`, and These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. The interface facilitates the implementation of "meta-algorithms" that combine different samplers. -We will demonstrate how it can be used to implement Gibbs sampling in the following sections. +We will demonstrate this in the following sections. ## Using the `state` Interface for block sampling within Gibbs