Skip to content

Commit

Permalink
JET optimize SqKalman
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Nov 6, 2024
1 parent da2a9ff commit 9e0900d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ pf = ParticleFilter(N, dynamics, measurement, df, dg, d0)
du = MvNormal(m,1) # Control input distribution
x,u,y = simulate(pf,T,du) # Simulate trajectory using the model in the filter
tosvec(y) = reinterpret(SVector{length(y[1]),Float64}, reduce(hcat,y))[:] |> copy
x,u,y = tosvec.((x,u,y))
x,u,y = tosvec.((x,u,y)) # It's good for performance to use StaticArrays to the extent possible
xb,ll = smooth(pf, M, u, y) # Sample smooting particles
xb,ll = smooth(pf, M, u, y) # Sample smoothing particles
xbm = smoothed_mean(xb) # Calculate the mean of smoothing trajectories
xbc = smoothed_cov(xb) # And covariance
xbt = smoothed_trajs(xb) # Get smoothing trajectories
Expand Down
44 changes: 32 additions & 12 deletions src/sq_kalman.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@with_kw struct SqKalmanFilter{AT,BT,CT,DT,R1T,R2T,R2DT,D0T,XT,RT,P,αT} <: AbstractKalmanFilter
@with_kw mutable struct SqKalmanFilter{AT,BT,CT,DT,R1T,R2T,R2DT,D0T,XT,RT,P,αT} <: AbstractKalmanFilter
A::AT
B::BT
C::CT
Expand Down Expand Up @@ -48,9 +48,14 @@ function SqKalmanFilter(A,B,C,D,R1,R2,d0=MvNormal(Matrix(R1)); p = SciMLBase.Nul
if check
maximum(abs, eigvals(A isa SMatrix ? Matrix(A) : A)) 2 && @warn "The dynamics matrix A has eigenvalues with absolute value ≥ 2. This is either a highly unstable system, or you have forgotten to discretize a continuous-time model. If you are sure that the system is provided in discrete time, you can disable this warning by setting check=false." maxlog=1
end
R = UpperTriangular(convert_cov_type(R1, cholesky(d0.Σ).U))
R1 = cholesky(R1).U
R2 = cholesky(R2).U
SqKalmanFilter(A,B,C,D,R1,R2,MvNormal(Matrix(R2'R2)), d0, Vector(d0.μ), UpperTriangular(Matrix(cholesky(d0.Σ).U)), Ref(1), p, α)

R2d = convert_cov_type(R2, R2'R2)
x0 = convert_x0_type(d0.μ)

SqKalmanFilter(A,B,C,D,R1,R2,R2d, d0, x0, R, Ref(1), p, α)
end


Expand Down Expand Up @@ -80,8 +85,8 @@ covtype(kf::SqKalmanFilter) = typeof(kf.R.data)
Reset the initial distribution of the state. Optionally, a new mean vector `x0` can be provided.
"""
function reset!(kf::SqKalmanFilter; x0 = kf.d0.μ)
kf.x .= Vector(x0)
kf.R .= cholesky(kf.d0.Σ).U
kf.x = convert_x0_type(x0)
kf.R = UpperTriangular(convert_cov_type(kf.R1, cholesky(kf.d0.Σ).U))
kf.t[] = 1
end

Expand All @@ -94,11 +99,21 @@ function predict!(kf::SqKalmanFilter, u, p=parameters(kf), t::Real = index(kf);
@unpack A,B,x,R = kf
At = get_mat(A, x, u, p, t)
Bt = get_mat(B, x, u, p, t)
x .= At*x .+ Bt*u |> vec
kf.x = At*x .+ Bt*u |> vec
if kf.α == 1
R .= UpperTriangular(qr!([R*At';R1]).R)
M1 = [R*At';R1]
if R.data isa SMatrix
kf.R = UpperTriangular(qr(M1).R)
else
kf.R = UpperTriangular(qr!(M1).R)
end
else
R .= UpperTriangular(qr!([sqrt(kf.α)*R*At';R1]).R) # symmetrize(kf.α*At*R*At') + R1
M = [sqrt(kf.α)*R*At';R1]
if R.data isa SMatrix
kf.R = UpperTriangular(qr(M).R) # symmetrize(kf.α*At*R*At') + R1
else
kf.R = UpperTriangular(qr!(M).R) # symmetrize(kf.α*At*R*At') + R1
end
end
kf.t[] += 1
end
Expand All @@ -115,16 +130,21 @@ function correct!(kf::SqKalmanFilter, u, y, p=parameters(kf), t::Real = index(kf
Dt = get_mat(D, x, u, p, t)
e = y .- Ct*x
if !iszero(D)
e .-= Dt*u
e -= Dt*u
end
S0 = qr([R*Ct';R2]).R
S = UpperTriangular(S0)
if det(S) < 0 # Cheap for triangular matrices
@. S0 = -S0 # To avoid log(negative) in logpdf
if any(<(0), @view(S0[diagind(S0)])) || det(S) < 0 # Cheap for triangular matrices
S0 = -S0 # To avoid log(negative) in logpdf
end
K = ((R'*(R*Ct'))/S)/(S')
x .+= K*e
R .= UpperTriangular(qr!([R*(I - K*Ct)';R2*K']).R)
kf.x += K*e
M = [R*(I - K*Ct)';R2*K']
if R.data isa SMatrix
kf.R = UpperTriangular(qr(M).R)
else
kf.R = UpperTriangular(qr!(M).R)
end
SS = S'S
Sᵪ = Cholesky(S0, 'U', 0)
ll = logpdf(MvNormal(PDMat(SS, Sᵪ)), e)# - 1/2*logdet(S) # logdet is included in logpdf
Expand Down
24 changes: 21 additions & 3 deletions test/test_jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,28 @@ x,u,y = tosvec.((x,u,y))
@report_call predict!(kf, u[1])

@test_opt correct!(kf, u[1], y[1])
@report_call predict!(kf, u[1])
@report_call correct!(kf, u[1], y[1])

ukf = UnscentedKalmanFilter(dynamics, measurement, eye(nx), eye(ny), d0; ny, nu)
@test ukf.R1 isa SMatrix{2, 2, Float64, 4}
@test_opt predict!(ukf, u[1])
@report_call predict!(ukf, u[1])

@test_opt correct!(ukf, u[1], y[1])
@report_call predict!(ukf, u[1])
@report_call correct!(ukf, u[1], y[1])


skf = SqKalmanFilter(_A, _B, _C, 0, eye(nx), eye(ny), d0)
@test skf.R1.data isa SMatrix{2, 2, Float64, 4}
@test_opt predict!(skf, u[1])
@report_call predict!(skf, u[1])

@test_opt correct!(skf, u[1], y[1])
@report_call correct!(skf, u[1], y[1])


## Test allocations

## Test allocations ============================================================
forward_trajectory(kf, u, y)
a = @allocations forward_trajectory(kf, u, y)
@test a <= 15
Expand All @@ -53,3 +63,11 @@ a = @allocations forward_trajectory(kf, u, y)
forward_trajectory(ukf, u, y)
a = @allocations forward_trajectory(ukf, u, y)
@test a <= 15

forward_trajectory(skf, u, y)
a = @allocations forward_trajectory(skf, u, y)

@test a <= 50 # was 7 on julia v1.10.6


## Test differentiability ======================================================

0 comments on commit 9e0900d

Please sign in to comment.