Skip to content

Commit

Permalink
Fix particle gibbs example (#58)
Browse files Browse the repository at this point in the history
* Fix particle gibbs example

* Format
  • Loading branch information
FredericWantiez authored Jun 16, 2022
1 parent 21d0425 commit 0bb8fde
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 58 deletions.
77 changes: 41 additions & 36 deletions examples/gaussian-ssm/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ using Plots
# y_{t} = x_{t} + \nu_{t} \quad \nu_{t} \sim \mathcal{N}(0, r^2)
# ```
#
# Here we assume the static parameters $\theta = (a, q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# To use particle gibbs with the ancestor sampling step we need to provide both the transition and observation densities.
# From the definition above we get:
# Here we assume the static parameters $\theta = (a, q^2, r^2)$ are known and we are only interested in sampling from the latent states $x_t$.
# To use particle gibbs with the ancestor sampling update step we need to provide both the transition and observation densities.
#
# From the definition above we get:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# ```
Expand All @@ -25,55 +26,54 @@ using Plots
# ```
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.

# We are ready to use `AdvancedPS` with our model. We first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
Parameters = @NamedTuple begin
a::Float64
q::Float64
r::Float64
end

mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearTimeSeries::Parameters) = new(Vector{Float64}(), θ)
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
end

# and the densities defined above.
f(m::NonLinearTimeSeries, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::NonLinearTimeSeries, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::NonLinearTimeSeries) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
#md nothing #hide

# To implement `AdvancedPS.AbstractStateSpaceModel` we need to define a few functions to specify the dynamics of the system:
# - `AdvancedPS.initialization` the initial state density
# - `AdvancedPS.transition` the state transition density
# - `AdvancedPS.observation` the observation score given the observed data
# - `AdvancedPS.isdone` signals the end of the execution for the model
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
# We also need to specify the dynamics of the system through the transition equations:
# - `AdvancedPS.initialization`: the initial state density
# - `AdvancedPS.transition`: the state transition density
# - `AdvancedPS.observation`: the observation score given the observed data
# - `AdvancedPS.isdone`: signals the end of the execution for the model
AdvancedPS.initialization(model::LinearSSM) = f₀(model)
AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step)
function AdvancedPS.observation(model::LinearSSM, state, step)
return logpdf(g(model, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Tₘ = 300 # Number of observation
Nₚ = 50 # Number of particles
Tₘ = 200 # Number of observation
Nₚ = 20 # Number of particles
Nₛ = 500 # Number of samples
seed = 9 # Reproduce everything
seed = 1 # Reproduce everything

θ₀ = Parameters((a, q, r))

rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = NonLinearTimeSeries(θ₀)
reference = LinearSSM(θ₀)
x[1] = rand(rng, f₀(reference))
for t in 1:Tₘ
if t < Tₘ
Expand All @@ -82,35 +82,40 @@ for t in 1:Tₘ
y[t] = rand(rng, g(reference, x[t], t))
end

# Let's have a look at the simulated data from the latent state dynamics
# Here are the latent and obseravation timeseries
plot(x; label="x")
xlabel!("t")

# and the observation data
#
plot(y; label="y")
xlabel!("x")

#
model = NonLinearTimeSeries(θ₀)
# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
# and a model interface.
model = LinearSSM(θ₀)
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false)
chains = sample(rng, model, pgas, Nₛ; progress=false);
#md nothing #hide

# The actual sampled trajectory is in the trajectory inner model
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2)
#
particles = hcat([chain.trajectory.model.X for chain in chains]...)
mean_trajectory = mean(particles; dims=2);
#md nothing #hide

#
# This toy model is small enough to inspect all the generated traces:
scatter(particles; label=false, opacity=0.01, color=:black)
plot!(x; color=:red, label="Original Trajectory")
plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
xlabel!("t")
ylabel!("State")

# By sampling an ancestor from the reference particle in the particle gibbs sampler we split the
# trajectory of the reference particle.
# We used a particle gibbs kernel with the ancestor updating step which should help with the particle
# degeneracy problem and improve the mixing.
# We can compute the update rate of $x_t$ vs $t$ defined as the proportion of times $t$ where $x_t$ gets updated:
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
#md nothing #hide

# and compare it to the theoretical value of $1 - 1/Nₚ$.
plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
xlabel!("Iteration")
Expand Down
82 changes: 60 additions & 22 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # Particle Gibbs with Ancestor Sampling
# # Particle Gibbs for non-linear models
using AdvancedPS
using Random
using Distributions
Expand All @@ -8,10 +8,27 @@ using Random123
using Libtask: TArray
using Libtask

# We consider the following stochastic volatility model:
#
# ```math
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1)
# ```
#
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2)
# ```
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
Parameters = @NamedTuple begin
a::Float64
q::Float64
r::Float64
T::Int
end

Expand All @@ -21,22 +38,20 @@ mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
NonLinearTimeSeries::Parameters) = new(TArray(Float64, θ.T), θ)
end

f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q) # Transition density
g(model::NonLinearTimeSeries, state, t) = Normal(state, model.θ.r) # Observation density
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q) # Initial state density
f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2)
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q)
#md nothing #hide

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Tₘ = 300 # Number of observation
Nₚ = 20 # Number of particles
Nₛ = 500 # Number of samples
seed = 9 # Reproduce everything

θ₀ = Parameters((a, q, r, Tₘ))
# Let's simulate some data
a = 0.9 # State Variance
q = 0.5 # Observation variance
Tₘ = 200 # Number of observation
Nₚ = 20 # Number of particles
Nₛ = 500 # Number of samples
seed = 1 # Reproduce everything

θ₀ = Parameters((a, q, Tₘ))
rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
Expand All @@ -51,6 +66,15 @@ for t in 1:Tₘ
y[t] = rand(rng, g(reference, x[t], t))
end

# Here are the latent and observation series:
plot(x; label="x")
xlabel!("t")

#
plot(y; label="y")
xlabel!("x")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
x₀ = rand(rng, f₀(model))
model.X[1] = x₀
Expand All @@ -65,14 +89,18 @@ function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
end
end

# `AdvancedPS` relies on `Libtask` to copy models during their execution but we need to make sure the
# internal data of each model is properly copied over as well.
Libtask.tape_copy(model::NonLinearTimeSeries) = deepcopy(model)

Random.seed!(rng, seed)
# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PG(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false)
pgas = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pgas, Nₛ; progress=false);
#md nothing #hide

# Utility to replay a particle trajectory
# The trajectories are not stored during the sampling and we need to regenerate the history of each
# sample if we want to look at the individual traces.
function replay(particle::AdvancedPS.Particle)
trng = deepcopy(particle.rng)
Random123.set_counter!(trng.rng, 0)
Expand All @@ -92,14 +120,24 @@ end

particles = hcat([trajectory.model.f.X for trajectory in trajectories]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2)
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help
# with the degeneracy problem.
plot()
scatter(particles; label=false, opacity=0.01, color=:black)
plot!(x; color=:red, label="Original Trajectory")
plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
xlabel!("t")
ylabel!("State")

# We can also check the mixing as defined in the Gaussian State Space model example. As seen on the
# scatter plot above, we are mostly left with a single trajectory before timestep 150. The orange
# bar is the optimal mixing rate for the number of particles we use.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
#md nothing #hide

plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
xlabel!("Iteration")
Expand Down

2 comments on commit 0bb8fde

@yebai
Copy link
Member

@yebai yebai commented on 0bb8fde Jun 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62639

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 0bb8fdea519dac7558851753caf28dd2d56b5a6c
git push origin v0.4.0

Please sign in to comment.