diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index c4e8426..7e57d88 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -282,7 +282,7 @@ def _flatten_with_path(dcls): k = jax.tree_util.GetAttrKey(k) path.append((k, v)) keys.append(k) - return path, keys + return path, tuple([k.name for k in keys]) @functools.cache