Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auxiliary particle filter #11

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
DataStructures = "0.18.20"
GaussianDistributions = "0.5.2"
SSMProblems = "0.4.0"
StatsBase = "0.34.3"

[extras]
Expand Down
19 changes: 19 additions & 0 deletions src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using NNlib

abstract type AbstractFilter <: AbstractSampler end

abstract type AbstractParticleFilter{N} <: AbstractFilter end

"""
predict([rng,] model, alg, iter, state; kwargs...)

Expand Down Expand Up @@ -41,6 +43,22 @@ Perform a combined predict and update call on a single iteration of the filter.
"""
function step end

"""
reset_weights!(log_weights, filter)

Reset container log-weights after a resampling step
"""
function reset_weights! end

"""
update_weights!
"""
function update_weights! end

function log_marginal end

function update_ref! end

function initialise(model, alg; kwargs...)
return initialise(default_rng(), model, alg; kwargs...)
end
Expand Down Expand Up @@ -106,6 +124,7 @@ include("models/hierarchical.jl")

# Filtering/smoothing algorithms
include("algorithms/bootstrap.jl")
include("algorithms/apf.jl")
include("algorithms/kalman.jl")
include("algorithms/forward.jl")
include("algorithms/rbpf.jl")
Expand Down
116 changes: 116 additions & 0 deletions src/algorithms/apf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
export AuxiliaryParticleFilter, APF

mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N}
resampler::RS
aux::Vector # Auxiliary weights
end

function AuxiliaryParticleFilter(
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N))
end

const APF = AuxiliaryParticleFilter

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter{N},
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
end

function update_weights!(
rng::AbstractRNG, filter, model, step, states, observation; kwargs...
)
simulation_weights = eta(rng, model, step, states, observation)
return states.log_weights += simulation_weights
end

function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::AuxiliaryParticleFilter,
step::Integer,
states::ParticleContainer{T},
observation;
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
# states = update_weights!(rng, filter.eta, model, step, states.filtered, observation; kwargs...)

# Compute auxilary weights
# POC: use the simplest approximation to the predictive likelihood
# Ideally should be something like update_weights!(filter, ...)
predicted = map(
x -> mean(SSMProblems.distribution(model.dyn, step, x; kwargs...)),
states.filtered.particles,
)
auxiliary_weights = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted
)
states.filtered.log_weights .+= auxiliary_weights
filter.aux = auxiliary_weights

states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
states.proposed.particles,
)

return update_ref!(states, ref_state, filter, step)
end

function update(
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter,
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T}
@debug "step $step"
log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
collect(states.proposed.particles),
)

states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return states, logmarginal(states, filter)
end

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AuxiliaryParticleFilter,
iter::Integer,
state,
observation;
kwargs...,
)
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...)
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...)

return filtered_state, ll
end
Comment on lines +91 to +104
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work without redefining step, as long as AuxiliaryParticleFilter is a subclass of AbstractFilter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current api doesn't forward observation to predict. I also tend to agree with #9 and splitting resample and predict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current api doesn't forward observation to predict. I also tend to agree with #9 and splitting resample and predict.

I wonder if it should. For general proposal distributions other than forward simulation (i.e. bootstrap filter) that would be needed.

Or at least have a subtype for Guided, Independent and Auxiliary filters that has this modified step/predict method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current api doesn't forward observation to predict

I completely forgot about that. We may need to think about defining some sort of AbstractProposal for more complex transition kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yeah, I agree we should split up resample and predict


function reset_weights!(
state::ParticleState{T,WT}, idxs, filter::AuxiliaryParticleFilter
) where {T,WT<:Real}
# From Choping: An Introduction to sequential monte carlo, section 10.3.3
state.log_weights = state.log_weights[idxs] - filter.aux[idxs]
return state
end

function logmarginal(states::ParticleContainer, ::AuxiliaryParticleFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end
40 changes: 26 additions & 14 deletions src/algorithms/bootstrap.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
export BootstrapFilter, BF

struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter
N::Integer
struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N}
resampler::RS
end

"""Shorthand for `BootstrapFilter`"""
const BF = BootstrapFilter

function BootstrapFilter(
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return BootstrapFilter(N, conditional_resampler)
return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler)
end

"""Shorthand for `BootstrapFilter`"""
const BF = BootstrapFilter

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::BootstrapFilter;
filter::BootstrapFilter{N};
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
initial_weights = zeros(T, filter.N)
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state)
return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
)
end

function predict(
Expand All @@ -37,13 +38,13 @@ function predict(
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered)
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
collect(states.proposed),
)

return update_ref!(states, ref_state, step)
return update_ref!(states, ref_state, filter, step)
end

