From 29017e76db7ea3be39942f7f7a36876254f0282a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 28 Sep 2022 12:34:25 +0200 Subject: [PATCH 01/44] Add support for Enzyme --- src/essential/Essential.jl | 1 + src/essential/ad.jl | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index df0a9b5ac..331dbeed9 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -42,6 +42,7 @@ export @model, setadbackend, setadsafe, ForwardDiffAD, + EnzymeAD, TrackerAD, ZygoteAD, ReverseDiffAD, diff --git a/src/essential/ad.jl b/src/essential/ad.jl index b56ce0140..b2530435d 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -12,6 +12,9 @@ end function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end +function _setadbackend(::Val{:enzyme}) + ADBACKEND[] = :enzyme +end function _setadbackend(::Val{:tracker}) ADBACKEND[] = :tracker end @@ -47,6 +50,7 @@ getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk standardtag(::ForwardDiffAD{<:Any,true}) = true standardtag(::ForwardDiffAD) = false +struct EnzymeAD <: ADBackend end struct TrackerAD <: ADBackend end struct ZygoteAD <: ADBackend end @@ -64,6 +68,7 @@ ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} +ADBackend(::Val{:enzyme}) = EnzymeAD ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val{:zygote}) = ZygoteAD ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} @@ -102,6 +107,10 @@ function LogDensityProblems.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensity return LogDensityProblems.ADgradient(Val(:ForwardDiff), ℓ; gradientconfig=config) end +function LogDensityProblems.ADgradient(::EnzymeAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Enzyme), ℓ) +end + function LogDensityProblems.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) end From 43ef4c4a10eb5d9deaed10665f79815c716fb931 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 23 Dec 2022 22:47:19 +0100 Subject: [PATCH 02/44] Apply suggestions from code review --- src/essential/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 61b678edc..848f10f68 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -108,11 +108,11 @@ function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensi end function LogDensityProblemsAD.ADgradient(::EnzymeAD, ℓ::Turing.LogDensityFunction) - return LogDensityProblems.ADgradient(Val(:Enzyme), ℓ) + return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ) end function LogDensityProblemsAD.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) - return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) + return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ) end function LogDensityProblemsAD.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction) From 3e5841f30492f70559c5f197f657f203db74467f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 12:21:12 +0100 Subject: [PATCH 03/44] Add Enzyme to test dependencies --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index a858a9ec6..02ce8cc1d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -41,6 +42,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.21" +Enzyme = "0.10.13" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32" LogDensityProblems = "2" From 66bce4ed46a86cf89b37a15d2b52a421cb1de5e9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 12:22:56 +0100 Subject: [PATCH 04/44] Test Enzyme --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7eee46813..b32605fa4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,7 @@ using Turing: BinomialLogit, ForwardDiffAD, Sampler, SampleFromPrior, NUTS, Trac using Turing.Essential: TuringDenseMvNormal, TuringDiagMvNormal using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, AdvancedVI +import Enzyme import LogDensityProblems import LogDensityProblemsAD @@ -65,7 +66,7 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ end Turing.setrdcache(false) - for adbackend in (:forwarddiff, :tracker, :reversediff) + for adbackend in (:forwarddiff, :tracker, :reversediff, :enzyme) @timeit TIMEROUTPUT "inference: $adbackend" begin Turing.setadbackend(adbackend) @info "Testing $(adbackend)" From 789013467cc6dbe0b5a3f9774b936c3a1084d2d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 17:40:31 +0100 Subject: [PATCH 05/44] Update ad.jl --- src/essential/ad.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 848f10f68..59289c184 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -9,6 +9,10 @@ function setadbackend(backend::Val) Bijectors.setadbackend(backend) end +# TODO: Add support to AdvancedVI and Bijectors +# (or better: use common interface package) +setadbackend(backend::Val{:enzyme}) = _setadbackend(backend) + function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end From f4bd1bf0f2c6e4f4f2b429d6a14d1cc8cba6f9a0 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:59:05 +0000 Subject: [PATCH 06/44] Update Project.toml --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index abae4a539..567ed47f6 100644 --- a/Project.toml +++ b/Project.toml @@ -40,14 +40,14 @@ AbstractMCMC = "4" AdvancedHMC = "0.3.0, 0.4" AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1" +AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.8, 0.9, 0.10" +Bijectors = "0.11, 0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.21.5" +DynamicPPL = "0.22" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From c8e01d0da369cfb55dab045efb8ee27d0c15fbdf Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:59:32 +0000 Subject: [PATCH 07/44] Update advi.jl --- src/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index d91f8a897..47b35ddb0 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -1,5 +1,5 @@ # TODO(torfjelde): Find a better solution. -struct Vec{N, B<:Bijectors.Bijector{N}} <: Bijectors.Bijector{1} +struct Vec{N, B<:Bijectors.Transform} <: Bijectors.Transform b::B size::NTuple{N, Int} end From 946e594d6e6c11decc0b9f49b519de4b837c14f5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 7 Mar 2023 10:20:12 +0100 Subject: [PATCH 08/44] Do not call `Bijectors.setadbackend` --- Project.toml | 2 +- src/essential/ad.jl | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 12c2ad79e..b64082ae6 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.11, 0.12" +Bijectors = "0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" diff --git a/src/essential/ad.jl b/src/essential/ad.jl index d3d73326f..4df698da3 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -6,13 +6,8 @@ setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) function setadbackend(backend::Val) _setadbackend(backend) AdvancedVI.setadbackend(backend) - Bijectors.setadbackend(backend) end -# TODO: Add support to AdvancedVI and Bijectors -# (or better: use common interface package) -setadbackend(backend::Val{:enzyme}) = _setadbackend(backend) - function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end From e9eedd10cd452962b23e8269c0dce8390ce34461 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 13 Apr 2023 11:04:14 +0200 Subject: [PATCH 09/44] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index ca7b13390..cc92ba98b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -42,7 +42,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.22" -Enzyme = "0.10.13" +Enzyme = "0.11" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" From 8d8d0310c672734b20a760a9faa0c0647193832b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 00:24:44 +0200 Subject: [PATCH 10/44] Address comments --- test/Project.toml | 2 +- test/runtests.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 1cec0ec19..aab5004fa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11" +Enzyme = "0.11.2" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/runtests.jl b/test/runtests.jl index 64d33ee17..282594576 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,9 @@ import LogDensityProblemsAD setprogress!(false) +# Disable Enzyme warnings +Enzyme.API.typeWarning!(false) + include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") # Collect timing and allocations information to show in a clear way. From e5916304d417b031fd2f82e5d1a6363a51600955 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 02:11:04 +0200 Subject: [PATCH 11/44] Update runtests.jl --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 282594576..8c7b729bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,9 @@ setprogress!(false) # Disable Enzyme warnings Enzyme.API.typeWarning!(false) +# Enable runtime activity (workaround) +Enzyme.API.runtimeActivity!(true) + include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") # Collect timing and allocations information to show in a clear way. From 568cdaceb8e1bd58cca981edd4924a430a78e9d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Jul 2023 19:48:43 +0200 Subject: [PATCH 12/44] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index aab5004fa..5933cf717 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.2" +Enzyme = "0.11.3" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 6f0bf67e079f46ba7be130b89e80c8df5b31d306 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Jul 2023 22:09:56 +0200 Subject: [PATCH 13/44] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 5933cf717..138eea80d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.3" +Enzyme = "0.11.4" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 5ba7ac6bdbf0eac941ef28fe08dd09b1ad717769 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 13 Jul 2023 21:37:15 +0200 Subject: [PATCH 14/44] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 138eea80d..1772782c8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.4" +Enzyme = "0.11.5" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 162755be3427a8400d83ddc9e2c32c5beeebbfc7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 14 Jul 2023 22:03:12 +0200 Subject: [PATCH 15/44] Test against Enzyme#main --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 6c618d4d7..d6b3c2a7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +using Pkg +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) + using AbstractMCMC using AdvancedMH using Clustering From e44e7560bc79a45f01c4b74ca1813140ae2f6924 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 Jul 2023 10:28:21 +0200 Subject: [PATCH 16/44] Try addr13 branch --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d6b3c2a7d..a75c1c771 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="addr13")) using AbstractMCMC using AdvancedMH From 1f1b1140e3ab1d6b41324dcb36d4e79061088f01 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 Jul 2023 20:08:37 +0200 Subject: [PATCH 17/44] Update runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a75c1c771..d6b3c2a7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="addr13")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) using AbstractMCMC using AdvancedMH From bb795e6e03b6b65e664ec8052784e3ae79571dc8 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 31 Jul 2023 20:46:32 +0100 Subject: [PATCH 18/44] Disable Gibbs tests temporarily --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d6b3c2a7d..1b99986f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,8 +81,8 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ @info "Testing $(adbackend)" @testset "inference: $adbackend" begin @testset "samplers" begin - @timeit_include("inference/gibbs.jl") - @timeit_include("inference/gibbs_conditional.jl") + # @timeit_include("inference/gibbs.jl") + # @timeit_include("inference/gibbs_conditional.jl") @timeit_include("inference/hmc.jl") @timeit_include("inference/Inference.jl") @timeit_include("contrib/inference/dynamichmc.jl") From 1c7f20efd54a74148b33cda574f5efa29f19b5b2 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 31 Jul 2023 22:13:11 +0100 Subject: [PATCH 19/44] Update test/Project.toml Co-authored-by: David Widmann --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 6eae1dcfd..8c777f43f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -61,4 +61,3 @@ StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" -julia = "1.6" From 012a0cbd3dc26e6e877f0d2a9cc731b1f5a53d3e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Sep 2023 09:54:56 +0100 Subject: [PATCH 20/44] disable tests unrelated to enzyme + limit CI to avoid over-use of resources --- .github/workflows/DynamicHMC.yml | 1 - .github/workflows/Numerical.yml | 1 - .github/workflows/TuringCI.yml | 14 ---------- test/runtests.jl | 45 ++++++++++++++++---------------- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/.github/workflows/DynamicHMC.yml b/.github/workflows/DynamicHMC.yml index 099f70fcf..d66c6988b 100644 --- a/.github/workflows/DynamicHMC.yml +++ b/.github/workflows/DynamicHMC.yml @@ -12,7 +12,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest diff --git a/.github/workflows/Numerical.yml b/.github/workflows/Numerical.yml index 314241fbe..977fc86f7 100644 --- a/.github/workflows/Numerical.yml +++ b/.github/workflows/Numerical.yml @@ -12,7 +12,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest diff --git a/.github/workflows/TuringCI.yml b/.github/workflows/TuringCI.yml index 88cc27bcb..cc8648a7a 100644 --- a/.github/workflows/TuringCI.yml +++ b/.github/workflows/TuringCI.yml @@ -13,7 +13,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest @@ -22,19 +21,6 @@ jobs: num_threads: - 1 - 2 - include: - - version: '1.7' - os: ubuntu-latest - arch: x86 - num_threads: 2 - - version: '1.7' - os: windows-latest - arch: x64 - num_threads: 2 - - version: '1.7' - os: macOS-latest - arch: x64 - num_threads: 2 steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/runtests.jl b/test/runtests.jl index 88b3e8a04..23f5fcc30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,19 +63,20 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin - @testset "essential" begin - @timeit_include("essential/ad.jl") - end - - @testset "samplers (without AD)" begin - @timeit_include("mcmc/particle_mcmc.jl") - @timeit_include("mcmc/emcee.jl") - @timeit_include("mcmc/ess.jl") - @timeit_include("mcmc/is.jl") - end + # NOTE: Doesn't contain Enzyme tests. + # @testset "essential" begin + # @timeit_include("essential/ad.jl") + # end + + # @testset "samplers (without AD)" begin + # @timeit_include("mcmc/particle_mcmc.jl") + # @timeit_include("mcmc/emcee.jl") + # @timeit_include("mcmc/ess.jl") + # @timeit_include("mcmc/is.jl") + # end Turing.setrdcache(false) - for adbackend in (:forwarddiff, :reversediff, :enzyme) + for adbackend in (:enzyme,) @timeit TIMEROUTPUT "inference: $adbackend" begin Turing.setadbackend(adbackend) @info "Testing $(adbackend)" @@ -104,19 +105,19 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ end end - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end + # @testset "variational optimisers" begin + # @timeit_include("variational/optimisers.jl") + # end - Turing.setadbackend(:forwarddiff) - @testset "stdlib" begin - @timeit_include("stdlib/distributions.jl") - @timeit_include("stdlib/RandomMeasures.jl") - end + # Turing.setadbackend(:forwarddiff) + # @testset "stdlib" begin + # @timeit_include("stdlib/distributions.jl") + # @timeit_include("stdlib/RandomMeasures.jl") + # end - @testset "utilities" begin - @timeit_include("mcmc/utilities.jl") - end + # @testset "utilities" begin + # @timeit_include("mcmc/utilities.jl") + # end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From 577734477186d8c02307398d809816905a8fb19e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 12 Dec 2023 08:30:31 +0000 Subject: [PATCH 21/44] import `AutoEnzyme` --- src/essential/Essential.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index 5ae03d9c3..ab1460fe6 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote +using ADTypes: ADTypes, AutoForwardDiff, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote import AdvancedPS import LogDensityProblems From 121df7d5246bc53803ec4f4b67d3686d391805c7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 16 Dec 2023 10:19:26 +0000 Subject: [PATCH 22/44] Test hmc only --- test/mcmc/Inference.jl | 3 ++- test/mcmc/hmc.jl | 3 ++- test/mcmc/sghmc.jl | 4 +++- test/runtests.jl | 52 +++++++++++++++++++++--------------------- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 1f5a14869..d8c05136f 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -1,4 +1,5 @@ -@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),) # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index fe18fa773..0d408f01b 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -1,4 +1,5 @@ -@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) # Set a seed rng = StableRNG(123) @numerical_testset "constrained bounded" begin diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 4405b505a..9079f94c0 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -1,4 +1,5 @@ @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC @@ -24,7 +25,8 @@ end end -@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD diff --git a/test/runtests.jl b/test/runtests.jl index ab4b8b7b1..71344d62d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,47 +64,47 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin -# @testset "essential" begin -# @timeit_include("essential/ad.jl") -# end - -# @testset "samplers (without AD)" begin -# @timeit_include("mcmc/particle_mcmc.jl") -# @timeit_include("mcmc/emcee.jl") -# @timeit_include("mcmc/ess.jl") -# @timeit_include("mcmc/is.jl") -# end + # @testset "essential" begin + # @timeit_include("essential/ad.jl") + # end + + # @testset "samplers (without AD)" begin + # @timeit_include("mcmc/particle_mcmc.jl") + # @timeit_include("mcmc/emcee.jl") + # @timeit_include("mcmc/ess.jl") + # @timeit_include("mcmc/is.jl") + # end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin -# @timeit_include("mcmc/gibbs.jl") -# @timeit_include("mcmc/gibbs_conditional.jl") + # @timeit_include("mcmc/gibbs.jl") + # @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") - @timeit_include("mcmc/abstractmcmc.jl") - @timeit_include("mcmc/mh.jl") - @timeit_include("ext/dynamichmc.jl") + # @timeit_include("mcmc/abstractmcmc.jl") + # @timeit_include("mcmc/mh.jl") + # @timeit_include("ext/dynamichmc.jl") end - @testset "variational algorithms" begin - @timeit_include("variational/advi.jl") - end + # @testset "variational algorithms" begin + # @timeit_include("variational/advi.jl") + # end - @testset "mode estimation" begin - @timeit_include("optimisation/OptimInterface.jl") - @timeit_include("ext/Optimisation.jl") - end + # @testset "mode estimation" begin + # @timeit_include("optimisation/OptimInterface.jl") + # @timeit_include("ext/Optimisation.jl") + # end end # @testset "variational optimisers" begin # @timeit_include("variational/optimisers.jl") # end -# @testset "stdlib" begin -# @timeit_include("stdlib/distributions.jl") -# @timeit_include("stdlib/RandomMeasures.jl") -# end + # @testset "stdlib" begin + # @timeit_include("stdlib/distributions.jl") + # @timeit_include("stdlib/RandomMeasures.jl") + # end # @testset "utilities" begin # @timeit_include("mcmc/utilities.jl") From a164707f10ddf8bacf3ade53a292f6f41229f71b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Dec 2023 13:50:40 -0600 Subject: [PATCH 23/44] Update sghmc.jl --- test/mcmc/sghmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 9079f94c0..16b6508ee 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -1,4 +1,4 @@ -@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) From 97f1fb6bb5b7d4f90bfd63ea86d12003dc687a6a Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Dec 2023 13:51:38 -0600 Subject: [PATCH 24/44] Update runtests.jl --- test/runtests.jl | 64 ++++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 71344d62d..b4a6d37e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,51 +64,51 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin - # @testset "essential" begin - # @timeit_include("essential/ad.jl") - # end - - # @testset "samplers (without AD)" begin - # @timeit_include("mcmc/particle_mcmc.jl") - # @timeit_include("mcmc/emcee.jl") - # @timeit_include("mcmc/ess.jl") - # @timeit_include("mcmc/is.jl") - # end + @testset "essential" begin + @timeit_include("essential/ad.jl") + end + + @testset "samplers (without AD)" begin + @timeit_include("mcmc/particle_mcmc.jl") + @timeit_include("mcmc/emcee.jl") + @timeit_include("mcmc/ess.jl") + @timeit_include("mcmc/is.jl") + end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin - # @timeit_include("mcmc/gibbs.jl") - # @timeit_include("mcmc/gibbs_conditional.jl") + @timeit_include("mcmc/gibbs.jl") + @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") - # @timeit_include("mcmc/abstractmcmc.jl") - # @timeit_include("mcmc/mh.jl") - # @timeit_include("ext/dynamichmc.jl") + @timeit_include("mcmc/abstractmcmc.jl") + @timeit_include("mcmc/mh.jl") + @timeit_include("ext/dynamichmc.jl") end - # @testset "variational algorithms" begin - # @timeit_include("variational/advi.jl") - # end + @testset "variational algorithms" begin + @timeit_include("variational/advi.jl") + end - # @testset "mode estimation" begin - # @timeit_include("optimisation/OptimInterface.jl") - # @timeit_include("ext/Optimisation.jl") - # end + @testset "mode estimation" begin + @timeit_include("optimisation/OptimInterface.jl") + @timeit_include("ext/Optimisation.jl") + end end - # @testset "variational optimisers" begin - # @timeit_include("variational/optimisers.jl") - # end + @testset "variational optimisers" begin + @timeit_include("variational/optimisers.jl") + end - # @testset "stdlib" begin - # @timeit_include("stdlib/distributions.jl") - # @timeit_include("stdlib/RandomMeasures.jl") - # end + @testset "stdlib" begin + @timeit_include("stdlib/distributions.jl") + @timeit_include("stdlib/RandomMeasures.jl") + end - # @testset "utilities" begin - # @timeit_include("mcmc/utilities.jl") - # end + @testset "utilities" begin + @timeit_include("mcmc/utilities.jl") + end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From c7b6cf4c18ae1c730e2b1e5a58b80d1f3dc8360e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Jan 2024 14:57:43 -0500 Subject: [PATCH 25/44] disable Type unstable getfield --- test/mcmc/Inference.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index d8c05136f..374c48e1e 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -337,6 +337,8 @@ alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) chn = sample(gdemo_default, alg, 1000) end + # Type unstable getfield of tuple not supported in Enzyme yet + if adbackend != AutoEnzyme() @testset "vectorization @." begin # https://github.com/FluxML/Tracker.jl/issues/119 @model function vdemo1(x) @@ -519,6 +521,7 @@ vdemo3kw(; T) = vdemo3(T) sample(vdemo3kw(; T=Vector{Float64}), alg, 250) end + end @testset "names_values" begin ks, xs = Turing.Inference.names_values([ From efdd8e7537111591043366c28c0c67631e9a69b8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Jan 2024 15:00:58 -0500 Subject: [PATCH 26/44] use release --- test/Project.toml | 2 +- test/runtests.jl | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 355ded217..72beb62b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -41,7 +41,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.11.5" +Enzyme = "0.11.12" DynamicPPL = "0.24" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/runtests.jl b/test/runtests.jl index b4a6d37e8..2475f6a38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) - using AbstractMCMC using AdvancedMH using AdvancedPS From 2fdf5464165a1f013f91afd51fb879348662a16a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 25 Jan 2024 23:35:13 +0100 Subject: [PATCH 27/44] Remove seemingly unnecessary definition --- src/essential/ad.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 289bf55fd..da5f827e9 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -40,10 +40,6 @@ function LogDensityProblemsAD.ADgradient(ad::AutoForwardDiff, ℓ::Turing.LogDen return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x = θ) end -function LogDensityProblemsAD.ADgradient(::AutoEnzyme, ℓ::Turing.LogDensityFunction) - return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ) -end - function LogDensityProblemsAD.ADgradient(ad::AutoReverseDiff, ℓ::Turing.LogDensityFunction) return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val(ad.compile), x=DynamicPPL.getparams(ℓ)) end From 4d8cd2313cfd3c93665df0dcc463246812e328e3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 26 Jan 2024 02:02:11 +0100 Subject: [PATCH 28/44] Run tests on Enzyme#main again --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 2475f6a38..b1f66287d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +import Pkg +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) + using AbstractMCMC using AdvancedMH using AdvancedPS From b8296bed1cd59532ad120e9ea14f3cc0e2b274d2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Mar 2024 01:18:49 +0100 Subject: [PATCH 29/44] Test with cholesky fixes --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b1f66287d..2101bebf3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git", rev="fix-cholesky")) using AbstractMCMC using AdvancedMH From 2b54d69a7118e7d129248f3c24080ad550c563e5 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 29 May 2024 23:31:11 +0100 Subject: [PATCH 30/44] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index c70d2009b..ccaf715cb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -42,7 +42,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.11.12" +Enzyme = "0.12" DynamicPPL = "0.27" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 2823a41efc85986ec3785852f0ff648269b97e93 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 30 May 2024 00:19:30 +0100 Subject: [PATCH 31/44] Update Turing.jl --- src/Turing.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Turing.jl b/src/Turing.jl index 99e9880d2..5ef60aca5 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -106,6 +106,7 @@ export @model, # modelling AutoForwardDiff, # ADTypes AutoReverseDiff, AutoZygote, + AutoEnzyme, AutoTracker, AutoTapir, From 0c376a66d5c54459815e81ffdf00c92291bfd17f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 1 Jul 2024 19:57:32 +0100 Subject: [PATCH 32/44] Attempt at fix for `bnn` tests as outlined in #2277 --- test/mcmc/hmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 6ac96a551..855d729e9 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -111,7 +111,7 @@ Enzyme.API.runtimeActivity!(true) alpha = 0.16 # regularizatin term var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior - @model function bnn(ts) + @model function bnn(ts, var_prior) b1 ~ MvNormal([0. ;0.; 0.], [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) @@ -129,7 +129,7 @@ Enzyme.API.runtimeActivity!(true) end # Sampling - chain = sample(rng, bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample(rng, bnn(ts, var_prior), HMC(0.1, 5; adtype=adbackend), 10) end @testset "hmcda inference" begin From 76b5e48532277879fd4f7b98d1b27778dee8ec7a Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:49:03 +0100 Subject: [PATCH 33/44] Update test/runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9294b3487..6fa0b20ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git", rev="fix-cholesky")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git")) include("test_utils/SelectiveTests.jl") using .SelectiveTests: isincluded, parse_args From 784b8cb66e149f9d1829f7aaa86e5213c55e954f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:50:21 +0100 Subject: [PATCH 34/44] Update runtests.jl --- test/runtests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6fa0b20ea..48a00122d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git")) - include("test_utils/SelectiveTests.jl") using .SelectiveTests: isincluded, parse_args using Pkg From 5bfd06dca74c0d1bab2e3627b9a9606aec558ac1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 9 Jul 2024 20:09:38 +0100 Subject: [PATCH 35/44] remove implicit usage of `hvcat` --- test/mcmc/hmc.jl | 94 +++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 49 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 855d729e9..d8f9e0049 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -28,52 +28,51 @@ Enzyme.API.runtimeActivity!(true) # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin - obs = [0,1,0,1,1,1,1,1,1,1] + obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @model function constrained_test(obs) - p ~ Beta(2,2) - for i = 1:length(obs) + p ~ Beta(2, 2) + for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p + return p end chain = sample( rng, constrained_test(obs), HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) - 1000) + 1000, + ) - check_numerical(chain, [:p], [10/14], atol=0.1) + check_numerical(chain, [:p], [10 / 14]; atol=0.1) end @testset "constrained simplex" begin - obs12 = [1,2,1,2,2,2,2,2,2,2] + obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2] @model function constrained_simplex_test(obs12) ps ~ Dirichlet(2, 3) pd ~ Dirichlet(4, 1) - for i = 1:length(obs12) + for i in 1:length(obs12) obs12[i] ~ Categorical(ps) end return ps end chain = sample( - rng, - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000) + rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000 + ) - check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015) + check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end @testset "hmc reverse diff" begin alg = HMC(0.1, 10; adtype=adbackend) res = sample(rng, gdemo_default, alg, 4000) - check_gdemo(res, rtol=0.1) + check_gdemo(res; rtol=0.1) end @testset "matrix support" begin @model function hmcmatrixsup() - v ~ Wishart(7, [1 0.5; 0.5 1]) + return v ~ Wishart(7, [1 0.5; 0.5 1]) end model_f = hmcmatrixsup() @@ -81,7 +80,7 @@ Enzyme.API.runtimeActivity!(true) vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) - reshape(mean(r; dims = 1), 2, 2) + reshape(mean(r; dims=1), 2, 2) end @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 @@ -98,10 +97,10 @@ Enzyme.API.runtimeActivity!(true) M = N ÷ 4 x1s = rand(M) * 5 x2s = rand(M) * 5 - xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M]) - append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M])) - xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M]) - append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M])) + xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M]) + append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M])) + xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M]) + append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M])) xs = [xt1s; xt0s] ts = [ones(M); ones(M); zeros(M); zeros(M)] @@ -112,20 +111,18 @@ Enzyme.API.runtimeActivity!(true) var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior @model function bnn(ts, var_prior) - b1 ~ MvNormal([0. ;0.; 0.], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) + b1 ~ MvNormal(zeros(3), var_prior * I) + w11 ~ MvNormal(zeros(2), var_prior * I) + w12 ~ MvNormal(zeros(2), var_prior * I) + w13 ~ MvNormal(zeros(2), var_prior * I) bo ~ Normal(0, var_prior) - wo ~ MvNormal([0.; 0; 0], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - for i = rand(1:N, 10) + wo ~ MvNormal(zeros(3), var_prior * I) + for i in rand(1:N, 10) y = nn(xs[i], b1, w11, w12, w13, bo, wo) ts[i] ~ Bernoulli(y) end - b1, w11, w12, w13, bo, wo + return b1, w11, w12, w13, bo, wo end # Sampling @@ -153,7 +150,7 @@ Enzyme.API.runtimeActivity!(true) Random.seed!(12345) # particle samplers do not support user-provided `rng` yet alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) - res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) + res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000) check_gdemo(res3) end @@ -197,8 +194,8 @@ Enzyme.API.runtimeActivity!(true) @testset "check discard" begin alg = NUTS(100, 0.8; adtype=adbackend) - c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true) - c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false) + c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true) + c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false) @test size(c1, 1) == 500 @test size(c2, 1) == 500 @@ -216,20 +213,20 @@ Enzyme.API.runtimeActivity!(true) # https://github.com/TuringLang/DynamicPPL.jl/issues/27 @model function mwe1(::Type{T}=Float64) where {T<:Real} m = Matrix{T}(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains @model function mwe2(::Type{T}=Matrix{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains # https://github.com/TuringLang/Turing.jl/issues/1308 @model function mwe3(::Type{T}=Array{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains end @@ -247,13 +244,17 @@ Enzyme.API.runtimeActivity!(true) @model function demo_hmc_prior() # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1), lower=0) - m ~ Normal(0, sqrt(s)) + s ~ truncated(Normal(3, 1); lower=0) + return m ~ Normal(0, sqrt(s)) end alg = NUTS(1000, 0.8; adtype=adbackend) - gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext()) + gdemo_default_prior = DynamicPPL.contextualize( + demo_hmc_prior(), DynamicPPL.PriorContext() + ) chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2) + check_numerical( + chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 + ) end @testset "warning for difficult init params" begin @@ -268,7 +269,7 @@ Enzyme.API.runtimeActivity!(true) @test_logs ( :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode=:any begin + ) (:info,) match_mode = :any begin sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end @@ -280,7 +281,7 @@ Enzyme.API.runtimeActivity!(true) @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2) xs[1] ~ Dirichlet(ones(5)) - xs[2] ~ Dirichlet(ones(5)) + return xs[2] ~ Dirichlet(ones(5)) end model = vector_of_dirichlet() chain = sample(model, NUTS(), 1000) @@ -306,15 +307,10 @@ Enzyme.API.runtimeActivity!(true) end end - model = buggy_model(); - num_samples = 1_000; + model = buggy_model() + num_samples = 1_000 - chain = sample( - model, - NUTS(), - num_samples; - initial_params=[0.5, 1.75, 1.0] - ) + chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how From ce13e032d6a5901daac026d9f1c1085ddc9a465a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 25 Jul 2024 22:24:43 +0200 Subject: [PATCH 36/44] Re-activate CIs disabled for Enzyme testing --- .github/workflows/Tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 174b4bd8a..8de296e5e 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -26,15 +26,15 @@ jobs: - "mcmc/ess.jl" - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" version: - #- '1.7' TODO(mhauru): Temporarily disabled for Enzyme + - '1.7' - '1' os: - ubuntu-latest - #- windows-latest TODO(mhauru): Temporarily disabled for Enzyme - #- macOS-latest TODO(mhauru): Temporarily disabled for Enzyme + - windows-latest + - macOS-latest arch: - x64 - #- x86 TODO(mhauru): Temporarily disabled for Enzyme + - x86 num_threads: - 1 - 2 From e2c069345184af87f345cadc7699c5f4149a1bca Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 15 Aug 2024 11:05:16 +0200 Subject: [PATCH 37/44] Re-enable tests with other AD backends --- Project.toml | 2 +- test/mcmc/Inference.jl | 12 +++++++----- test/mcmc/hmc.jl | 11 +++++++---- test/mcmc/sghmc.jl | 6 ++---- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 33d4be908..8845ebbe2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.3" +version = "0.33.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15468669f..a6d9998ef 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -20,8 +20,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin @@ -374,8 +373,6 @@ Enzyme.API.runtimeActivity!(true) alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) chn = sample(gdemo_default, alg, 1000) end - # Type unstable getfield of tuple not supported in Enzyme yet - if adbackend != AutoEnzyme() @testset "vectorization @." begin # https://github.com/FluxML/Tracker.jl/issues/119 @model function vdemo1(x) @@ -407,6 +404,8 @@ Enzyme.API.runtimeActivity!(true) alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) + # Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 alg = HMC(0.2, 4; adtype=adbackend) @@ -452,6 +451,7 @@ Enzyme.API.runtimeActivity!(true) end sample(vdemo7(), alg, 1000) + end end @testset "vectorization .~" begin @model function vdemo1(x) @@ -474,6 +474,8 @@ Enzyme.API.runtimeActivity!(true) alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) + # Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 alg = HMC(0.2, 4; adtype=adbackend) @@ -518,6 +520,7 @@ Enzyme.API.runtimeActivity!(true) end sample(vdemo7(), alg, 1000) + end end @testset "Type parameters" begin N = 10 @@ -558,7 +561,6 @@ Enzyme.API.runtimeActivity!(true) vdemo3kw(; T) = vdemo3(T) sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) end - end @testset "names_values" begin ks, xs = Turing.Inference.names_values([ diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index d8f9e0049..e0e8da469 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -23,8 +23,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin @@ -76,14 +75,18 @@ Enzyme.API.runtimeActivity!(true) end model_f = hmcmatrixsup() - n_samples = 1_000 + n_samples = 5_000 vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) reshape(mean(r; dims=1), 2, 2) end - @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + if VERSION > v"1.7" || !(adbackend isa AutoEnzyme) + @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + else + @test_broken maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + end end @testset "multivariate support" begin # Define NN flow diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 955995570..411278f90 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -17,8 +17,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC @@ -44,8 +43,7 @@ Enzyme.API.runtimeActivity!(true) end end -# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) @testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD From 2115d524ff6f3825debfaba874c4a693243a9337 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 15 Aug 2024 13:03:12 +0200 Subject: [PATCH 38/44] Load `@test_broken` --- test/mcmc/hmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 212fced39..0e85fd16b 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -15,7 +15,7 @@ using LinearAlgebra: I, dot, vec import Random using StableRNGs: StableRNG using StatsFuns: logistic -using Test: @test, @test_logs, @testset +using Test: @test, @test_broken, @test_logs, @testset using Turing # Disable Enzyme warnings From 79d057c37fe2d0202edb5162454480cb89292c4f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:11:05 +0100 Subject: [PATCH 39/44] Bump Enzyme to 0.13 in tests --- test/Project.toml | 2 +- test/mcmc/Inference.jl | 4 ++-- test/mcmc/hmc.jl | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 9eb5b4fdc..07d5e0136 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -46,7 +46,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.12" +Enzyme = "0.13" DynamicPPL = "0.29" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 02a00766c..1ad807ef3 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -412,7 +412,7 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) - # Type unstable getfield of tuple not supported in Enzyme yet + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 @@ -482,7 +482,7 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) - # Type unstable getfield of tuple not supported in Enzyme yet + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 295f2703b..2a9a89761 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -78,7 +78,8 @@ using Turing reshape(mean(r; dims=1), 2, 2) end - if VERSION > v"1.7" || !(adbackend isa AutoEnzyme) + # TODO(mhauru) The below remains broken for Enzyme. Need to investigate why. + if !(adbackend isa AutoEnzyme) @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 else @test_broken maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 From c98bbc96af637df114deed81d346544ddda6bc41 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:13:25 +0100 Subject: [PATCH 40/44] Run JuliaFormatter on more files, remove trailing whitespace --- .JuliaFormatter.toml | 3 -- .github/workflows/DocsNav.yml | 6 +-- src/mcmc/mh.jl | 2 +- test/mcmc/hmc.jl | 98 +++++++++++++++++------------------ test/mcmc/sghmc.jl | 10 ++-- 5 files changed, 58 insertions(+), 61 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 15ecbc5c3..2772de28b 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -8,7 +8,4 @@ ignore = [ # https://github.com/TuringLang/Turing.jl/pull/2328/files "src/experimental/gibbs.jl", "test/experimental/gibbs.jl", - # https://github.com/TuringLang/Turing.jl/pull/1887 # Enzyme PR - "test/mcmc/hmc.jl", - "test/mcmc/sghmc.jl", ] diff --git a/.github/workflows/DocsNav.yml b/.github/workflows/DocsNav.yml index 14614d1fd..301ee7393 100644 --- a/.github/workflows/DocsNav.yml +++ b/.github/workflows/DocsNav.yml @@ -32,13 +32,13 @@ jobs: # Define the URL of the navbar to be used NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html" - + # Update all HTML files in the current directory (gh-pages root) ./insert_navbar.sh . $NAVBAR_URL - + # Remove the insert_navbar.sh file rm insert_navbar.sh - + # Check if there are any changes if [[ -n $(git status -s) ]]; then git add . diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 433add6b5..bc2519d71 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -54,7 +54,7 @@ Specifying a single distribution implies the use of static MH: ```julia # Use a static proposal for s² (which happens to be the same -# as the prior) and a static proposal for m (note that this +# as the prior) and a static proposal for m (note that this # isn't a random walk proposal). chain = sample( gdemo(1.5, 2.0), diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 7404dbf43..27c055394 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -22,52 +22,51 @@ using Turing # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin - obs = [0,1,0,1,1,1,1,1,1,1] + obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @model function constrained_test(obs) - p ~ Beta(2,2) - for i = 1:length(obs) + p ~ Beta(2, 2) + for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p + return p end chain = sample( rng, constrained_test(obs), HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) - 1000) + 1000, + ) - check_numerical(chain, [:p], [10/14], atol=0.1) + check_numerical(chain, [:p], [10 / 14]; atol=0.1) end @testset "constrained simplex" begin - obs12 = [1,2,1,2,2,2,2,2,2,2] + obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2] @model function constrained_simplex_test(obs12) ps ~ Dirichlet(2, 3) pd ~ Dirichlet(4, 1) - for i = 1:length(obs12) + for i in 1:length(obs12) obs12[i] ~ Categorical(ps) end return ps end chain = sample( - rng, - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000) + rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000 + ) - check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015) + check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end @testset "hmc reverse diff" begin alg = HMC(0.1, 10; adtype=adbackend) res = sample(rng, gdemo_default, alg, 4000) - check_gdemo(res, rtol=0.1) + check_gdemo(res; rtol=0.1) end @testset "matrix support" begin @model function hmcmatrixsup() - v ~ Wishart(7, [1 0.5; 0.5 1]) + return v ~ Wishart(7, [1 0.5; 0.5 1]) end model_f = hmcmatrixsup() @@ -75,7 +74,7 @@ using Turing vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) - reshape(mean(r; dims = 1), 2, 2) + reshape(mean(r; dims=1), 2, 2) end @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 @@ -92,10 +91,10 @@ using Turing M = N ÷ 4 x1s = rand(M) * 5 x2s = rand(M) * 5 - xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M]) - append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M])) - xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M]) - append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M])) + xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M]) + append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M])) + xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M]) + append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M])) xs = [xt1s; xt0s] ts = [ones(M); ones(M); zeros(M); zeros(M)] @@ -106,20 +105,22 @@ using Turing var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior @model function bnn(ts) - b1 ~ MvNormal([0. ;0.; 0.], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) + b1 ~ MvNormal( + [0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] + ) + w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) + w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) + w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) bo ~ Normal(0, var_prior) - wo ~ MvNormal([0.; 0; 0], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - for i = rand(1:N, 10) + wo ~ MvNormal( + [0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] + ) + for i in rand(1:N, 10) y = nn(xs[i], b1, w11, w12, w13, bo, wo) ts[i] ~ Bernoulli(y) end - b1, w11, w12, w13, bo, wo + return b1, w11, w12, w13, bo, wo end # Sampling @@ -147,7 +148,7 @@ using Turing Random.seed!(12345) # particle samplers do not support user-provided `rng` yet alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) - res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) + res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000) check_gdemo(res3) end @@ -191,8 +192,8 @@ using Turing @testset "check discard" begin alg = NUTS(100, 0.8; adtype=adbackend) - c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true) - c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false) + c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true) + c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false) @test size(c1, 1) == 500 @test size(c2, 1) == 500 @@ -210,20 +211,20 @@ using Turing # https://github.com/TuringLang/DynamicPPL.jl/issues/27 @model function mwe1(::Type{T}=Float64) where {T<:Real} m = Matrix{T}(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains @model function mwe2(::Type{T}=Matrix{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains # https://github.com/TuringLang/Turing.jl/issues/1308 @model function mwe3(::Type{T}=Array{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains end @@ -241,13 +242,17 @@ using Turing @model function demo_hmc_prior() # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1), lower=0) - m ~ Normal(0, sqrt(s)) + s ~ truncated(Normal(3, 1); lower=0) + return m ~ Normal(0, sqrt(s)) end alg = NUTS(1000, 0.8; adtype=adbackend) - gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext()) + gdemo_default_prior = DynamicPPL.contextualize( + demo_hmc_prior(), DynamicPPL.PriorContext() + ) chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2) + check_numerical( + chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 + ) end @testset "warning for difficult init params" begin @@ -262,7 +267,7 @@ using Turing @test_logs ( :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode=:any begin + ) (:info,) match_mode = :any begin sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end @@ -271,7 +276,7 @@ using Turing @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2) xs[1] ~ Dirichlet(ones(5)) - xs[2] ~ Dirichlet(ones(5)) + return xs[2] ~ Dirichlet(ones(5)) end model = vector_of_dirichlet() chain = sample(model, NUTS(), 1000) @@ -296,15 +301,10 @@ using Turing end end - model = buggy_model(); - num_samples = 1_000; + model = buggy_model() + num_samples = 1_000 - chain = sample( - model, - NUTS(), - num_samples; - initial_params=[0.5, 1.75, 1.0] - ) + chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 1f8179503..c1d07d2ce 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -34,7 +34,7 @@ using Turing alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend) chain = sample(rng, gdemo_default, alg, 10_000) - check_gdemo(chain, atol=0.1) + check_gdemo(chain; atol=0.1) end end @@ -58,15 +58,15 @@ end @testset "sgld inference" begin rng = StableRNG(1) - chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000) - check_gdemo(chain, atol = 0.2) + chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000) + check_gdemo(chain; atol=0.2) # Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh) v = get(chain, [:SGLD_stepsize, :s, :m]) s_weighted = dot(v.SGLD_stepsize, v.s) / sum(v.SGLD_stepsize) m_weighted = dot(v.SGLD_stepsize, v.m) / sum(v.SGLD_stepsize) - @test s_weighted ≈ 49/24 atol=0.2 - @test m_weighted ≈ 7/6 atol=0.2 + @test s_weighted ≈ 49 / 24 atol = 0.2 + @test m_weighted ≈ 7 / 6 atol = 0.2 end end From 9d391938f2b04a8254ad2a23077bb07abfe77a7e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:44:48 +0100 Subject: [PATCH 41/44] Restore compat with Enzyme v0.12 --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 07d5e0136..0841eb27e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -46,7 +46,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.13" +Enzyme = "0.12, 0.13" DynamicPPL = "0.29" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 66cd80a1ffeb9a4e233a247d04f596d1ec8ccc2d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 13:11:34 +0100 Subject: [PATCH 42/44] Import Enzyme in abstractmcmc and gibbs tests --- test/mcmc/abstractmcmc.jl | 1 + test/mcmc/gibbs.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 449b43b71..a113f1b7c 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -5,6 +5,7 @@ using AdvancedMH: AdvancedMH using Distributions: sample using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +import Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: I using LogDensityProblems: LogDensityProblems diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index cd044910b..0121687cd 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -5,6 +5,7 @@ using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample +import Enzyme using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff From ec34e4120b40fc10613d1c650fdaee1b47e98de2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 13:37:25 +0100 Subject: [PATCH 43/44] Add Enzyme imports to a couple of other tests files --- test/mcmc/gibbs_conditional.jl | 1 + test/optimisation/Optimisation.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index d6d81cbe0..abbdc03c5 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -5,6 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical import ..ADUtils using Clustering: Clustering using Distributions: Categorical, InverseGamma, Normal, sample +import Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: Diagonal, I using Random: Random diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index d8afd83db..8758e946f 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -5,6 +5,7 @@ using ..ADUtils: ADUtils using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: Diagonal, I using Mooncake: Mooncake From 120230a65727898d1cf9708be47f1d66cfdcc1e6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 14:21:28 +0100 Subject: [PATCH 44/44] Remove unnecessary version conditions in tests --- test/mcmc/Inference.jl | 95 ++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 1ad807ef3..bf53af36e 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -18,70 +18,55 @@ using Test: @test, @test_throws, @testset using Turing @testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends - # Only test threading if 1.3+. - if VERSION > v"1.2" - @testset "threaded sampling" begin - # Test that chains with the same seed will sample identically. - @testset "rng" begin - model = gdemo_default - - # multithreaded sampling with PG causes segfaults on Julia 1.5.4 - # https://github.com/TuringLang/Turing.jl/issues/1571 - samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0" - ( - HMC(0.1, 7; adtype=adbackend), - PG(10), - IS(), - MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), - ) - else - ( - HMC(0.1, 7; adtype=adbackend), - IS(), - MH(), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), - ) - end - for sampler in samplers - Random.seed!(5) - chain1 = sample(model, sampler, MCMCThreads(), 1000, 4) + @testset "threaded sampling" begin + # Test that chains with the same seed will sample identically. + @testset "rng" begin + model = gdemo_default + + samplers = ( + HMC(0.1, 7; adtype=adbackend), + PG(10), + IS(), + MH(), + Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), + Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + ) + for sampler in samplers + Random.seed!(5) + chain1 = sample(model, sampler, MCMCThreads(), 1000, 4) - Random.seed!(5) - chain2 = sample(model, sampler, MCMCThreads(), 1000, 4) + Random.seed!(5) + chain2 = sample(model, sampler, MCMCThreads(), 1000, 4) - @test chain1.value == chain2.value - end + @test chain1.value == chain2.value + end - # Should also be stable with am explicit RNG - seed = 5 - rng = Random.MersenneTwister(seed) - for sampler in samplers - Random.seed!(rng, seed) - chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) + # Should also be stable with am explicit RNG + seed = 5 + rng = Random.MersenneTwister(seed) + for sampler in samplers + Random.seed!(rng, seed) + chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) - Random.seed!(rng, seed) - chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) + Random.seed!(rng, seed) + chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) - @test chain1.value == chain2.value - end + @test chain1.value == chain2.value end + end - # Smoke test for default sample call. - Random.seed!(100) - chain = sample( - gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4 - ) - check_gdemo(chain) + # Smoke test for default sample call. + Random.seed!(100) + chain = sample(gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4) + check_gdemo(chain) - # run sampler: progress logging should be disabled and - # it should return a Chains object - sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default) - chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) - @test chains isa MCMCChains.Chains - end + # run sampler: progress logging should be disabled and + # it should return a Chains object + sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default) + chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) + @test chains isa MCMCChains.Chains end + @testset "chain save/resume" begin Random.seed!(1234)