diff --git a/torax/fvm/residual_and_loss.py b/torax/fvm/residual_and_loss.py index a6b791a0..27c03120 100644 --- a/torax/fvm/residual_and_loss.py +++ b/torax/fvm/residual_and_loss.py @@ -54,6 +54,7 @@ 'convection_neumann_mode', 'theta_imp', ], + assert_max_traces=1, ) def theta_method_matrix_equation( dt: jax.Array, @@ -296,6 +297,7 @@ def theta_method_block_residual( 'source_models', 'evolving_names', ], + assert_max_traces=1, ) def theta_method_block_jacobian(*args, **kwargs): return jax.jacfwd(theta_method_block_residual, has_aux=True)(*args, **kwargs) @@ -309,6 +311,7 @@ def theta_method_block_jacobian(*args, **kwargs): 'source_models', 'evolving_names', ], + assert_max_traces=1, ) def theta_method_block_loss( x_new_guess_vec: jax.Array, diff --git a/torax/jax_utils.py b/torax/jax_utils.py index 1f57cd0a..2049613c 100644 --- a/torax/jax_utils.py +++ b/torax/jax_utils.py @@ -17,7 +17,7 @@ import contextlib import dataclasses import os -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, Sequence, TypeVar, Union import chex import equinox as eqx @@ -172,7 +172,7 @@ def jax_default(value: chex.Numeric) -> ...: def compat_linspace( start: Union[chex.Numeric, jax.Array], stop: jax.Array, num: jax.Array -)-> jax.Array: +) -> jax.Array: """See np.linspace. This implementation of a subset of the linspace API reproduces the @@ -229,11 +229,38 @@ def is_tracer(var: jax.Array) -> bool: assert False # Should be unreachable -def jit(*args, **kwargs) -> Callable[..., Any]: - """Calls jax.jit iff TORAX_COMPILATION_ENABLED is True.""" +def jit( + fun, + *, + static_argnums: int | Sequence[int] | None = None, + static_argnames: str | Iterable[str] | None = None, + assert_max_traces: int | None = None, +) -> Callable[..., Any]: + """Custom JIT for Torax. + + Args: + fun: The function to jit. + static_argnums: optional, an int or collection of ints that specify which + positional arguments to treat as static (trace- and compile-time + constant). + static_argnames: optional, a string or collection of strings specifying + which named arguments to treat as static (compile-time constant). + assert_max_traces: if not `None`, checks that the function `fun` is + re-traced at most `assert_max_traces` times during program execution. + Raises a `AssertionError` if not. + + Returns: + A JITted version of `fun` iff `TORAX_COMPILATION_ENABLED=True` and the + original `fun` if not. + """ + if _COMPILATION_ENABLED: - return jax.jit(*args, **kwargs) - return args[0] + if assert_max_traces is not None: + fun = chex.assert_max_traces(fun, n=assert_max_traces) + return jax.jit( + fun, static_argnums=static_argnums, static_argnames=static_argnames + ) + return fun def py_while( diff --git a/torax/tests/jax_utils.py b/torax/tests/jax_utils.py index 53ad9c1e..98f27630 100644 --- a/torax/tests/jax_utils.py +++ b/torax/tests/jax_utils.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +import chex import jax from jax import numpy as jnp from torax import jax_utils @@ -63,6 +64,26 @@ def test_enable_errors(self): self._should_error() + def test_jit(self): + """Test for jax_utils.jit.""" + + x1 = jnp.array(3.3) + + def f(x, y): + if y: + return jnp.sin(x) + else: + return jnp.cos(x) + + out_ref = f(x1, y=True) + + f_jit = jax_utils.jit(f, static_argnames=["y"], assert_max_traces=1) + self.assertTrue(hasattr(f_jit, "lower")) + chex.assert_trees_all_close(f_jit(x1, y=True), out_ref) + + with self.assertRaises(AssertionError): + f_jit(jnp.array([3.3]), y=False) + if __name__ == "__main__": absltest.main()