Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of Nutpie/Nuts-rs mass matrix adaption #312

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

svilupp
Copy link

@svilupp svilupp commented Jan 11, 2023

NOT READY!

Quick draft of the new mass matrix adaption begin tested/used by PyMC team for discussion. Attempt at #311

Notes:

  • First venture in AbstractHMC
  • Translated from Rust; a few choices that could be challenged (eg, frequencies, window sizes, but also the fact that diag. matrix is updated on each iteration)
  • made quick and dirty choices around implementation (composite WelfordVar struct, instead of a new and cleaner var; reusing some fields within it)
  • health checks missing - eg, is this a good draw (divergences, termination at step=0), which are present in the original Rust implementation
  • would require bending the current Abstract HMC APIs a bit (eg, adapt! needs to provide gradient information (DualValues from Phasepoint) to the adaptor etc.)

It's a draft -- my VS Code made several formatting changes that I'd need to unwind + I need to take out changes to Project.toml (used for running examples)

@svilupp
Copy link
Author

svilupp commented Jan 11, 2023

I did a quick benchmark on a simple linear regression model and I don't see much difference in how quickly it adapts (perhaps it's too easy to see any difference)

This is the first 100 tuning draws and what values were sampled.
The hypothesis was that if Nutpie is better, it would hone in on the right values faster, but as you can see it takes only c. 10 draws to get it right for both algorithms -- ie, hard to compare.
comparison

# Example for Nuts-rs / Nutpie Adaptor
using AdvancedHMC, ForwardDiff
using LinearAlgebra
using Distributions
using Plots
const A = AdvancedHMC
using LogDensityProblems, TransformVariables, TransformedLogDensities, Parameters
using AbstractMCMC: LogDensityModel

# Example taken from https://www.tamaspapp.eu/DynamicHMCExamples.jl/latest/example_linear_regression/
"""
Linear regression model ``y ∼ Xβ + ϵ``, where ``ϵ ∼ N(0, σ²)`` IID.

Weakly informative prior for `β`, half-T for `σ`.
"""
struct LinearRegressionProblem{TY<:AbstractVector,TX<:AbstractMatrix,
    Tν<:Real}
    "Observations."
    y::TY
    "Covariates"
    X::TX
    "Degrees of freedom for prior."
    ν::Tν
end
function (problem::LinearRegressionProblem)(θ)
    @unpack y, X, ν = problem   # extract the data
    @unpack β, σ = θ            # works on the named tuple too
    ϵ_distribution = Normal(0, σ) # the error term
    ℓ_error = mapreduce((y, x) -> logpdf(ϵ_distribution, y - dot(x, β)), +,
        y, eachrow(X))    # likelihood for error
    ℓ_σ = logpdf(TDist(ν), σ)             # prior for σ
    ℓ_β = loglikelihood(Normal(0, 10), β) # prior for β
    ℓ_error + ℓ_σ + ℓ_β
end
# Random data
n_samples, n_adapts = 2_000, 1_000
N = 100
β = [-10.0, 2.0, -1.0, -0.3, 3, 0.1, 8]
σ = 2
X = hcat(ones(N), randn(N, length(β) - 1));
y = X * β .+ randn(N) .* σ;
p = LinearRegressionProblem(y, X, 1.0);
@info string("log density: ", p((β=β, σ=σ)))
# Transform it to unconstrained space
function problem_transformation(p::LinearRegressionProblem)
    as((β=as(Array, size(p.X, 2)), σ=asℝ₊))
end
t = problem_transformation(p)
P = TransformedLogDensity(t, p)
initial_θ = ones(length(β) + 1)
metric = DiagEuclideanMetric(length(initial_θ))

# Sampling w Default
hamiltonian = Hamiltonian(metric, P, ForwardDiff)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))
# sample(LogDensityModel(P), proposal, metric, adaptor, 100)
@time samples1, stats1 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=true);


# NUTPIE
# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs
hamiltonian = Hamiltonian(metric, P, ForwardDiff)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
pc = A.ExpWeightedWelfordVar(size(metric))
adaptor = A.NutpieHMCAdaptor(pc, StepSizeAdaptor(0.8, integrator))
@time samples2, stats2 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=true);

# # Plots
# plot how many tuning draws it takes to get to the real values
compare_coefficient = let samples1 = samples1, samples2 = samples2, β = β, σ = σ
    function plotter(idx)
        pl = plot(getindex.(samples1, idx), xlim=(0, 100), label="Default", color=palette(:default)[1])
        plot!(pl, getindex.(samples2, idx), xlim=(0, 100), label="Nutpie/Nuts-rs", color=palette(:default)[2])
        if idx <= length(β)
            hline!(pl, β[[idx]], color=:red, linestyle=:dash, label="True value")
        else # if it's not beta, it must be sigma
            hline!(pl, [1 / σ], color=:red, linestyle=:dash, label="True value (untrf σ")
        end
        return pl
    end
