Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for trace checking in jax_utils.jit. #359

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torax/fvm/residual_and_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'convection_neumann_mode',
'theta_imp',
],
assert_max_traces=1,
)
def theta_method_matrix_equation(
dt: jax.Array,
Expand Down Expand Up @@ -296,6 +297,7 @@ def theta_method_block_residual(
'source_models',
'evolving_names',
],
assert_max_traces=1,
)
def theta_method_block_jacobian(*args, **kwargs):
return jax.jacfwd(theta_method_block_residual, has_aux=True)(*args, **kwargs)
Expand All @@ -309,6 +311,7 @@ def theta_method_block_jacobian(*args, **kwargs):
'source_models',
'evolving_names',
],
assert_max_traces=1,
)
def theta_method_block_loss(
x_new_guess_vec: jax.Array,
Expand Down
39 changes: 33 additions & 6 deletions torax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import contextlib
import dataclasses
import os
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Iterable, Optional, Sequence, TypeVar, Union

import chex
import equinox as eqx
Expand Down Expand Up @@ -172,7 +172,7 @@ def jax_default(value: chex.Numeric) -> ...:

def compat_linspace(
start: Union[chex.Numeric, jax.Array], stop: jax.Array, num: jax.Array
)-> jax.Array:
) -> jax.Array:
"""See np.linspace.

This implementation of a subset of the linspace API reproduces the
Expand Down Expand Up @@ -229,11 +229,38 @@ def is_tracer(var: jax.Array) -> bool:
assert False # Should be unreachable


def jit(*args, **kwargs) -> Callable[..., Any]:
"""Calls jax.jit iff TORAX_COMPILATION_ENABLED is True."""
def jit(
fun,
*,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
assert_max_traces: int | None = None,
) -> Callable[..., Any]:
"""Custom JIT for Torax.

Args:
fun: The function to jit.
static_argnums: optional, an int or collection of ints that specify which
positional arguments to treat as static (trace- and compile-time
constant).
static_argnames: optional, a string or collection of strings specifying
which named arguments to treat as static (compile-time constant).
assert_max_traces: if not `None`, checks that the function `fun` is
re-traced at most `assert_max_traces` times during program execution.
Raises a `AssertionError` if not.

Returns:
A JITted version of `fun` iff `TORAX_COMPILATION_ENABLED=True` and the
original `fun` if not.
"""

if _COMPILATION_ENABLED:
return jax.jit(*args, **kwargs)
return args[0]
if assert_max_traces is not None:
fun = chex.assert_max_traces(fun, n=assert_max_traces)
return jax.jit(
fun, static_argnums=static_argnums, static_argnames=static_argnames
)
return fun


def py_while(
Expand Down
21 changes: 21 additions & 0 deletions torax/tests/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
from jax import numpy as jnp
from torax import jax_utils
Expand Down Expand Up @@ -63,6 +64,26 @@ def test_enable_errors(self):

self._should_error()

def test_jit(self):
"""Test for jax_utils.jit."""

x1 = jnp.array(3.3)

def f(x, y):
if y:
return jnp.sin(x)
else:
return jnp.cos(x)

out_ref = f(x1, y=True)

f_jit = jax_utils.jit(f, static_argnames=["y"], assert_max_traces=1)
self.assertTrue(hasattr(f_jit, "lower"))
chex.assert_trees_all_close(f_jit(x1, y=True), out_ref)

with self.assertRaises(AssertionError):
f_jit(jnp.array([3.3]), y=False)


if __name__ == "__main__":
absltest.main()
Loading