Skip to content

Commit

Permalink
Stricter types for evaluate!! methods (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru authored Jul 12, 2024
1 parent 3b3840d commit ad848b2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -909,13 +909,18 @@ function AbstractPPL.evaluate!!(model::Model, context::AbstractContext)
return evaluate!!(model, VarInfo(), context)
end

function AbstractPPL.evaluate!!(model::Model, args...)
function AbstractPPL.evaluate!!(
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
)
return evaluate!!(model, Random.default_rng(), args...)
end

# without VarInfo
function AbstractPPL.evaluate!!(
model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...
model::Model,
rng::Random.AbstractRNG,
sampler::AbstractSampler,
args::AbstractContext...,
)
return evaluate!!(model, rng, VarInfo(), sampler, args...)
end
Expand Down
14 changes: 14 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,18 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end
end

@testset "Erroneous model call" begin
# Calling a model with the wrong arguments used to lead to infinite recursion, see
# https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it.
@model function a_model(x)
m ~ Normal(0, 1)
x ~ Normal(m, 1)
return nothing
end
instance = a_model(1.0)
# `instance` should be called with rng, context, etc., but one may easily get
# confused and call it the way you are meant to call `a_model`.
@test_throws MethodError instance(1.0)
end
end

0 comments on commit ad848b2

Please sign in to comment.