Skip to content

Commit

Permalink
copy_function_with_new_rngs warns with JAXLinker
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Jan 10, 2025
1 parent 88308f2 commit 7c597a8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
10 changes: 10 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.link.jax.linker import JAXLinker
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
Expand Down Expand Up @@ -1208,6 +1209,15 @@ def copy_function_with_new_rngs(
fn_ = fn.f if isinstance(fn, PointFunc) else fn
shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
n_shared_rngs = len(shared_rngs)
if n_shared_rngs > 0 and isinstance(fn_.maker.linker, JAXLinker):
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
# used internally are not the ones that `function.get_shared()` returns.
warnings.warn(
"At the moment, it is not possible to set the random generator's key for "
"JAX linked functions. This means that the draws yielded by the random "
"variables that are requested by 'Deterministic' will not be reproducible."
)
return fn
swap = {
old_shared_rng: shared(rng, borrow=True)
for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)
Expand Down
27 changes: 22 additions & 5 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,12 +929,29 @@ def trace_backend(request):
return trace


def test_random_deterministics(trace_backend):
@pytest.fixture(scope="function", params=["FAST_COMPILE", "NUMBA", "JAX"])
def pytensor_mode(request):
return request.param


def test_random_deterministics(trace_backend, pytensor_mode):
with pm.Model() as m:
x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero
pm.Deterministic("y", x + pm.Normal.dist())

idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)

assert idata1.posterior.equals(idata2.posterior)
if pytensor_mode == "JAX":
expected_warning = (
"At the moment, it is not possible to set the random generator's key for "
"JAX linked functions. This means that the draws yielded by the random "
"variables that are requested by 'Deterministic' will not be reproducible."
)
with pytest.warns(UserWarning, match=expected_warning):
with pytensor.config.change_flags(mode=pytensor_mode):
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
assert not idata1.posterior.equals(idata2.posterior)
else:
with pytensor.config.change_flags(mode=pytensor_mode):
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
assert idata1.posterior.equals(idata2.posterior)

0 comments on commit 7c597a8

Please sign in to comment.