From 787c96f7b635f18cb4e5360cb497557fcc783261 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 10 Apr 2024 15:08:37 -0700 Subject: [PATCH] updated nnx dropout --- flax/experimental/nnx/nnx/nn/stochastic.py | 13 ++-- .../nnx/tests/nn/test_stochastic.py | 71 ++++++++++++++++--- 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index 684ed28e02..a4a676df7e 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -75,12 +75,11 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - rngs = rngs or self.rngs deterministic = first_from( deterministic, self.deterministic, error_msg="""No `deterministic` argument was provided to Dropout - as either a __call__ argument, class attribute, or nnx.flag.""", + as either a __call__ argument or class attribute""", ) if (self.rate == 0.0) or deterministic: @@ -90,10 +89,12 @@ def __call__( if self.rate == 1.0: return jnp.zeros_like(inputs) - if rngs is None: - raise ValueError( - "Dropout needs to generate a random mask but no 'rngs' were provided." - ) + rngs = first_from( + rngs, + self.rngs, + error_msg="""`deterministic` is False, but no `rngs` argument was provided to Dropout + as either a __call__ argument or class attribute.""", + ) keep_prob = 1.0 - self.rate rng = rngs[self.rng_collection]() diff --git a/flax/experimental/nnx/tests/nn/test_stochastic.py b/flax/experimental/nnx/tests/nn/test_stochastic.py index 21bdbabbd5..64c3c8ccca 100644 --- a/flax/experimental/nnx/tests/nn/test_stochastic.py +++ b/flax/experimental/nnx/tests/nn/test_stochastic.py @@ -14,28 +14,81 @@ import jax.numpy as jnp +import numpy as np from flax.experimental import nnx +import pytest + class TestStochastic: def test_dropout_internal_rngs(self): n = 0 - m = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) + m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) + m2 = nnx.Dropout(rate=0.5, deterministic=False) + rngs2 = nnx.Rngs(dropout=0) @nnx.jit - def f(m, x): + def f(m, x, rngs=None): nonlocal n n += 1 - return m(x) + return m(x, rngs=rngs) x = jnp.ones((1, 10)) - assert m.rngs is not None and m.rngs.dropout.count.value == 0 + assert m1.rngs is not None and m1.rngs.dropout.count.value == 0 - y = f(m, x) + y1 = f(m1, x) assert n == 1 - assert m.rngs.dropout.count.value == 1 + assert m1.rngs.dropout.count.value == 1 + y2 = f(m2, x, rngs=rngs2) + assert n == 2 + assert rngs2.dropout.count.value == 1 + np.testing.assert_allclose(y1, y2) - y = f(m, x) - assert n == 1 - assert m.rngs.dropout.count.value == 2 + y1 = f(m1, x) + assert m1.rngs.dropout.count.value == 2 + y2 = f(m2, x, rngs=rngs2) + assert rngs2.dropout.count.value == 2 + np.testing.assert_allclose(y1, y2) + + assert n == 2 + + def test_dropout_rng_override(self): + m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) + m2 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=1)) + x = jnp.ones((1, 10)) + + y1 = m1(x) + y2 = m2(x) + with pytest.raises(AssertionError): + np.testing.assert_allclose(y1, y2) + + y2 = m2(x, rngs=nnx.Rngs(dropout=0)) + np.testing.assert_allclose(y1, y2) + + def test_dropout_arg_override(self): + m = nnx.Dropout(rate=0.5) + x = jnp.ones((1, 10)) + + # no deterministic arg provided + with pytest.raises( + ValueError, match='No `deterministic` argument was provided to Dropout' + ): + m(x) + # deterministic call arg provided + m(x, deterministic=True) + # deterministic constructor arg provided + m.set_attributes(deterministic=True) + y = m(x) + # both deterministic call and constructor arg provided + with pytest.raises(AssertionError): + np.testing.assert_allclose( + y, m(x, deterministic=False, rngs=nnx.Rngs(dropout=0)) + ) + # no rng arg provided + m.set_attributes(deterministic=False) + with pytest.raises( + ValueError, + match='`deterministic` is False, but no `rngs` argument was provided to Dropout', + ): + m(x)