function update(
Expand All @@ -62,5 +63,16 @@ function update(
states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return states, logmarginal(states)
return states, logmarginal(states, filter)
end

function reset_weights!(
state::ParticleState{T,WT}, idxs, filter::BootstrapFilter
) where {T,WT<:Real}
fill!(state.log_weights, -log(WT(length(state.particles))))
return state
end

function logmarginal(states::ParticleContainer, ::BootstrapFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end
4 changes: 2 additions & 2 deletions src/algorithms/rbpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
function predict(
rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, t::Integer, states; kwargs...
)
states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered)
states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered, algo)

states.proposed.particles = map(
x -> marginal_predict(rng, model, algo, t, x; kwargs...),
Expand Down Expand Up @@ -108,7 +108,7 @@ function update(

states.filtered.log_weights = states.proposed.log_weights + log_increments

return states, logmarginal(states)
return states, logmarginal(states, algo)
end

#################################
Expand Down
21 changes: 12 additions & 9 deletions src/containers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataStructures: Stack
using Random: rand
import Random: rand

## GAUSSIAN STATES #########################################################################

Expand Down Expand Up @@ -105,13 +105,20 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles)
Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i]
# Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i]

function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real}
function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real}
fill!(state.log_weights, zero(WT))
return state.log_weights
end

function logmarginal(states::ParticleContainer, ::AbstractFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end

function update_ref!(
pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0
pc::ParticleContainer{T},
ref_state::Union{Nothing,AbstractVector{T}},
::AbstractFilter,
step::Integer=0,
) where {T}
Comment on lines 117 to 122
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the filter passed through this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it here, but we might need to dispatch on the filter type when updating the reference particle. That would be useful for ancestor resampling for example

# this comes from Nicolas Chopin's package particles
if !isnothing(ref_state)
Expand All @@ -122,10 +129,6 @@ function update_ref!(
return pc
end

function logmarginal(states::ParticleContainer)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end

## SPARSE PARTICLE STORAGE #################################################################

Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a)
Expand Down Expand Up @@ -180,7 +183,7 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64})
end

function insert!(
tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64}
tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{<:Integer}
) where {T}
# parents of new generation
parents = getindex(tree.leaves, ancestors)
Expand Down Expand Up @@ -213,7 +216,7 @@ function expand!(tree::ParticleTree)
return tree
end

function get_offspring(a::AbstractVector{Int64})
function get_offspring(a::AbstractVector{<:Integer})
offspring = zero(a)
for i in a
offspring[i] += 1
Expand Down
19 changes: 11 additions & 8 deletions src/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ export Multinomial, Systematic, Metropolis, Rejection
abstract type AbstractResampler end

function resample(
rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT}
rng::AbstractRNG,
resampler::AbstractResampler,
states::ParticleState{PT,WT},
filter::AbstractFilter;
weights::AbstractVector{WT}=StatsBase.weights(states)
) where {PT,WT}
weights = StatsBase.weights(states)
idxs = sample_ancestors(rng, resampler, weights)

new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))

new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))
reset_weights!(new_state, idxs, filter)
return new_state, idxs
end

Expand All @@ -23,8 +25,9 @@ function resample(
rng::AbstractRNG,
resampler::AbstractResampler,
states::RaoBlackwellisedParticleState{T,M,ZT},
::AbstractFilter;
weights=StatsBase.weights(states)
) where {T,M,ZT}
weights = StatsBase.weights(states)
idxs = sample_ancestors(rng, resampler, weights)

new_state = RaoBlackwellisedParticleState(
Expand All @@ -49,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler
end

function resample(
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}, filter::AbstractFilter
) where {PT,WT}
n = length(state)
# TODO: computing weights twice. Should create a wrapper to avoid this
Expand All @@ -58,7 +61,7 @@ function resample(
@debug "ESS: $ess"

if cond_resampler.threshold * n ≥ ess
return resample(rng, cond_resampler.resampler, state)
return resample(rng, cond_resampler.resampler, state, filter; weights=weights)
else
return deepcopy(state), collect(1:n)
end
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ end
_, _, data = sample(rng, model, 20)

bf = BF(2^12; threshold=0.8)
apf = APF(2^10, threshold=1.)
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data)
_, llapf= GeneralisedFilters.filter(rng, model, apf, data)
kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data)

xs = bf_state.filtered.particles
Expand All @@ -120,6 +122,7 @@ end

# since this is log valued, we can up the tolerance
@test llkf ≈ llbf atol = 0.1
@test llkf ≈ llapf atol = 2
end

@testitem "Forward algorithm test" begin
Expand Down