end
k = 8
pl = plot([compare_coefficient(i) for i in 1:k]..., layout=(k, 1), size=(600, 400 + k * 100),
    plot_title="Comparison of components", legend=:bottomright)
savefig(pl, "comparison.png")

@aseyboldt
Copy link

From a very quick look:

  • I think you need a more challenging example problem. With that one everything will converge really quickly, and you won't see much of a differences between anything. An indep. normal probably also isn't ideal, nutpie should have the exact mass matrix with that after the first window switch and two draws (still a good test case though :-) ).
  • In your plot, I think it would make more sense to use number of gradient evaluations as x-axis, instead of number of draws, because those might look very different for different adaptation methods.
  • Plotting the mass matrix estimate or the step size over time is I think also often helpful to understand what's happening

Translated from Rust; a few choices that could be challenged (eg, frequencies, window sizes, but also the fact that diag. matrix is updated on each iteration)

If you experiment with those, I'd love to see the results. :-)

I don't think the cost of updating the mass matrix will matter that much in real word problems however. This still only happens once per draw, so every couple of gradient evals. And gradient evals are usually significantly more expensive than a diag mass matrix update.
The mass matrix updates in every iteration do however lead to a bit of trouble. Since we do it every time, we strictly speaking break the mcmc sampler, which can slightly bias the results, which can then bias the mass matrix estimate, which then makes sampling efficiency worse after tuning. It seemed to me that the advantages out way the costs, but I don't actually know for sure.

@aseyboldt
Copy link

Oh, and I'm not sure, but doesn't this look biased?
image

@svilupp
Copy link
Author

svilupp commented Feb 12, 2023

Quick update:

  • added dispatch for sample() to handle types that require gradient information (dispatches over a new abstract type AbstractHMCAdaptorWithGradients)
  • added variations of NutpieHMCAdaptor to benchmark various aspects of the new routine separately (eg, no initialization from gradients, no switching, etc)
  • cleaned up slightly for better readability
  • connected to PosteriorDB thanks to Sethaxen's amazing package universe + added some wrappers to easily plot the results, which made it obvious that something isn't right...

