From eedb02f3f5ce867b085c020ee2a0efb6bea1d7ab Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Oct 2024 07:17:58 +0100 Subject: [PATCH 01/18] add `getparams` and `setparams!!` --- Project.toml | 4 +-- src/abstractmcmc.jl | 9 ++++++ test/abstractmcmc.jl | 70 ++++++++++++++++++++++++++------------------ 3 files changed, 52 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index fa49a335..83102a0f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.2" +version = "0.6.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains" AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" [compat] -AbstractMCMC = "5" +AbstractMCMC = "5.5" ArgCheck = "1, 2" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 24ce799c..1472a622 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -30,6 +30,15 @@ getadaptor(state::HMCState) = state.adaptor getmetric(state::HMCState) = state.metric getintegrator(state::HMCState) = state.κ.τ.integrator +function AbstractMCMC.getparams(state::HMCState) + # TODO(sunxd): should we return a copy? + return state.transition.z.θ +end + +function AbstractMCMC.setparams!!(state::HMCState, θ) + return @set state.transition.z.θ = θ +end + """ $(TYPEDSIGNATURES) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 25359cd6..a5fa2536 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -8,7 +8,7 @@ using Statistics: mean θ_init = randn(rng, 2) nuts = NUTS(0.8) - hmc = HMC(100; integrator = Leapfrog(0.05)) + hmc = HMC(100; integrator=Leapfrog(0.05)) hmcda = HMCDA(0.8, 0.1) integrator = Leapfrog(1e-3) @@ -21,15 +21,27 @@ using Statistics: mean LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo), ) + @testset "getparams and setparams!!" begin + t, s = AbstractMCMC.step( + rng, + model, + nuts; + ) + + θ = AbstractMCMC.getparams(s) + @test θ == t.z.θ + @test AbstractMCMC.setparams!!(s, θ) == s + end + samples_nuts = AbstractMCMC.sample( rng, model, nuts, n_adapts + n_samples; - n_adapts = n_adapts, - initial_params = θ_init, - progress = false, - verbose = false, + n_adapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, ) # Error if keyword argument `nadapts` is used @@ -38,10 +50,10 @@ using Statistics: mean model, nuts, n_adapts + n_samples; - nadapts = n_adapts, - initial_params = θ_init, - progress = false, - verbose = false, + nadapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, ) @test_throws ArgumentError AbstractMCMC.sample( rng, @@ -50,10 +62,10 @@ using Statistics: mean MCMCThreads(), n_adapts + n_samples, 2; - nadapts = n_adapts, - initial_params = θ_init, - progress = false, - verbose = false, + nadapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, ) # Transform back to original space. @@ -73,10 +85,10 @@ using Statistics: mean model, hmc, n_adapts + n_samples; - n_adapts = n_adapts, - initial_params = θ_init, - progress = false, - verbose = false, + n_adapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, ) # Transform back to original space. @@ -96,10 +108,10 @@ using Statistics: mean model, custom, n_adapts + n_samples; - n_adapts = 0, - initial_params = θ_init, - progress = false, - verbose = false, + n_adapts=0, + initial_params=θ_init, + progress=false, + verbose=false, ) # Transform back to original space. @@ -122,20 +134,20 @@ using Statistics: mean model, custom, 10; - n_adapts = 0, - initial_params = θ_init, - progress = false, - verbose = false, + n_adapts=0, + initial_params=θ_init, + progress=false, + verbose=false, ) samples2 = AbstractMCMC.sample( rng2, model, custom, 10; - n_adapts = 0, - initial_params = θ_init, - progress = false, - verbose = false, + n_adapts=0, + initial_params=θ_init, + progress=false, + verbose=false, ) @test mapreduce(*, samples1, samples2) do s1, s2 s1.z.θ == s2.z.θ From f340df79ac4613ef5d1d864e1da358bb2554cc0e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Oct 2024 07:22:03 +0100 Subject: [PATCH 02/18] undo formatting --- test/abstractmcmc.jl | 58 ++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index a5fa2536..2c1c2ddb 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -8,7 +8,7 @@ using Statistics: mean θ_init = randn(rng, 2) nuts = NUTS(0.8) - hmc = HMC(100; integrator=Leapfrog(0.05)) + hmc = HMC(100; integrator = Leapfrog(0.05)) hmcda = HMCDA(0.8, 0.1) integrator = Leapfrog(1e-3) @@ -38,10 +38,10 @@ using Statistics: mean model, nuts, n_adapts + n_samples; - n_adapts=n_adapts, - initial_params=θ_init, - progress=false, - verbose=false, + n_adapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, ) # Error if keyword argument `nadapts` is used @@ -50,10 +50,10 @@ using Statistics: mean model, nuts, n_adapts + n_samples; - nadapts=n_adapts, - initial_params=θ_init, - progress=false, - verbose=false, + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, ) @test_throws ArgumentError AbstractMCMC.sample( rng, @@ -62,10 +62,10 @@ using Statistics: mean MCMCThreads(), n_adapts + n_samples, 2; - nadapts=n_adapts, - initial_params=θ_init, - progress=false, - verbose=false, + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, ) # Transform back to original space. @@ -85,10 +85,10 @@ using Statistics: mean model, hmc, n_adapts + n_samples; - n_adapts=n_adapts, - initial_params=θ_init, - progress=false, - verbose=false, + n_adapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, ) # Transform back to original space. @@ -108,10 +108,10 @@ using Statistics: mean model, custom, n_adapts + n_samples; - n_adapts=0, - initial_params=θ_init, - progress=false, - verbose=false, + n_adapts = 0, + initial_params = θ_init, + progress = false, + verbose = false, ) # Transform back to original space. @@ -134,20 +134,20 @@ using Statistics: mean model, custom, 10; - n_adapts=0, - initial_params=θ_init, - progress=false, - verbose=false, + n_adapts = 0, + initial_params = θ_init, + progress = false, + verbose = false, ) samples2 = AbstractMCMC.sample( rng2, model, custom, 10; - n_adapts=0, - initial_params=θ_init, - progress=false, - verbose=false, + n_adapts = 0, + initial_params = θ_init, + progress = false, + verbose = false, ) @test mapreduce(*, samples1, samples2) do s1, s2 s1.z.θ == s2.z.θ From 161efffa0ec34d266fd7a90987e779ec238a5fba Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:23:54 +0800 Subject: [PATCH 03/18] Update test/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/abstractmcmc.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 2c1c2ddb..7fbeb872 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -22,11 +22,7 @@ using Statistics: mean ) @testset "getparams and setparams!!" begin - t, s = AbstractMCMC.step( - rng, - model, - nuts; - ) + t, s = AbstractMCMC.step(rng, model, nuts;) θ = AbstractMCMC.getparams(s) @test θ == t.z.θ From 933efe5d4e25534bd5d1d2f2f071b0e9a300df81 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 22 Oct 2024 07:08:43 +0100 Subject: [PATCH 04/18] add some new tests --- test/abstractmcmc.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 7fbeb872..da3a9744 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -27,6 +27,10 @@ using Statistics: mean θ = AbstractMCMC.getparams(s) @test θ == t.z.θ @test AbstractMCMC.setparams!!(s, θ) == s + + new_θ = randn(rng, 2) + new_state = AbstractMCMC.setparams!!(s, new_θ) + @test AbstractMCMC.getparams(new_state) == new_θ end samples_nuts = AbstractMCMC.sample( From 4b0be7fe5bbb3fd91f6bc342206b9adcc59f6bc1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 22 Oct 2024 19:44:17 +0100 Subject: [PATCH 05/18] update `setparams!!` --- research/tests/runtests.jl | 2 +- src/abstractmcmc.jl | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index 803458c6..2bb8e8d3 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl") include("relativistic_hmc.jl") include("riemannian_hmc.jl") -@main function runtests(patterns...; dry::Bool = false) +Comonicon.@main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) end diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1472a622..3a2ff638 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -35,8 +35,12 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -function AbstractMCMC.setparams!!(state::HMCState, θ) - return @set state.transition.z.θ = θ +function AbstractMCMC.setparams!!(state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) end """ From 8f801d42d992b5ebc96edb37d92e5700411fe257 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 23 Oct 2024 02:50:03 +0800 Subject: [PATCH 06/18] Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractmcmc.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 3a2ff638..ef32b5ec 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -38,8 +38,10 @@ end function AbstractMCMC.setparams!!(state::HMCState, params) hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, params, state.transition.z.r; - ℓκ=state.transition.z.ℓκ + hamiltonian, + params, + state.transition.z.r; + ℓκ = state.transition.z.ℓκ, ) end From 57f9e2bdc5f72eff2e6d6baf5169c71480ae4233 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 22 Oct 2024 20:57:46 +0100 Subject: [PATCH 07/18] add comment --- src/abstractmcmc.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index ef32b5ec..db7ebd5e 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -31,18 +31,17 @@ getmetric(state::HMCState) = state.metric getintegrator(state::HMCState) = state.κ.τ.integrator function AbstractMCMC.getparams(state::HMCState) - # TODO(sunxd): should we return a copy? return state.transition.z.θ end +# Using @set to update state.transition.z.θ can lead to inconsistencies: +# - It retains cached log-joint and gradient computations that become invalid +# - This can cause incorrect evaluations in subsequent steps (e.g. MH) +# +# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17 +# if in the future the interface provides access to the log density function function AbstractMCMC.setparams!!(state::HMCState, params) - hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) - return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, - params, - state.transition.z.r; - ℓκ = state.transition.z.ℓκ, - ) + return @set state.transition.z.θ = θ end """ @@ -429,4 +428,4 @@ end function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator) return spl.κ -end +end \ No newline at end of file From 26b09707cc1875a1f5f8527f6b8389d68e5313fb Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 22 Oct 2024 21:25:10 +0100 Subject: [PATCH 08/18] Update abstractmcmc.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/abstractmcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index db7ebd5e..92f58daa 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -41,7 +41,7 @@ end # TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17 # if in the future the interface provides access to the log density function function AbstractMCMC.setparams!!(state::HMCState, params) - return @set state.transition.z.θ = θ + return @set state.transition.z.θ = params end """ From b9dfa361164837ac20f04206896e24f7081a7a6d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 23 Oct 2024 07:11:32 +0100 Subject: [PATCH 09/18] format --- src/abstractmcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 92f58daa..d3a8fe98 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -428,4 +428,4 @@ end function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator) return spl.κ -end \ No newline at end of file +end From c7f3163bb12fd635576398fd01633704b42e2efa Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 28 Oct 2024 05:36:25 +0000 Subject: [PATCH 10/18] update implementation --- src/abstractmcmc.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index d3a8fe98..67d4ec9e 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -34,14 +34,12 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -# Using @set to update state.transition.z.θ can lead to inconsistencies: -# - It retains cached log-joint and gradient computations that become invalid -# - This can cause incorrect evaluations in subsequent steps (e.g. MH) -# -# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17 -# if in the future the interface provides access to the log density function -function AbstractMCMC.setparams!!(state::HMCState, params) - return @set state.transition.z.θ = params +function AbstractMCMC.setparams!!(model, state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) end """ From 9d05e46bc9c78aa733f35b2c4faa555c6f91aa22 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 28 Oct 2024 05:37:34 +0000 Subject: [PATCH 11/18] bump AbstractMCMC --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 83102a0f..f35b1df4 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains" AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" [compat] -AbstractMCMC = "5.5" +AbstractMCMC = "5.6" ArgCheck = "1, 2" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" From 3b0576946e5b66ce516493a5b9c06f158a3c626e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 28 Oct 2024 08:39:08 +0000 Subject: [PATCH 12/18] update test --- test/abstractmcmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index da3a9744..e2355f77 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -26,10 +26,10 @@ using Statistics: mean θ = AbstractMCMC.getparams(s) @test θ == t.z.θ - @test AbstractMCMC.setparams!!(s, θ) == s + @test AbstractMCMC.setparams!!(model, s, θ) == s new_θ = randn(rng, 2) - new_state = AbstractMCMC.setparams!!(s, new_θ) + new_state = AbstractMCMC.setparams!!(model, s, new_θ) @test AbstractMCMC.getparams(new_state) == new_θ end From 480dedc9b71b54cff07156d5eba4fe683850d41e Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:40:26 +0000 Subject: [PATCH 13/18] Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractmcmc.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 67d4ec9e..9ebedabe 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -37,8 +37,10 @@ end function AbstractMCMC.setparams!!(model, state::HMCState, params) hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, params, state.transition.z.r; - ℓκ=state.transition.z.ℓκ + hamiltonian, + params, + state.transition.z.r; + ℓκ = state.transition.z.ℓκ, ) end From faf5bb0dc6926a21c1ac75700aa66e521cd72ec6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 31 Oct 2024 08:24:06 +0000 Subject: [PATCH 14/18] fix method ambiguity --- src/abstractmcmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 9ebedabe..bf4a0094 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -34,13 +34,13 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -function AbstractMCMC.setparams!!(model, state::HMCState, params) +function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, state::HMCState, params) hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( hamiltonian, params, state.transition.z.r; - ℓκ = state.transition.z.ℓκ, + ℓκ=state.transition.z.ℓκ, ) end From cd35f1b2c5d6323438d3d15455ddf032c159c328 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:26:05 +0000 Subject: [PATCH 15/18] Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractmcmc.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index bf4a0094..5090b77f 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -34,7 +34,11 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, state::HMCState, params) +function AbstractMCMC.setparams!!( + model::AbstractMCMC.LogDensityModel, + state::HMCState, + params, +) hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( hamiltonian, From 3711a0ca81a53f6bdc21219ecc34251cb587f8b8 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:26:11 +0000 Subject: [PATCH 16/18] Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractmcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 5090b77f..40585909 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -44,7 +44,7 @@ function AbstractMCMC.setparams!!( hamiltonian, params, state.transition.z.r; - ℓκ=state.transition.z.ℓκ, + ℓκ = state.transition.z.ℓκ, ) end From 027cd31e17ec29955792b41a895ece1b47e35f54 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 31 Oct 2024 09:13:22 +0000 Subject: [PATCH 17/18] fix test error --- test/abstractmcmc.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index e2355f77..b620a85f 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -26,8 +26,12 @@ using Statistics: mean θ = AbstractMCMC.getparams(s) @test θ == t.z.θ - @test AbstractMCMC.setparams!!(model, s, θ) == s - + new_state = AbstractMCMC.setparams!!(model, s, θ) + @test new_state.transition.z.θ == θ + @test new_state.transition.z.ℓπ == s.transition.z.ℓπ + @test new_state.transition.z.ℓκ == s.transition.z.ℓκ + @test new_state.transition.z.r == s.transition.z.r + new_θ = randn(rng, 2) new_state = AbstractMCMC.setparams!!(model, s, new_θ) @test AbstractMCMC.getparams(new_state) == new_θ From 30dca23817f4de70b4f7aa27645be1ee482b8072 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 31 Oct 2024 09:44:29 +0000 Subject: [PATCH 18/18] fix more test error --- test/abstractmcmc.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index b620a85f..7d325f03 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -28,10 +28,14 @@ using Statistics: mean @test θ == t.z.θ new_state = AbstractMCMC.setparams!!(model, s, θ) @test new_state.transition.z.θ == θ - @test new_state.transition.z.ℓπ == s.transition.z.ℓπ - @test new_state.transition.z.ℓκ == s.transition.z.ℓκ + new_state_logπ = new_state.transition.z.ℓπ + @test new_state_logπ.value == s.transition.z.ℓπ.value + @test new_state_logπ.gradient == s.transition.z.ℓπ.gradient + new_state_logκ = new_state.transition.z.ℓκ + @test new_state_logκ.value == s.transition.z.ℓκ.value + @test new_state_logκ.gradient == s.transition.z.ℓκ.gradient @test new_state.transition.z.r == s.transition.z.r - + new_θ = randn(rng, 2) new_state = AbstractMCMC.setparams!!(model, s, new_θ) @test AbstractMCMC.getparams(new_state) == new_θ