Skip to content

Commit

Permalink
C++ tree with path API
Browse files Browse the repository at this point in the history
Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

PiperOrigin-RevId: 694219933
  • Loading branch information
IvyZX authored and ChexDev committed Nov 18, 2024
1 parent 1dc7862 commit 90649be
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _flatten_with_path(dcls):
k = jax.tree_util.GetAttrKey(k)
path.append((k, v))
keys.append(k)
return path, keys
return path, tuple([k.name for k in keys])


@functools.cache
Expand Down

0 comments on commit 90649be

Please sign in to comment.