Results: Poor! I'm unable to get Nutpie to sample properly.

  • It too often results in divergences (so the ESS/grad comparison is meaningless), so there must be an issue somewhere in my code (given that Adrian's benchmarks were fine)
  • I'm using a lot of the WelfordVar variance estimators, so, hopefully, the issue will be easier to trace down. The get_estimation() for the variance estimate is different from Nutpie, but it matches the linked Stan implementation, so I don't think the issue is there (the default routine seems to be fine)
  • The most likely culprit is that I don't stop the numeric instabilities from flowing in, so the next step is to listen to the variance estimates and implement some filters for invalid gradients/samples (like in Nutpie), to prevent them from polluting the estimation

Initial results for the Diamonds model below (incl. the code to reproduce the example)

image

Code to reproduce the above chart

using Pkg;
using BridgeStan, AdvancedHMC, PosteriorDB, Random, StanLogDensityProblems, LogDensityProblems
using AdvancedHMC: ExpWeightedWelfordVar, NutpieHMCAdaptor, ExpWeightedWelfordVar, NutpieHMCAdaptorNoSwitch, NutpieHMCAdaptorNoGradInit, NutpieHMCAdaptorNoSwitchNoGradInit
# using DataFramesMeta
using MCMCDiagnosticTools
using Folds
using StatsPlots


# # Benchmark setup
# Seth Axen's amazing wrapper based on https://github.com/mlcolab/PathfinderBenchmarks.jl/blob/main/src/dynamichmc.jl
# wrapper to count the number of function evaluations and gradient evaluations
mutable struct EvalCountingProblem{P}
    const prob::P
    num_evals::Int
    num_grad_evals::Int
end
EvalCountingProblem(prob) = EvalCountingProblem(prob, 0, 0)

function LogDensityProblems.capabilities(::Type{<:EvalCountingProblem{P}}) where {P}
    return LogDensityProblems.capabilities(P)
end

function LogDensityProblems.dimension(prob::EvalCountingProblem)
    return LogDensityProblems.dimension(prob.prob)
end

function LogDensityProblems.logdensity(prob::EvalCountingProblem, x)
    prob.num_evals += 1
    return LogDensityProblems.logdensity(prob.prob, x)
end

function LogDensityProblems.logdensity_and_gradient(prob::EvalCountingProblem, x)
    prob.num_grad_evals += 1
    return LogDensityProblems.logdensity_and_gradient(prob.prob, x)
end
zero!(prob::EvalCountingProblem) = (prob.num_evals = 0; prob.num_grad_evals = 0)
Base.show(io::IO, m::MIME"text/plain", prob::EvalCountingProblem) = (ioinner = IOBuffer(); show(ioinner, m, prob.prob); print(io, "EvalCountingProblem($(take!(ioinner)|>String))"))

function ess_rhat(x)
    ess, rhat_bulk = MCMCDiagnosticTools.ess_rhat_bulk(x; maxlag=typemax(Int))
    rhat_tail = MCMCDiagnosticTools.rhat_tail(x)
    rhat = max.(rhat_bulk, rhat_tail)
    return (; ess, rhat)
end

#####################################
### Estimation routines
function sample_one_chain_(adaptor_type::Type{TA},metric_adaptor_type::Type{TMA}, rng_chain, prob; n_samples=1000, n_adapts=200,verbose=false) where {TA, TMA}
    D = LogDensityProblems.dimension(prob)
    initial_θ = rand(rng_chain, D)
    metric = DiagEuclideanMetric(D)
    count_prob = EvalCountingProblem(prob)
    hamiltonian = Hamiltonian(metric, count_prob)
    initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
    integrator = Leapfrog(initial_ϵ)
    proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
    # clear the counter after the step size is found
    zero!(count_prob)
    # provide an adaptor of choice
    metric_adaptor = metric_adaptor_type(size(metric))
    adaptor = adaptor_type(metric_adaptor, StepSizeAdaptor(0.8, integrator))
    samples, stats = sample(rng_chain, hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; drop_warmup=false, progress=false,verbose)
    samples = mapreduce(permutedims, vcat, samples) |> x -> reshape(x, (n_samples, 1, D))
    return samples, stats, count_prob.num_evals, count_prob.num_grad_evals
end

function sample_one_chain(adaptor_type,metric_adaptor_type, seed, rng, prob; n_samples=1000, n_adapts=200, n_tries=100,verbose=false)
    rng_chain = deepcopy(rng)
    Random.seed!(rng_chain, seed)
    i = 0
    while i < n_tries
        try
            samples, stats, num_evals, num_grad_evals = sample_one_chain_(adaptor_type,metric_adaptor_type, rng_chain, prob; n_samples, n_adapts,verbose)
            return samples, stats, num_evals, num_grad_evals
        catch err
            # @warn err
            @warn "Failed to sample, trying again ($i/$n_tries)"
            i += 1
        end
    end
    @error "Failed failed to sample after $n_tries tries"
end

function run_scenario(adaptor_type,metric_adaptor_type, rng, prob; n_samples=1000, n_adapts=200, n_chains=4, verbose=false)
    seeds = rand(rng, UInt, n_chains)
    # measure time
    time = @elapsed results = Folds.collect(sample_one_chain(adaptor_type,metric_adaptor_type, seed, rng, prob; n_samples, n_adapts, verbose) for seed in seeds);
    # extract results
    samples = hcat(getindex.(results, 1)...)
    stats = hcat(getindex.(results, 2)...)
    num_evals = getindex.(results, 3) |> sum
    num_grads_evals = getindex.(results, 4) |> sum
    # alternative
    # num_grads_evals = getindex.(stats, :n_steps) |> sum
    divergences = getindex.(stats, :numerical_error) |> sum
    ess, rhat = ess_rhat(samples)
    #
    ess_per_grad_mean = mean(ess) ./ num_grads_evals
    ess_mean = mean(ess)
    ess_min = minimum(ess) 
    ess_max = maximum(ess)
    rhat_max = maximum(rhat)
    @info "Results: $(ess_per_grad_mean) ESS/grad eval in $(time) seconds with $(divergences) divergences (max Rhat: $rhat_max)"
    return  (;ess_mean,ess_max,ess_min,num_grads_evals,ess_per_grad_mean,rhat_max, divergences, time)
end

# # Explore posteriorDB
# posterior_name = "diamonds-diamonds"
# post = posterior(pdb, posterior_name)
# mod = model(post)
# data = dataset(post)
# info(post)

# impl = implementation(mod, "stan")
# mod_code = load(impl)
# println(mod_code)

# load(data)
# ref = reference_posterior(post)
# info(ref)
# ref = DataFrame(load(ref))
# prob = StanProblem(post, ".", force=true, make_args=["STAN_THREADS=true"])
# LogDensityProblems.capabilities(prob)
# rng = Random.default_rng();
# LogDensityProblems.logdensity(prob, initial_θ)

# # Run Scenario: Diamond Model
pdb = database()
posterior_name = "diamonds-diamonds"
post = posterior(pdb, posterior_name)
prob = StanProblem(post, ".", force=true, make_args=["STAN_THREADS=true"])

# Set parameters
n_adapts = 1000
n_samples = 1000
n_chains = 4
n_tries = 100

# STAN DEFAULT
# rng = Random.default_rng();
rng = Random.MersenneTwister(1234);
res1=run_scenario(StanHMCAdaptor,WelfordVar, rng, prob; n_samples, n_adapts,n_chains,verbose=false);

# NUTPIE Variants
# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs
rng = Random.MersenneTwister(1234);
res2=run_scenario(NutpieHMCAdaptor,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res3=run_scenario(NutpieHMCAdaptorNoGradInit,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res4=run_scenario(NutpieHMCAdaptorNoSwitch,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);
rng = Random.MersenneTwister(1234);
res5=run_scenario(NutpieHMCAdaptorNoSwitchNoGradInit,ExpWeightedWelfordVar, rng, prob; n_samples, n_adapts, n_chains,verbose=false);


# Plot it
function plot_bar(res_array,metric,labels=["Stan-like","Nutpie","Nutpie no grad init","Nutpie no switch","Nutpie no switch no grad init"];
    kwargs...)
    pl = bar(labels,getfield.(res_array,metric);xrotation=15,left_margin=5Plots.mm,bottom_margin=5Plots.mm,kwargs...)
end

pl = let res_array = [res1,res2,res3,res4,res5]
    plot(
        plot_bar(res_array,:ess_per_grad_mean; title="ESS per gradient evaluation",
            ylabel="ESS/grad eval",legend=false),
        plot_bar(res_array,:ess_mean; title="ESS (mean)",
            ylabel="ESS (mean)",legend=false), 
        plot_bar(res_array,:num_grads_evals; title="# of Gradient Eval.",
            ylabel="Grad. evaluations",legend=false),
        plot_bar(res_array,:rhat_max; title="Max Rhat",
            ylabel="Rhat (max)",legend=false),
        plot_bar(res_array,:divergences; title="# of Divergences",
            ylabel="Divergences",legend=false),
        plot_bar(res_array,:time; title="Time Elapsed",
            ylabel="Time (s)",legend=false),
            size=(900,600),layout=(2,3),titlefontsize=12)
end

# savefig(pl,"diamonds_20230212.png")

@aseyboldt
Copy link

I don't know for sure, bad step size adaptation might explain what you are seeing. If the mass matrix itself is good, but the final step size doesn't match the mass matrix, you might see lot's of divergences.
If that's the problem, the actual mean acceptance rate after tuning would not match the target.

And shouldn't this be a final step size adaptation, instead of mass matrix adaptation?
https://github.com/TuringLang/AdvancedHMC.jl/pull/312/files#diff-622665792b73235f2c5b58233a3eb82abaea68b4133c51bf9471d7bac99c10d3R118

@svilupp
Copy link
Author

svilupp commented Feb 13, 2023

I don't know for sure, bad step size adaptation might explain what you are seeing. If the mass matrix itself is good, but the final step size doesn't match the mass matrix, you might see lot's of divergences. If that's the problem, the actual mean acceptance rate after tuning would not match the target.

Good tip! Thanks - I'll look into it.
My assumption so far is that the mass matrix is bad, so it gives too much kinetic energy (hence, the divergences), but I'll look out for the step sizes too!

And shouldn't this be a final step size adaptation, instead of mass matrix adaptation? https://github.com/TuringLang/AdvancedHMC.jl/pull/312/files#diff-622665792b73235f2c5b58233a3eb82abaea68b4133c51bf9471d7bac99c10d3R118

I've removed these files, they were just background artefacts. The one you references is a chatGPT re-write of your codebase - I think the line you referenced is this one.

The actual implementation is here, where it first updates the step size (same as default in this package) and then we update the mass_matrix as per Nutpie here (it's decoupled from the switch! which happens 4 lines earlier).

This adapt! calls the adaptation method of the variance accumulator, which is just a bundle of 4 Welford variance accumulators (WelfordVar) - two for draws, two for grads. The different lingo here is that push! methods add new samples to the accumulators, update! updates the variance estimator in adaptor.exp_variance_draw.var (I tried to adhere a bit to your naming) here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants