From 2ea66d4d8953eb46586dae4036fd7d92b77406cd Mon Sep 17 00:00:00 2001 From: Alexandre Bouchard Date: Tue, 13 Aug 2024 13:56:16 -0700 Subject: [PATCH] Fix #245 --- src/explorers/SliceSampler.jl | 5 +++++ test/test_slice_sampler.jl | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/explorers/SliceSampler.jl b/src/explorers/SliceSampler.jl index 1a940e732..43622c3b5 100644 --- a/src/explorers/SliceSampler.jl +++ b/src/explorers/SliceSampler.jl @@ -166,6 +166,11 @@ function slice_shrink!(h::SliceSampler, replica, z, L, R, lp_L, lp_R, pointer, l else Rbar = new_position end + if Lbar ≈ Rbar + # see https://github.com/UBC-Stat-ML/blangSDK/blob/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/src/main/java/blang/mcmc/RealSliceSampler.java#L111 + pointer[] = old_position + return log_potential(state) + end n += 1 end # code should never get here, because eventually diff --git a/test/test_slice_sampler.jl b/test/test_slice_sampler.jl index 274a0d2b2..be84a1dc0 100644 --- a/test/test_slice_sampler.jl +++ b/test/test_slice_sampler.jl @@ -117,4 +117,19 @@ end explorer = SliceSampler(w = 0.1, p = 20, n_passes = 1, max_iter = 1_024) ) @test_throws "AssertionError: for integer variables, the width should be an integer. Got: 0.1" pt = pigeons(inputs) +end + +# This covers the Lbar ≈ Rbar check in slice_shrink! +struct Dirac end +function (::Dirac)(x) # Dirac in first coordinate, Gaussian in the second + return x[1] == 1.1 ? -x[2]^2/2.0 : -Inf64 +end +Pigeons.initialization(::Dirac, ::AbstractRNG, ::Int) = [1.1, 0.0] + +@testset "Dirac" begin + pt = pigeons(target = Dirac(), reference = Dirac(), n_chains = 1, record = [online], n_rounds = 15) + @test mean(pt)[1] == 1.1 + @test ≈(mean(pt)[2], 0.0, atol = 0.01) + @test var(pt)[1] == 0.0 + @test ≈(var(pt)[2], 1.0, atol = 0.01) end \ No newline at end of file