Skip to content

Commit

Permalink
feat(inference): add nuts sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
chentoast committed Oct 4, 2022
1 parent dce003c commit 321fd29
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 12 deletions.
12 changes: 0 additions & 12 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
function sample_momenta(n::Int)
Float64[random(normal, 0, 1) for _=1:n]
end

function assess_momenta(momenta)
logprob = 0.
for val in momenta
logprob += logpdf(normal, val, 0, 1)
end
logprob
end

"""
(new_trace, accepted) = hmc(
trace, selection::Selection; L=10, eps=0.1,
Expand Down
39 changes: 39 additions & 0 deletions src/inference/hmc_common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function sample_momenta(n::Int)
Float64[random(normal, 0, 1) for _=1:n]
end

function assess_momenta(momenta)
logprob = 0.
for val in momenta
logprob += logpdf(normal, val, 0, 1)
end
logprob
end

function add_choicemaps(a::ChoiceMap, b::ChoiceMap)
out = choicemap()

for (name, val) in get_values_shallow(a)
out[name] = val + b[name]
end

for (name, submap) in get_submaps_shallow(a)
out.internal_nodes[name] = add_choicemaps(submap, get_submap(b, name))
end

return out
end

function scale_choicemap(a::ChoiceMap, scale)
out = choicemap()

for (name, val) in get_values_shallow(a)
out[name] = val * scale
end

for (name, submap) in get_submaps_shallow(a)
out.internal_nodes[name] = scale_choicemap(submap, scale)
end

return out
end
3 changes: 3 additions & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ export logsumexp

include("trace_translators.jl")

include("hmc_common.jl")

# mcmc
include("kernel_dsl.jl")
include("mh.jl")
include("hmc.jl")
include("nuts.jl")
include("mala.jl")
include("elliptical_slice.jl")

Expand Down
179 changes: 179 additions & 0 deletions src/inference/nuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using LinearAlgebra: dot

struct Tree
val_left :: ChoiceMap
momenta_left :: ChoiceMap
val_right :: ChoiceMap
momenta_right :: ChoiceMap
val_sample :: ChoiceMap
n :: Int
weight :: Float64
stop :: Bool
diverging :: Bool
end

function u_turn(values_left, values_right, momenta_left, momenta_right)
return (dot(values_left - values_right, momenta_right) >= 0) &&
(dot(values_right - values_left, momenta_left) >= 0)
end

function leapfrog(values_trie, momenta_trie, eps, integrator_state)
selection, retval_grad, trace = integrator_state

(trace, _, _) = update(trace, values_trie)
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)

# half step on momenta
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))

# full step on positions
values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps))

# get new gradient
(trace, _, _) = update(trace, values_trie)
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)

# half step on momenta
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
return values_trie, momenta_trie, get_score(trace)
end

function build_root(val, momenta, eps, direction, weight_init, integrator_state)
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
weight = lp + assess_momenta(to_array(momenta, Float64))

diverging = weight - weight_init > 1000

return Tree(val, momenta, val, momenta, val, 1, weight, false, diverging)
end

function merge_trees(tree_left, tree_right)
# multinomial sampling
if log(rand()) < tree_right.weight - tree_left.weight
sample = tree_right.val_sample
else
sample = tree_left.val_sample
end

weight = logsumexp(tree_left.weight, tree_right.weight)
n = tree_left.n + tree_right.n

stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64),
to_array(tree_right.val_right, Float64),
to_array(tree_left.momenta_left, Float64),
to_array(tree_right.momenta_right, Float64))
diverging = tree_left.diverging || tree_right.diverging

return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
tree_right.momenta_right, sample, n, weight, stop, diverging)
end

function build_tree(val, momenta, depth, eps, direction, weight_init, integrator_state)
if depth == 0
return build_root(val, momenta, eps, direction, weight_init, integrator_state)
end

tree = build_tree(val, momenta, depth - 1, eps, direction, weight_init, integrator_state)

if tree.stop || tree.diverging
return tree
end

if direction == 1
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
weight_init, integrator_state)
return merge_trees(tree, other_tree)
else
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
weight_init, integrator_state)
return merge_trees(other_tree, tree)
end
end

"""
(new_trace, sampler_statistics) = nuts(
trace, selection::Selection;eps=0.1,
max_treedepth=15, check=false, observations=EmptyChoiceMap())
Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not..
The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
`sampler_statistics` is a struct containing the following fields:
- depth: the depth of the trajectory tree
- n: the number of samples in the trajectory tree
- sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
- n_accept: how many intermediate samples were accepted
- accept: whether the sample was accepted or not
# References
Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
"""
function nuts(
trace::Trace, selection::Selection; eps=0.1, max_treedepth=15,
check=false, observations=EmptyChoiceMap())
prev_model_score = get_score(trace)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

# values needed for a leapfrog step
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)

momenta = sample_momenta(length(to_array(values_trie, Float64)))
momenta_trie = from_array(values_trie, momenta)
prev_momenta_score = assess_momenta(momenta)

weight_init = prev_model_score + prev_momenta_score

integrator_state = (selection, retval_grad, trace)

tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false)

direction = 0
depth = 0
stop = false
while depth < max_treedepth
direction = rand([-1, 1])

if direction == 1 # going right
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction,
weight_init, integrator_state)
tree = merge_trees(tree, other_tree)
else # going left
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction,
weight_init, integrator_state)
tree = merge_trees(other_tree, tree)
end

stop = stop || tree.stop || tree.diverging
if stop
break
end
depth += 1
end

(new_trace, _, _) = update(trace, tree.val_sample)
check && check_observations(get_choices(new_trace), observations)

# assess new model score (negative potential energy)
new_model_score = get_score(new_trace)

# assess new momenta score (negative kinetic energy)
if direction == 1
new_momenta_score = assess_momenta(to_array(tree.momenta_right, Float64))
else
new_momenta_score = assess_momenta(to_array(tree.momenta_left, Float64))
end

# accept or reject
alpha = new_model_score + new_momenta_score - weight_init
if log(rand()) < alpha
return (new_trace, true)
else
return (trace, false)
end
end

export nuts

20 changes: 20 additions & 0 deletions test/inference/nuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "nuts" begin

# smoke test a function without retval gradient
@gen function foo()
x = @trace(normal(0, 1), :x)
return x
end

(trace, _) = generate(foo, ())
(new_trace, accepted) = nuts(trace, select(:x))

# smoke test a function with retval gradient
@gen (grad) function foo()
x = @trace(normal(0, 1), :x)
return x
end

(trace, _) = generate(foo, ())
(new_trace, accepted) = nuts(trace, select(:x))
end

0 comments on commit 321fd29

Please sign in to comment.