Skip to content

Commit

Permalink
Fix generic dataclasses
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 521409933
  • Loading branch information
stompchicken authored and ChexDev committed Apr 3, 2023
1 parent 06426a1 commit 0ac42e3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
80 changes: 47 additions & 33 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))


def mappable_dataclass(cls):
def mappable_dataclass(cls, update_constructor=False):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
Allows to traverse dataclasses in methods from `dm-tree` library.
Expand All @@ -37,39 +37,39 @@ def mappable_dataclass(cls):
Args:
cls: A dataclass to mutate.
update_constructor: Whether to do something mysterious to the constructor.
Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

# Define methods for compatibility with `collections.abc.Mapping`.
setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
setattr(cls, "__len__", lambda self: len(self.__dict__))
setattr(cls, "__iter__", lambda self: iter(self.__dict__))

# Update constructor.
orig_init = cls.__init__
all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]

@functools.wraps(orig_init)
def new_init(self, *orig_args, **orig_kwargs):
if (orig_args and orig_kwargs) or len(orig_args) > 1:
raise ValueError(
"Mappable dataclass constructor doesn't support positional args."
"(it has the same constructor as python dict)")
all_kwargs = dict(*orig_args, **orig_kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")

# Pass only arguments corresponding to fields with `init=True`.
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
orig_init(self, **valid_kwargs)

cls.__init__ = new_init
if update_constructor:
orig_init = cls.__init__
all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]

@functools.wraps(orig_init)
def new_init(self, *orig_args, **orig_kwargs):
if (orig_args and orig_kwargs) or len(orig_args) > 1:
raise ValueError(
"Mappable dataclass constructor doesn't support positional args."
"(it has the same constructor as python dict)")
all_kwargs = dict(*orig_args, **orig_kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")

# Pass only arguments corresponding to fields with `init=True`.
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
orig_init(self, **valid_kwargs)

cls.__init__ = new_init

# Update base class to derive from Mapping
dct = dict(cls.__dict__)
Expand Down Expand Up @@ -155,6 +155,14 @@ def __init__(
def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""

if self.mappable_dataclass:
cls = mappable_dataclass(cls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(cls, attr, None) # redefine
delattr(cls, attr) # delete

# Remove once https://github.com/python/cpython/pull/24484 is merged.
for base in cls.__bases__:
if (dataclasses.is_dataclass(base) and
Expand All @@ -169,6 +177,7 @@ def __call__(self, cls):
eq=self.eq,
order=self.order,
unsafe_hash=self.unsafe_hash,
kw_only=True,
frozen=self.frozen)
# pytype: enable=wrong-keyword-args

Expand All @@ -178,16 +187,8 @@ def __call__(self, cls):
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({dcls}).")

if self.mappable_dataclass:
dcls = mappable_dataclass(dcls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(dcls, attr, None) # redefine
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args)))

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand All @@ -209,14 +210,27 @@ def _setstate(self, state):

orig_init = dcls.__init__

all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init]

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

if (args and kwargs) or len(args) > 1:
raise ValueError(
"Mappable dataclass constructor doesn't support positional args."
"(it has the same constructor as python dict)")
all_kwargs = dict(*args, **kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
return orig_init(self, **valid_kwargs)

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
Expand Down
17 changes: 12 additions & 5 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def __post_init__(self, k_init_only):
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True)
elif test_type == 'original':
cls = mappable_dataclass(orig_dataclass(Class))
nested_cls = mappable_dataclass(orig_dataclass(NestedClass))
nested_cls = mappable_dataclass(
orig_dataclass(NestedClass), update_constructor=True
)
else:
raise ValueError(f'Unknown test type: {test_type}')

Expand Down Expand Up @@ -521,19 +523,24 @@ def _is_leaf(value) -> bool:
jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls)

@parameterized.named_parameters(
('mappable', True),
('not_mappable', False),
('mappable_frozen', True, True),
('not_mappable_frozen', False, True),
('mappable_not_frozen', True, False),
('not_mappable_not_frozen', False, False),
)
def test_generic_dataclass(self, mappable):
def test_generic_dataclass(self, mappable, frozen):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable)
@chex_dataclass(mappable_dataclass=mappable, frozen=frozen)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

obj = GenericDataclass(a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

obj = GenericDataclass[np.array](a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True)
Expand Down

0 comments on commit 0ac42e3

Please sign in to comment.