Skip to content

Commit

Permalink
Move ADTypeCheckContext tests to a separate module (#2383)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru authored Oct 30, 2024
1 parent 5426eca commit 397d1a7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ macro timeit_include(path::AbstractString)
end

@testset "Turing" begin
@testset "Test utils" begin
@timeit_include("test_utils/test_utils.jl")
end

@testset "Aqua" begin
@timeit_include("Aqua.jl")
end
Expand Down
38 changes: 0 additions & 38 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,44 +229,6 @@ function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, righ
return logp, vi
end

# Check that the ADTypeCheckContext works as expected.
Test.@testset "ADTypeCheckContext" begin
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
# TODO: Mooncake
# Turing.AutoMooncake(config=nothing),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
Test.@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# List of AD backends to test.

Expand Down
50 changes: 50 additions & 0 deletions test/test_utils/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Module for testing the test utils themselves."""
module TestUtilsTests

using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using Test: @test, @testset, @test_throws
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

# Check that the ADTypeCheckContext works as expected.
@testset "ADTypeCheckContext" begin
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
# TODO: Mooncake
# Turing.AutoMooncake(config=nothing),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

end

0 comments on commit 397d1a7

Please sign in to comment.