Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of Nutpie/Nuts-rs mass matrix adaption #312

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ version = "0.4.3"
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
12 changes: 10 additions & 2 deletions src/adaptation/Adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ using LinearAlgebra: LinearAlgebra
using Statistics: Statistics
using UnPack: @unpack, @pack!

using ..AdvancedHMC: DEBUG, AbstractScalarOrVec
using ..AdvancedHMC: DEBUG, AbstractScalarOrVec, PhasePoint

abstract type AbstractAdaptor end
abstract type AbstractHMCAdaptorWithGradients <: AbstractAdaptor end
function getM⁻¹ end
function getϵ end
function adapt! end
function reset! end
function initialize! end
function finalize! end
export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹
export AbstractAdaptor, AbstractHMCAdaptorWithGradients, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹

struct NoAdaptation <: AbstractAdaptor end
export NoAdaptation
Expand Down Expand Up @@ -57,4 +58,11 @@ finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)
include("stan_adaptor.jl")
export NaiveHMCAdaptor, StanHMCAdaptor

const LOWER_LIMIT::Float64 = 1e-10
const UPPER_LIMIT::Float64 = 1e10

include("nutpie_adaptor.jl")
export NutpieHMCAdaptor, ExpWeightedWelfordVar, NutpieHMCAdaptorNoSwitch, NutpieHMCAdaptorNoGradInit,NutpieHMCAdaptorNoSwitchNoGradInit


end # module
102 changes: 96 additions & 6 deletions src/adaptation/massmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function adapt!(
adaptor::MassMatrixAdaptor,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat},
is_update::Bool = true,
is_update::Bool=true,
)
resize!(adaptor, θ)
push!(adaptor, θ)
Expand All @@ -38,7 +38,7 @@ adapt!(
::UnitMassMatrix,
::AbstractVecOrMat{<:AbstractFloat},
::AbstractScalarOrVec{<:AbstractFloat},
is_update::Bool = true,
is_update::Bool=true,
) = nothing

## Diagonal mass matrix adaptor
Expand Down Expand Up @@ -91,8 +91,8 @@ Base.show(io::IO, ::WelfordVar) = print(io, "WelfordVar")

function WelfordVar{T}(
sz::Union{Tuple{Int},Tuple{Int,Int}};
n_min::Int = 10,
var = ones(T, sz),
n_min::Int=10,
var=ones(T, sz)
) where {T<:AbstractFloat}
return WelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var)
end
Expand Down Expand Up @@ -133,6 +133,96 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5))
end

# Rust implementation of NUTS used in Nutpie (comes from nuts-rs crate)
# Source: https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs

mutable struct ExpWeightedWelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
exp_variance_draw::WelfordVar{T,E}
exp_variance_grad::WelfordVar{T,E}
exp_variance_draw_bg::WelfordVar{T,E}
exp_variance_grad_bg::WelfordVar{T,E}
function ExpWeightedWelfordVar(exp_variance_draw::WelfordVar{T,E}, exp_variance_grad::WelfordVar{T,E}, exp_variance_draw_bg::WelfordVar{T,E}, exp_variance_grad_bg::WelfordVar{T,E}) where {T,E}
return new{eltype(E),E}(exp_variance_draw, exp_variance_grad, exp_variance_draw_bg, exp_variance_grad_bg)
end
end

# save the best estimate of the variance in the "current" WelfordVar
getM⁻¹(ve::ExpWeightedWelfordVar) = ve.exp_variance_draw.var

Base.show(io::IO, ::ExpWeightedWelfordVar) = print(io, "ExpWeightedWelfordVar")

function ExpWeightedWelfordVar{T}(
sz::Union{Tuple{Int},Tuple{Int,Int}};
n_min::Int=4, var=ones(T, sz)
) where {T<:AbstractFloat}
# return ExpWeightedWelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var)
return ExpWeightedWelfordVar(WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var))
end

ExpWeightedWelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = ExpWeightedWelfordVar{Float64}(sz; kwargs...)

function Base.resize!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T<:AbstractFloat}
@assert size(θ) == size(∇logπ) "Size of draw and grad must be the same."
resize!(wv.exp_variance_draw, θ)
resize!(wv.exp_variance_grad, ∇logπ)
resize!(wv.exp_variance_draw_bg, θ)
resize!(wv.exp_variance_grad_bg, ∇logπ)
end

function reset!(wv::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat}
reset!(wv.exp_variance_draw)
reset!(wv.exp_variance_grad)
reset!(wv.exp_variance_draw_bg)
reset!(wv.exp_variance_grad_bg)
end

function Base.push!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T}
@assert size(θ) == size(∇logπ) "Size of draw and grad must be the same."
push!(wv.exp_variance_draw, θ)
push!(wv.exp_variance_grad, ∇logπ)
push!(wv.exp_variance_draw_bg, θ)
push!(wv.exp_variance_grad_bg, ∇logπ)
end

# swap the background and foreground estimators for both _draw and _grad variance
# unlike the Rust implementation, we don't update the estimators inside of the switch as well (called separately)
function switch!(wv::ExpWeightedWelfordVar)
wv.exp_variance_draw = wv.exp_variance_draw_bg
reset!(wv.exp_variance_draw_bg)
wv.exp_variance_grad = wv.exp_variance_grad_bg
reset!(wv.exp_variance_grad_bg)
end
current_count(wv) = wv.exp_variance_draw.n
background_count(wv) = wv.exp_variance_draw_bg.n

function adapt!(
adaptor::ExpWeightedWelfordVar,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat},
∇logπ::AbstractVecOrMat{<:AbstractFloat},
is_update::Bool=true
)
resize!(adaptor, θ, ∇logπ)
push!(adaptor, θ, ∇logπ)
is_update && update!(adaptor)
end

# TODO: handle NaN
function get_estimation(ad::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat}
var_draw = get_estimation(ad.exp_variance_draw)
var_grad = get_estimation(ad.exp_variance_grad)
# mimics: let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT);
var = (var_draw ./ var_grad) .|> sqrt .|> x -> clamp(x, LOWER_LIMIT, UPPER_LIMIT)
# re-use the last estimate `var` if the current estimate is not valid
return all(isfinite.(var)) ? var : ad.exp_variance_draw.var
end
# reuse the `var` slot in the `exp_variance_draw` (which is `WelfordVar`)
# to store the estimated variance of the draw (the "current" / "foreground" one)
function update!(ad::ExpWeightedWelfordVar)
current_count(ad) >= ad.exp_variance_draw.n_min && (ad.exp_variance_draw.var .= get_estimation(ad))
end


## Dense mass matrix adaptor

abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end
Expand Down Expand Up @@ -179,8 +269,8 @@ Base.show(io::IO, ::WelfordCov) = print(io, "WelfordCov")

function WelfordCov{T}(
sz::Tuple{Int};
n_min::Int = 10,
cov = LinearAlgebra.diagm(0 => ones(T, first(sz))),
n_min::Int=10,
cov=LinearAlgebra.diagm(0 => ones(T, first(sz)))
) where {T<:AbstractFloat}
d = first(sz)
return WelfordCov(0, n_min, zeros(T, d), zeros(T, d, d), zeros(T, d), cov)
Expand Down
Loading