-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add auxiliary particle filter #11
base: main
Are you sure you want to change the base?
Changes from all commits
ce97457
9dc1f99
840d7a9
7daa3d1
6564f54
72282af
9a8b28e
b78efce
c791095
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
export AuxiliaryParticleFilter, APF | ||
|
||
mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N} | ||
resampler::RS | ||
aux::Vector # Auxiliary weights | ||
end | ||
|
||
function AuxiliaryParticleFilter( | ||
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() | ||
) | ||
conditional_resampler = ESSResampler(threshold, resampler) | ||
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N)) | ||
end | ||
|
||
const APF = AuxiliaryParticleFilter | ||
|
||
function initialise( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel{T}, | ||
filter::AuxiliaryParticleFilter{N}, | ||
ref_state::Union{Nothing,AbstractVector}=nothing, | ||
kwargs..., | ||
) where {N,T} | ||
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) | ||
initial_weights = zeros(T, N) | ||
|
||
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) | ||
end | ||
|
||
function update_weights!( | ||
rng::AbstractRNG, filter, model, step, states, observation; kwargs... | ||
) | ||
simulation_weights = eta(rng, model, step, states, observation) | ||
return states.log_weights += simulation_weights | ||
end | ||
|
||
function predict( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel, | ||
filter::AuxiliaryParticleFilter, | ||
step::Integer, | ||
states::ParticleContainer{T}, | ||
observation; | ||
ref_state::Union{Nothing,AbstractVector{T}}=nothing, | ||
kwargs..., | ||
) where {T} | ||
# states = update_weights!(rng, filter.eta, model, step, states.filtered, observation; kwargs...) | ||
|
||
# Compute auxilary weights | ||
# POC: use the simplest approximation to the predictive likelihood | ||
# Ideally should be something like update_weights!(filter, ...) | ||
predicted = map( | ||
x -> mean(SSMProblems.distribution(model.dyn, step, x; kwargs...)), | ||
states.filtered.particles, | ||
) | ||
auxiliary_weights = map( | ||
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted | ||
) | ||
states.filtered.log_weights .+= auxiliary_weights | ||
filter.aux = auxiliary_weights | ||
|
||
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) | ||
states.proposed.particles = map( | ||
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), | ||
states.proposed.particles, | ||
) | ||
|
||
return update_ref!(states, ref_state, filter, step) | ||
end | ||
|
||
function update( | ||
model::StateSpaceModel{T}, | ||
filter::AuxiliaryParticleFilter, | ||
step::Integer, | ||
states::ParticleContainer, | ||
observation; | ||
kwargs..., | ||
) where {T} | ||
@debug "step $step" | ||
log_increments = map( | ||
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), | ||
collect(states.proposed.particles), | ||
) | ||
|
||
states.filtered.log_weights = states.proposed.log_weights + log_increments | ||
states.filtered.particles = states.proposed.particles | ||
|
||
return states, logmarginal(states, filter) | ||
end | ||
|
||
function step( | ||
rng::AbstractRNG, | ||
model::AbstractStateSpaceModel, | ||
alg::AuxiliaryParticleFilter, | ||
iter::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) | ||
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) | ||
|
||
return filtered_state, ll | ||
end | ||
|
||
function reset_weights!( | ||
state::ParticleState{T,WT}, idxs, filter::AuxiliaryParticleFilter | ||
) where {T,WT<:Real} | ||
# From Choping: An Introduction to sequential monte carlo, section 10.3.3 | ||
state.log_weights = state.log_weights[idxs] - filter.aux[idxs] | ||
return state | ||
end | ||
|
||
function logmarginal(states::ParticleContainer, ::AuxiliaryParticleFilter) | ||
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
using DataStructures: Stack | ||
using Random: rand | ||
import Random: rand | ||
|
||
## GAUSSIAN STATES ######################################################################### | ||
|
||
|
@@ -105,13 +105,20 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles) | |
Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i] | ||
# Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i] | ||
|
||
function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real} | ||
function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real} | ||
fill!(state.log_weights, zero(WT)) | ||
return state.log_weights | ||
end | ||
|
||
function logmarginal(states::ParticleContainer, ::AbstractFilter) | ||
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) | ||
end | ||
|
||
function update_ref!( | ||
pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0 | ||
pc::ParticleContainer{T}, | ||
ref_state::Union{Nothing,AbstractVector{T}}, | ||
::AbstractFilter, | ||
step::Integer=0, | ||
) where {T} | ||
Comment on lines
117
to
122
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need the filter passed through this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need it here, but we might need to dispatch on the filter type when updating the reference particle. That would be useful for ancestor resampling for example |
||
# this comes from Nicolas Chopin's package particles | ||
if !isnothing(ref_state) | ||
|
@@ -122,10 +129,6 @@ function update_ref!( | |
return pc | ||
end | ||
|
||
function logmarginal(states::ParticleContainer) | ||
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) | ||
end | ||
|
||
## SPARSE PARTICLE STORAGE ################################################################# | ||
|
||
Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a) | ||
|
@@ -180,7 +183,7 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64}) | |
end | ||
|
||
function insert!( | ||
tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64} | ||
tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{<:Integer} | ||
) where {T} | ||
# parents of new generation | ||
parents = getindex(tree.leaves, ancestors) | ||
|
@@ -213,7 +216,7 @@ function expand!(tree::ParticleTree) | |
return tree | ||
end | ||
|
||
function get_offspring(a::AbstractVector{Int64}) | ||
function get_offspring(a::AbstractVector{<:Integer}) | ||
offspring = zero(a) | ||
for i in a | ||
offspring[i] += 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work without redefining step, as long as
AuxiliaryParticleFilter
is a subclass ofAbstractFilter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current api doesn't forward
observation
topredict
. I also tend to agree with #9 and splitting resample and predict.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it should. For general proposal distributions other than forward simulation (i.e. bootstrap filter) that would be needed.
Or at least have a subtype for Guided, Independent and Auxiliary filters that has this modified step/predict method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely forgot about that. We may need to think about defining some sort of
AbstractProposal
for more complex transition kernels.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And yeah, I agree we should split up
resample
andpredict