Skip to content

Commit

Permalink
Fix #245
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebouchard committed Aug 13, 2024
1 parent 0dc2048 commit 2ea66d4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/explorers/SliceSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ function slice_shrink!(h::SliceSampler, replica, z, L, R, lp_L, lp_R, pointer, l
else
Rbar = new_position
end
if Lbar Rbar
# see https://github.com/UBC-Stat-ML/blangSDK/blob/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/src/main/java/blang/mcmc/RealSliceSampler.java#L111
pointer[] = old_position
return log_potential(state)
end
n += 1
end
# code should never get here, because eventually
Expand Down
15 changes: 15 additions & 0 deletions test/test_slice_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,19 @@ end
explorer = SliceSampler(w = 0.1, p = 20, n_passes = 1, max_iter = 1_024)
)
@test_throws "AssertionError: for integer variables, the width should be an integer. Got: 0.1" pt = pigeons(inputs)
end

# This covers the Lbar ≈ Rbar check in slice_shrink!
struct Dirac end
function (::Dirac)(x) # Dirac in first coordinate, Gaussian in the second
return x[1] == 1.1 ? -x[2]^2/2.0 : -Inf64
end
Pigeons.initialization(::Dirac, ::AbstractRNG, ::Int) = [1.1, 0.0]

@testset "Dirac" begin
pt = pigeons(target = Dirac(), reference = Dirac(), n_chains = 1, record = [online], n_rounds = 15)
@test mean(pt)[1] == 1.1
@test (mean(pt)[2], 0.0, atol = 0.01)
@test var(pt)[1] == 0.0
@test (var(pt)[2], 1.0, atol = 0.01)
end

0 comments on commit 2ea66d4

Please sign in to comment.