From 0bb8fdea519dac7558851753caf28dd2d56b5a6c Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 16 Jun 2022 20:22:00 +0100 Subject: [PATCH] Fix particle gibbs example (#58) * Fix particle gibbs example * Format --- examples/gaussian-ssm/script.jl | 77 +++++++++++++++-------------- examples/particle-gibbs/script.jl | 82 ++++++++++++++++++++++--------- 2 files changed, 101 insertions(+), 58 deletions(-) diff --git a/examples/gaussian-ssm/script.jl b/examples/gaussian-ssm/script.jl index c4694962..cb4cc15f 100644 --- a/examples/gaussian-ssm/script.jl +++ b/examples/gaussian-ssm/script.jl @@ -14,9 +14,10 @@ using Plots # y_{t} = x_{t} + \nu_{t} \quad \nu_{t} \sim \mathcal{N}(0, r^2) # ``` # -# Here we assume the static parameters $\theta = (a, q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$. -# To use particle gibbs with the ancestor sampling step we need to provide both the transition and observation densities. -# From the definition above we get: +# Here we assume the static parameters $\theta = (a, q^2, r^2)$ are known and we are only interested in sampling from the latent states $x_t$. +# To use particle gibbs with the ancestor sampling update step we need to provide both the transition and observation densities. +# +# From the definition above we get: # ```math # x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2) # ``` @@ -25,55 +26,54 @@ using Plots # ``` # as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$. -# We are ready to use `AdvancedPS` with our model. We first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`. +# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`. Parameters = @NamedTuple begin a::Float64 q::Float64 r::Float64 end -mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel +mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel X::Vector{Float64} θ::Parameters - NonLinearTimeSeries(θ::Parameters) = new(Vector{Float64}(), θ) + LinearSSM(θ::Parameters) = new(Vector{Float64}(), θ) end # and the densities defined above. -f(m::NonLinearTimeSeries, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density -g(m::NonLinearTimeSeries, state, t) = Normal(state, m.θ.r) # Observation density -f₀(m::NonLinearTimeSeries) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density +f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density +g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density +f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density #md nothing #hide -# To implement `AdvancedPS.AbstractStateSpaceModel` we need to define a few functions to specify the dynamics of the system: -# - `AdvancedPS.initialization` the initial state density -# - `AdvancedPS.transition` the state transition density -# - `AdvancedPS.observation` the observation score given the observed data -# - `AdvancedPS.isdone` signals the end of the execution for the model -AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model) -AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step) -function AdvancedPS.observation(model::NonLinearTimeSeries, state, step) +# We also need to specify the dynamics of the system through the transition equations: +# - `AdvancedPS.initialization`: the initial state density +# - `AdvancedPS.transition`: the state transition density +# - `AdvancedPS.observation`: the observation score given the observed data +# - `AdvancedPS.isdone`: signals the end of the execution for the model +AdvancedPS.initialization(model::LinearSSM) = f₀(model) +AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step) +function AdvancedPS.observation(model::LinearSSM, state, step) return logpdf(g(model, state, step), y[step]) end -AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ +AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ # Everything is now ready to simulate some data. a = 0.9 # Scale q = 0.32 # State variance r = 1 # Observation variance -Tₘ = 300 # Number of observation -Nₚ = 50 # Number of particles +Tₘ = 200 # Number of observation +Nₚ = 20 # Number of particles Nₛ = 500 # Number of samples -seed = 9 # Reproduce everything +seed = 1 # Reproduce everything θ₀ = Parameters((a, q, r)) - rng = Random.MersenneTwister(seed) x = zeros(Tₘ) y = zeros(Tₘ) -reference = NonLinearTimeSeries(θ₀) +reference = LinearSSM(θ₀) x[1] = rand(rng, f₀(reference)) for t in 1:Tₘ if t < Tₘ @@ -82,35 +82,40 @@ for t in 1:Tₘ y[t] = rand(rng, g(reference, x[t], t)) end -# Let's have a look at the simulated data from the latent state dynamics +# Here are the latent and obseravation timeseries plot(x; label="x") xlabel!("t") -# and the observation data +# plot(y; label="y") xlabel!("x") -# -model = NonLinearTimeSeries(θ₀) +# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel +# and a model interface. +model = LinearSSM(θ₀) pgas = AdvancedPS.PGAS(Nₚ) -chains = sample(rng, model, pgas, Nₛ; progress=false) +chains = sample(rng, model, pgas, Nₛ; progress=false); #md nothing #hide -# The actual sampled trajectory is in the trajectory inner model -particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states -mean_trajectory = mean(particles; dims=2) +# +particles = hcat([chain.trajectory.model.X for chain in chains]...) +mean_trajectory = mean(particles; dims=2); #md nothing #hide -# +# This toy model is small enough to inspect all the generated traces: scatter(particles; label=false, opacity=0.01, color=:black) -plot!(x; color=:red, label="Original Trajectory") -plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9) +plot!(x; color=:darkorange, label="Original Trajectory") +plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) xlabel!("t") ylabel!("State") -# By sampling an ancestor from the reference particle in the particle gibbs sampler we split the -# trajectory of the reference particle. +# We used a particle gibbs kernel with the ancestor updating step which should help with the particle +# degeneracy problem and improve the mixing. +# We can compute the update rate of $x_t$ vs $t$ defined as the proportion of times $t$ where $x_t$ gets updated: update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ +#md nothing #hide + +# and compare it to the theoretical value of $1 - 1/Nₚ$. plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft) hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)") xlabel!("Iteration") diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index 28d934e8..a6b372c7 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -1,4 +1,4 @@ -# # Particle Gibbs with Ancestor Sampling +# # Particle Gibbs for non-linear models using AdvancedPS using Random using Distributions @@ -8,10 +8,27 @@ using Random123 using Libtask: TArray using Libtask +# We consider the following stochastic volatility model: +# +# ```math +# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2) +# ``` +# ```math +# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1) +# ``` +# +# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$. +# We can reformulate the above in terms of transition and observation densities: +# ```math +# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2) +# ``` +# ```math +# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2) +# ``` +# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$. Parameters = @NamedTuple begin a::Float64 q::Float64 - r::Float64 T::Int end @@ -21,22 +38,20 @@ mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel NonLinearTimeSeries(θ::Parameters) = new(TArray(Float64, θ.T), θ) end -f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q) # Transition density -g(model::NonLinearTimeSeries, state, t) = Normal(state, model.θ.r) # Observation density -f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q) # Initial state density +f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q) +g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2) +f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q) +#md nothing #hide -# Everything is now ready to simulate some data. - -a = 0.9 # Scale -q = 0.32 # State variance -r = 1 # Observation variance -Tₘ = 300 # Number of observation -Nₚ = 20 # Number of particles -Nₛ = 500 # Number of samples -seed = 9 # Reproduce everything - -θ₀ = Parameters((a, q, r, Tₘ)) +# Let's simulate some data +a = 0.9 # State Variance +q = 0.5 # Observation variance +Tₘ = 200 # Number of observation +Nₚ = 20 # Number of particles +Nₛ = 500 # Number of samples +seed = 1 # Reproduce everything +θ₀ = Parameters((a, q, Tₘ)) rng = Random.MersenneTwister(seed) x = zeros(Tₘ) @@ -51,6 +66,15 @@ for t in 1:Tₘ y[t] = rand(rng, g(reference, x[t], t)) end +# Here are the latent and observation series: +plot(x; label="x") +xlabel!("t") + +# +plot(y; label="y") +xlabel!("x") + +# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition: function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG) x₀ = rand(rng, f₀(model)) model.X[1] = x₀ @@ -65,14 +89,18 @@ function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG) end end +# `AdvancedPS` relies on `Libtask` to copy models during their execution but we need to make sure the +# internal data of each model is properly copied over as well. Libtask.tape_copy(model::NonLinearTimeSeries) = deepcopy(model) -Random.seed!(rng, seed) +# Here we use the particle gibbs kernel without adaptive resampling. model = NonLinearTimeSeries(θ₀) -pgas = AdvancedPS.PG(Nₚ) -chains = sample(rng, model, pgas, Nₛ; progress=false) +pgas = AdvancedPS.PG(Nₚ, 1.0) +chains = sample(rng, model, pgas, Nₛ; progress=false); +#md nothing #hide -# Utility to replay a particle trajectory +# The trajectories are not stored during the sampling and we need to regenerate the history of each +# sample if we want to look at the individual traces. function replay(particle::AdvancedPS.Particle) trng = deepcopy(particle.rng) Random123.set_counter!(trng.rng, 0) @@ -92,14 +120,24 @@ end particles = hcat([trajectory.model.f.X for trajectory in trajectories]...) # Concat all sampled states mean_trajectory = mean(particles; dims=2) +#md nothing #hide +# We can now plot all the generated traces. +# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help +# with the degeneracy problem. +plot() scatter(particles; label=false, opacity=0.01, color=:black) -plot!(x; color=:red, label="Original Trajectory") -plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9) +plot!(x; color=:darkorange, label="Original Trajectory") +plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) xlabel!("t") ylabel!("State") +# We can also check the mixing as defined in the Gaussian State Space model example. As seen on the +# scatter plot above, we are mostly left with a single trajectory before timestep 150. The orange +# bar is the optimal mixing rate for the number of particles we use. update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ +#md nothing #hide + plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft) hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)") xlabel!("Iteration")