From 9678ce903031320963900879d3e48a6f7ee57e14 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 13:37:22 -0700 Subject: [PATCH 01/12] fix target chain identifier --- src/swap/VariationalDEO.jl | 2 +- test/test_two_legs.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/swap/VariationalDEO.jl b/src/swap/VariationalDEO.jl index 07fd31553..1a2b23960 100644 --- a/src/swap/VariationalDEO.jl +++ b/src/swap/VariationalDEO.jl @@ -18,4 +18,4 @@ create_swap_graph(deo::VariationalDEO, shared) = iseven(shared.iterators.scan) ? even(deo.n_chains_fixed, deo.n_chains_var) : odd(deo.n_chains_fixed, deo.n_chains_var) is_reference(deo::VariationalDEO, chain::Int) = (chain == 1) || (chain == n_chains(deo)) -is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_fixed) || (chain == deo.n_chains_fixed + 1) \ No newline at end of file +is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_var) || (chain == deo.n_chains_var + 1) \ No newline at end of file diff --git a/test/test_two_legs.jl b/test/test_two_legs.jl index 2288ebf25..e5d83c254 100644 --- a/test/test_two_legs.jl +++ b/test/test_two_legs.jl @@ -25,4 +25,12 @@ for approx in [gcb_1, gcb_2_1, gcb_2_2] @test isapprox(approx, truth, atol = 0.1) end + + @testset "Issue #290" begin + # test: if multiple targets exist, they should live on different legs + idx_targets = Pigeons.target_chains(pt_2_legs) + indexer = pt_2_legs.shared.tempering.indexer + n_distinct_legs = length(Set(last(indexer.i2t[idx]) for idx in idx_targets)) + @test length(idx_targets) == n_distinct_legs + end end \ No newline at end of file From 18d99cb79793b7fe1944b3f6d6be3104fe361558 Mon Sep 17 00:00:00 2001 From: Alexandre Bouchard Date: Mon, 21 Oct 2024 16:08:25 -0700 Subject: [PATCH 02/12] Test case that would have caught #290 --- test/test_mala.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_mala.jl b/test/test_mala.jl index d8a880f9c..644413a84 100644 --- a/test/test_mala.jl +++ b/test/test_mala.jl @@ -8,6 +8,7 @@ which also cover checks of the Hamiltonian dynamics. target = toy_mvn_target(2), n_chains = 2, explorer = MALA(), + n_chains_variational = 10, record = [Pigeons.online], n_rounds = 10); for var_name in Pigeons.continuous_variables(pt) From a14e5d82b3326895000dd9b0d2fc602b06f234a2 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 16:10:35 -0700 Subject: [PATCH 03/12] more tests --- src/pt/process_sample.jl | 4 ++-- test/test_two_legs.jl | 39 ++++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index 55fba60e3..4f22596c4 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -40,7 +40,7 @@ end function target_chains(pt::PT) n = n_chains(pt.inputs) - return filter(i -> is_target(pt.shared.tempering.swap_graphs, i), 1:n) + return (i for i in 1:n if is_target(pt.shared.tempering.swap_graphs, i)) end """ @@ -84,7 +84,7 @@ The `chain` option can be omitted and by default the first chain targetting the distribution of interest will be used (in many cases, there will be only one, in variational cases, two). """ -get_sample(pt::PT, chain = target_chains(pt)[1]) = SampleArray(pt, chain) +get_sample(pt::PT, chain = first(target_chains(pt))) = SampleArray(pt, chain) function Base.show(io::IO, s::SampleArray{T,PT}) where {T,PT} println(io, "SampleArray{$T}") diff --git a/test/test_two_legs.jl b/test/test_two_legs.jl index e5d83c254..8c1168d02 100644 --- a/test/test_two_legs.jl +++ b/test/test_two_legs.jl @@ -1,10 +1,11 @@ -@testset "Two legs schedule adaptation" begin +@testset "Test StabilizedPT machinery" begin n_rounds = 10 - n_chains = 10 + n_chains = 8 + n_chains_variational = 7 pt_2_legs = pigeons(; target = Pigeons.toy_turing_unid_target(), variational = GaussianReference(first_tuning_round = n_rounds + 1), # never activate - n_chains_variational = n_chains, + n_chains_variational = n_chains_variational, n_chains, n_rounds) @@ -15,22 +16,26 @@ n_chains, n_rounds) - - @show gcb_1 = Pigeons.global_barrier(pt_1_leg.shared.tempering) - @show gcb_2_1 = Pigeons.global_barrier(pt_2_legs.shared.tempering) - @show gcb_2_2 = Pigeons.global_barrier_variational(pt_2_legs.shared.tempering) - - truth = 3.5 # based on 15 rounds - - for approx in [gcb_1, gcb_2_1, gcb_2_2] - @test isapprox(approx, truth, atol = 0.1) + @testset "Two legs schedule adaptation" begin + @show gcb_1 = Pigeons.global_barrier(pt_1_leg.shared.tempering) + @show gcb_2_1 = Pigeons.global_barrier(pt_2_legs.shared.tempering) + @show gcb_2_2 = Pigeons.global_barrier_variational(pt_2_legs.shared.tempering) + truth = 3.5 # based on 15 rounds + for approx in [gcb_1, gcb_2_1, gcb_2_2] + @test isapprox(approx, truth, rtol = 0.1) + end end - + @testset "Issue #290" begin - # test: if multiple targets exist, they should live on different legs - idx_targets = Pigeons.target_chains(pt_2_legs) + n = Pigeons.n_chains(pt_2_legs.inputs) + idxs_targets = Pigeons.target_chains(pt_2_legs) + idxs_refs = (i for i in 1:n if Pigeons.is_reference(pt_2_legs.shared.tempering.swap_graphs, i)) indexer = pt_2_legs.shared.tempering.indexer - n_distinct_legs = length(Set(last(indexer.i2t[idx]) for idx in idx_targets)) - @test length(idx_targets) == n_distinct_legs + @test isempty(intersect(idxs_refs, idxs_targets)) # targets and references should be different + # test: if multiple references/targets exist, they should live on different legs + for idxs in (idxs_targets, idxs_refs) + n_distinct_legs = length(Set(last(indexer.i2t[idx]) for idx in idxs)) + @test length(collect(idxs)) == n_distinct_legs + end end end \ No newline at end of file From f73a8a315fbf890d593cd2011e56f82664c7ca4e Mon Sep 17 00:00:00 2001 From: Alexandre Bouchard Date: Mon, 21 Oct 2024 16:15:04 -0700 Subject: [PATCH 04/12] Make it run faster --- test/test_mala.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_mala.jl b/test/test_mala.jl index 644413a84..45e8462a9 100644 --- a/test/test_mala.jl +++ b/test/test_mala.jl @@ -8,7 +8,7 @@ which also cover checks of the Hamiltonian dynamics. target = toy_mvn_target(2), n_chains = 2, explorer = MALA(), - n_chains_variational = 10, + n_chains_variational = 4, record = [Pigeons.online], n_rounds = 10); for var_name in Pigeons.continuous_variables(pt) From 2573731c6b61ab454a29f4e5b8ab7cafbda78925 Mon Sep 17 00:00:00 2001 From: Alexandre Bouchard Date: Mon, 21 Oct 2024 16:18:00 -0700 Subject: [PATCH 05/12] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 637747b31..63b579ddb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Pigeons" uuid = "0eb8d820-af6a-4919-95ae-11206f830c31" authors = ["Alexandre Bouchard-Côté , Nikola Surjanovic , Paul Tiede , Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"] -version = "0.4.6" +version = "0.4.7" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" From d257d3e4dddf384c86f4e20e7086a93a9fccbd9f Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 17:50:02 -0700 Subject: [PATCH 06/12] use first instead of [] --- src/pt/process_sample.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index 4f22596c4..1c6da182c 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -34,8 +34,8 @@ end chains_with_samples(pt) = pt.inputs.extended_traces ? (1:n_chains(pt.inputs)) : target_chains(pt) function sample_dim_size(pt::PT, chains) - sample = get_sample(pt, chains[1]) - return length(sample[1]), length(sample) + sample = get_sample(pt, first(chains)) + return length(first(sample)), length(sample) end function target_chains(pt::PT) From eec16d9fd27858f0b8caa03f2ae063328b585c52 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 17:59:15 -0700 Subject: [PATCH 07/12] fix length for iterator --- src/pt/process_sample.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index 1c6da182c..e4c375416 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -19,7 +19,8 @@ and pair plots (via [PairPlots](https://sefffal.github.io/PairPlots.jl/dev/chain function sample_array(pt::PT) chains = chains_with_samples(pt) dim, size = sample_dim_size(pt, chains) - result = zeros(size, dim, length(chains)) + n_chains_with_samples = count(!isnothing, chains) # iterators have no `length` method + result = zeros(size, dim, n_chains_with_samples) for chain_index in eachindex(chains) t = chains[chain_index] sample = get_sample(pt, t) From 738a7a5ed6eb2cc5701316ff6f4fedb74bfcf41d Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 18:06:02 -0700 Subject: [PATCH 08/12] use enumerate --- src/pt/process_sample.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index e4c375416..0a12d3dee 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -21,11 +21,10 @@ function sample_array(pt::PT) dim, size = sample_dim_size(pt, chains) n_chains_with_samples = count(!isnothing, chains) # iterators have no `length` method result = zeros(size, dim, n_chains_with_samples) - for chain_index in eachindex(chains) - t = chains[chain_index] + for (chain_index, t) in enumerate(chains) sample = get_sample(pt, t) - for i in 1:size - vector = sample[i] + for i in 1:size + vector = sample[i] result[i, :, chain_index] .= vector end end From e442c0987f78fe4ecb6e3485486353f8712d1c5a Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 18:11:55 -0700 Subject: [PATCH 09/12] add back missing FillArrays --- test/supporting/setup.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/supporting/setup.jl b/test/supporting/setup.jl index c8b3dd27c..dcfee1d41 100644 --- a/test/supporting/setup.jl +++ b/test/supporting/setup.jl @@ -7,6 +7,7 @@ using Pigeons, Distributions, DynamicPPL, Enzyme, + FillArrays, ForwardDiff, HypothesisTests, LinearAlgebra, From 0c3cbc8db4912ed188aff6051e58d352f39ee78d Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 18:42:18 -0700 Subject: [PATCH 10/12] fix CI to avoid precompilation latency in building docs --- .github/workflows/CI.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5dea45729..aee9e7d18 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -39,13 +39,13 @@ jobs: with: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - - uses: julia-actions/julia-processcoverage@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -71,7 +71,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/cache@v2 @@ -85,7 +85,7 @@ jobs: MPIPreferences.use_jll_binary("OpenMPI_jll") rm("test/Manifest.toml") - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-runtest@v1 # # CI is getting too slow! @@ -119,7 +119,7 @@ jobs: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} @@ -137,7 +137,7 @@ jobs: run(`sed -i.bu 's/unknown/MPICH/' test/LocalPreferences.toml`) # fix wrong abi detection for mpich rm("test/Manifest.toml") - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-runtest@v1 docs: @@ -151,11 +151,11 @@ jobs: with: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: '1.10' - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-docdeploy@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | From fda52a9f0071d62ea7331e302e3e49ba3fa328f5 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 18:47:33 -0700 Subject: [PATCH 11/12] use cache --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index aee9e7d18..6a9e0cff0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -154,6 +154,7 @@ jobs: - uses: julia-actions/setup-julia@v2 with: version: '1.10' + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: From 0c8cea301e2417aae7dc2f4d64d89a33671cf29a Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Mon, 21 Oct 2024 19:03:15 -0700 Subject: [PATCH 12/12] rm unnecessary buildpkg --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6a9e0cff0..2fbdac85f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -155,7 +155,6 @@ jobs: with: version: '1.10' - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}