diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f7d43470e..4015ab331 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -4,9 +4,11 @@ on: push: branches: - master + - backport-* pull_request: branches: - master + - backport-* merge_group: types: [checks_requested] diff --git a/src/utils.jl b/src/utils.jl index a809fda17..5fedd3039 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -941,10 +941,10 @@ end """ float_type_with_fallback(x) -Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`. +Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`. """ -float_type_with_fallback(::Type) = Real -float_type_with_fallback(::Type{Union{}}) = Real +float_type_with_fallback(::Type) = float(Real) +float_type_with_fallback(::Type{Union{}}) = float(Real) float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) """ diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index f1d805505..8aa19241d 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -342,4 +342,19 @@ model = state_space(y, length(t)) @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n end + + if Threads.nthreads() > 1 + @testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin + @model function f(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + x ~ Normal(m, 1) + end + model = f(1) + chain = sample(model, NUTS(), MCMCThreads(), 10, 2); + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + end end