From 1b0586403642dfb6da5ac7b5180ae9f4e37cd6b7 Mon Sep 17 00:00:00 2001 From: Stephen Spencer Date: Mon, 23 May 2022 01:40:43 -0700 Subject: [PATCH] Internal changes PiperOrigin-RevId: 450380384 --- chex/_src/dataclass.py | 39 ++++++++++++++++++++++++++++++++++--- chex/_src/dataclass_test.py | 11 ++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index f74db778..865dd3db 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -20,6 +20,7 @@ from absl import logging import jax +import tree FrozenInstanceError = dataclasses.FrozenInstanceError @@ -62,7 +63,8 @@ def new_init(self, *orig_args, **orig_kwargs): all_kwargs = dict(*orig_args, **orig_kwargs) unknown_kwargs = set(all_kwargs.keys()) - all_fields if unknown_kwargs: - raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.") + raise ValueError( + f"__init__() got unexpected keyword arguments: {unknown_kwargs}.") # Pass only arguments corresponding to fields with `init=True`. valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields} @@ -91,7 +93,7 @@ def dataclass( order=False, unsafe_hash=False, frozen=False, - mappable_dataclass=True, # pylint: disable=redefined-outer-name + mappable_dataclass=False, # pylint: disable=redefined-outer-name ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. @@ -185,7 +187,7 @@ def __call__(self, cls): delattr(dcls, attr) # delete def _from_tuple(args): - return dcls(zip(dcls.__dataclass_fields__.keys(), args)) + return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args))) def _to_tuple(self): return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) @@ -202,6 +204,8 @@ def _getstate(self): def _setstate(self, state): if not class_self.registered: register_dataclass_type_with_jax_tree_util(dcls) + if not class_self.mappable_dataclass: + register_dataclass_type_with_dm_tree(dcls) class_self.registered = True self.__dict__.update(state) @@ -213,6 +217,8 @@ def _setstate(self, state): def _init(self, *args, **kwargs): if not class_self.registered: register_dataclass_type_with_jax_tree_util(dcls) + if not class_self.mappable_dataclass: + register_dataclass_type_with_dm_tree(dcls) class_self.registered = True return orig_init(self, *args, **kwargs) @@ -246,3 +252,30 @@ def register_dataclass_type_with_jax_tree_util(data_class): nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten) except ValueError: logging.info("%s is already registered as JAX PyTree node.", data_class) + + +def register_dataclass_type_with_dm_tree(data_class): + """Register an existing dataclass with dm_tree node registry. + + This will mean that functions in dm_tree will operate over fields of the + dataclass. + + Args: + data_class: A class created using dataclasses.dataclass. It must be + constructable from keyword arguments corresponding to the members exposed + in instance.__dict__. + """ + + def to_iterable(d): + keys, values = jax.util.unzip2(sorted(d.__dict__.items())) + return values, keys, keys + + def from_iterable(keys, values): + return data_class(**dict(zip(keys, values))) + + try: + tree.register_node(data_class, to_iterable, from_iterable) + except ValueError: + logging.log_first_n(logging.INFO, + "%s is already registered as dm_tree node.", 1, + data_class) diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index e4ba950e..a79ee7a0 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -365,12 +365,17 @@ class SimpleDataclass: b: int = 2 SimpleDataclass(a=1, b=3) - with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'): + with self.assertRaisesRegex((ValueError, TypeError), + '.*unexpected keyword argument.*'): SimpleDataclass(a=1, b=3, c=4) - def test_tuple_conversion(self): + @parameterized.named_parameters( + ('non_mappable', False), + ('mappable', True), + ) + def test_tuple_conversion(self, mappable): - @chex_dataclass() + @chex_dataclass(mappable_dataclass=mappable) class SimpleDataclass: b: int a: int