Skip to content

Commit

Permalink
updated nnx dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Apr 19, 2024
1 parent 1fbe17d commit 787c96f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 15 deletions.
13 changes: 7 additions & 6 deletions flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
rngs = rngs or self.rngs
deterministic = first_from(
deterministic,
self.deterministic,
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument, class attribute, or nnx.flag.""",
as either a __call__ argument or class attribute""",
)

if (self.rate == 0.0) or deterministic:
Expand All @@ -90,10 +89,12 @@ def __call__(
if self.rate == 1.0:
return jnp.zeros_like(inputs)

if rngs is None:
raise ValueError(
"Dropout needs to generate a random mask but no 'rngs' were provided."
)
rngs = first_from(
rngs,
self.rngs,
error_msg="""`deterministic` is False, but no `rngs` argument was provided to Dropout
as either a __call__ argument or class attribute.""",
)

keep_prob = 1.0 - self.rate
rng = rngs[self.rng_collection]()
Expand Down
71 changes: 62 additions & 9 deletions flax/experimental/nnx/tests/nn/test_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,81 @@


import jax.numpy as jnp
import numpy as np

from flax.experimental import nnx

import pytest


class TestStochastic:
def test_dropout_internal_rngs(self):
n = 0
m = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))
m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))
m2 = nnx.Dropout(rate=0.5, deterministic=False)
rngs2 = nnx.Rngs(dropout=0)

@nnx.jit
def f(m, x):
def f(m, x, rngs=None):
nonlocal n
n += 1
return m(x)
return m(x, rngs=rngs)

x = jnp.ones((1, 10))
assert m.rngs is not None and m.rngs.dropout.count.value == 0
assert m1.rngs is not None and m1.rngs.dropout.count.value == 0

y = f(m, x)
y1 = f(m1, x)
assert n == 1
assert m.rngs.dropout.count.value == 1
assert m1.rngs.dropout.count.value == 1
y2 = f(m2, x, rngs=rngs2)
assert n == 2
assert rngs2.dropout.count.value == 1
np.testing.assert_allclose(y1, y2)

y = f(m, x)
assert n == 1
assert m.rngs.dropout.count.value == 2
y1 = f(m1, x)
assert m1.rngs.dropout.count.value == 2
y2 = f(m2, x, rngs=rngs2)
assert rngs2.dropout.count.value == 2
np.testing.assert_allclose(y1, y2)

assert n == 2

def test_dropout_rng_override(self):
m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))
m2 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=1))
x = jnp.ones((1, 10))

y1 = m1(x)
y2 = m2(x)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y1, y2)

y2 = m2(x, rngs=nnx.Rngs(dropout=0))
np.testing.assert_allclose(y1, y2)

def test_dropout_arg_override(self):
m = nnx.Dropout(rate=0.5)
x = jnp.ones((1, 10))

# no deterministic arg provided
with pytest.raises(
ValueError, match='No `deterministic` argument was provided to Dropout'
):
m(x)
# deterministic call arg provided
m(x, deterministic=True)
# deterministic constructor arg provided
m.set_attributes(deterministic=True)
y = m(x)
# both deterministic call and constructor arg provided
with pytest.raises(AssertionError):
np.testing.assert_allclose(
y, m(x, deterministic=False, rngs=nnx.Rngs(dropout=0))
)
# no rng arg provided
m.set_attributes(deterministic=False)
with pytest.raises(
ValueError,
match='`deterministic` is False, but no `rngs` argument was provided to Dropout',
):
m(x)

0 comments on commit 787c96f

Please sign in to comment.