diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5dea45729..2fbdac85f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -39,13 +39,13 @@ jobs: with: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - - uses: julia-actions/julia-processcoverage@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -71,7 +71,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/cache@v2 @@ -85,7 +85,7 @@ jobs: MPIPreferences.use_jll_binary("OpenMPI_jll") rm("test/Manifest.toml") - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-runtest@v1 # # CI is getting too slow! @@ -119,7 +119,7 @@ jobs: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} @@ -137,7 +137,7 @@ jobs: run(`sed -i.bu 's/unknown/MPICH/' test/LocalPreferences.toml`) # fix wrong abi detection for mpich rm("test/Manifest.toml") - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-runtest@v1 docs: @@ -151,11 +151,11 @@ jobs: with: distribution: 'temurin' java-version: '11' - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: '1.10' - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-docdeploy@latest + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | diff --git a/Project.toml b/Project.toml index 637747b31..63b579ddb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Pigeons" uuid = "0eb8d820-af6a-4919-95ae-11206f830c31" authors = ["Alexandre Bouchard-Côté , Nikola Surjanovic , Paul Tiede , Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"] -version = "0.4.6" +version = "0.4.7" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index 55fba60e3..0a12d3dee 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -19,12 +19,12 @@ and pair plots (via [PairPlots](https://sefffal.github.io/PairPlots.jl/dev/chain function sample_array(pt::PT) chains = chains_with_samples(pt) dim, size = sample_dim_size(pt, chains) - result = zeros(size, dim, length(chains)) - for chain_index in eachindex(chains) - t = chains[chain_index] + n_chains_with_samples = count(!isnothing, chains) # iterators have no `length` method + result = zeros(size, dim, n_chains_with_samples) + for (chain_index, t) in enumerate(chains) sample = get_sample(pt, t) - for i in 1:size - vector = sample[i] + for i in 1:size + vector = sample[i] result[i, :, chain_index] .= vector end end @@ -34,13 +34,13 @@ end chains_with_samples(pt) = pt.inputs.extended_traces ? (1:n_chains(pt.inputs)) : target_chains(pt) function sample_dim_size(pt::PT, chains) - sample = get_sample(pt, chains[1]) - return length(sample[1]), length(sample) + sample = get_sample(pt, first(chains)) + return length(first(sample)), length(sample) end function target_chains(pt::PT) n = n_chains(pt.inputs) - return filter(i -> is_target(pt.shared.tempering.swap_graphs, i), 1:n) + return (i for i in 1:n if is_target(pt.shared.tempering.swap_graphs, i)) end """ @@ -84,7 +84,7 @@ The `chain` option can be omitted and by default the first chain targetting the distribution of interest will be used (in many cases, there will be only one, in variational cases, two). """ -get_sample(pt::PT, chain = target_chains(pt)[1]) = SampleArray(pt, chain) +get_sample(pt::PT, chain = first(target_chains(pt))) = SampleArray(pt, chain) function Base.show(io::IO, s::SampleArray{T,PT}) where {T,PT} println(io, "SampleArray{$T}") diff --git a/src/swap/VariationalDEO.jl b/src/swap/VariationalDEO.jl index 07fd31553..1a2b23960 100644 --- a/src/swap/VariationalDEO.jl +++ b/src/swap/VariationalDEO.jl @@ -18,4 +18,4 @@ create_swap_graph(deo::VariationalDEO, shared) = iseven(shared.iterators.scan) ? even(deo.n_chains_fixed, deo.n_chains_var) : odd(deo.n_chains_fixed, deo.n_chains_var) is_reference(deo::VariationalDEO, chain::Int) = (chain == 1) || (chain == n_chains(deo)) -is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_fixed) || (chain == deo.n_chains_fixed + 1) \ No newline at end of file +is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_var) || (chain == deo.n_chains_var + 1) \ No newline at end of file diff --git a/test/supporting/setup.jl b/test/supporting/setup.jl index c8b3dd27c..dcfee1d41 100644 --- a/test/supporting/setup.jl +++ b/test/supporting/setup.jl @@ -7,6 +7,7 @@ using Pigeons, Distributions, DynamicPPL, Enzyme, + FillArrays, ForwardDiff, HypothesisTests, LinearAlgebra, diff --git a/test/test_mala.jl b/test/test_mala.jl index d8a880f9c..45e8462a9 100644 --- a/test/test_mala.jl +++ b/test/test_mala.jl @@ -8,6 +8,7 @@ which also cover checks of the Hamiltonian dynamics. target = toy_mvn_target(2), n_chains = 2, explorer = MALA(), + n_chains_variational = 4, record = [Pigeons.online], n_rounds = 10); for var_name in Pigeons.continuous_variables(pt) diff --git a/test/test_two_legs.jl b/test/test_two_legs.jl index 2288ebf25..8c1168d02 100644 --- a/test/test_two_legs.jl +++ b/test/test_two_legs.jl @@ -1,10 +1,11 @@ -@testset "Two legs schedule adaptation" begin +@testset "Test StabilizedPT machinery" begin n_rounds = 10 - n_chains = 10 + n_chains = 8 + n_chains_variational = 7 pt_2_legs = pigeons(; target = Pigeons.toy_turing_unid_target(), variational = GaussianReference(first_tuning_round = n_rounds + 1), # never activate - n_chains_variational = n_chains, + n_chains_variational = n_chains_variational, n_chains, n_rounds) @@ -15,14 +16,26 @@ n_chains, n_rounds) - - @show gcb_1 = Pigeons.global_barrier(pt_1_leg.shared.tempering) - @show gcb_2_1 = Pigeons.global_barrier(pt_2_legs.shared.tempering) - @show gcb_2_2 = Pigeons.global_barrier_variational(pt_2_legs.shared.tempering) + @testset "Two legs schedule adaptation" begin + @show gcb_1 = Pigeons.global_barrier(pt_1_leg.shared.tempering) + @show gcb_2_1 = Pigeons.global_barrier(pt_2_legs.shared.tempering) + @show gcb_2_2 = Pigeons.global_barrier_variational(pt_2_legs.shared.tempering) + truth = 3.5 # based on 15 rounds + for approx in [gcb_1, gcb_2_1, gcb_2_2] + @test isapprox(approx, truth, rtol = 0.1) + end + end - truth = 3.5 # based on 15 rounds - - for approx in [gcb_1, gcb_2_1, gcb_2_2] - @test isapprox(approx, truth, atol = 0.1) + @testset "Issue #290" begin + n = Pigeons.n_chains(pt_2_legs.inputs) + idxs_targets = Pigeons.target_chains(pt_2_legs) + idxs_refs = (i for i in 1:n if Pigeons.is_reference(pt_2_legs.shared.tempering.swap_graphs, i)) + indexer = pt_2_legs.shared.tempering.indexer + @test isempty(intersect(idxs_refs, idxs_targets)) # targets and references should be different + # test: if multiple references/targets exist, they should live on different legs + for idxs in (idxs_targets, idxs_refs) + n_distinct_legs = length(Set(last(indexer.i2t[idx]) for idx in idxs)) + @test length(collect(idxs)) == n_distinct_legs + end end end \ No newline at end of file