From 89d4c6b7fdf27e9ed46bec85f969b3fd11de178d Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 11 Jan 2022 18:49:38 +0000 Subject: [PATCH 1/5] support generic dataclasses --- marshmallow_dataclass/__init__.py | 50 ++++++++++++++++++++++++++++--- tests/test_class_schema.py | 27 +++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 0459b95e..6513d3b5 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -313,7 +313,9 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): + if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass( + clazz + ): clazz = dataclasses.dataclass(clazz) return _internal_class_schema(clazz, base_schema) @@ -323,8 +325,7 @@ def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None ) -> Type[marshmallow.Schema]: try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) + class_name, fields = _dataclass_name_and_fields(clazz) except TypeError: # Not a dataclass try: warnings.warn( @@ -363,7 +364,7 @@ def _internal_class_schema( if field.init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -662,6 +663,47 @@ def _get_field_default(field: dataclasses.Field): return field.default +def _is_generic_alias_of_dataclass(clazz: type) -> bool: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass( + typing_inspect.get_origin(clazz) + ) + + +# noinspection PyDataclass +def _dataclass_name_and_fields( + clazz: type, +) -> Tuple[str, Tuple[dataclasses.Field, ...]]: + if not _is_generic_alias_of_dataclass(clazz): + return clazz.__name__, dataclasses.fields(clazz) + + base_dataclass = typing_inspect.get_origin(clazz) + base_parameters = typing_inspect.get_parameters(base_dataclass) + type_arguments = typing_inspect.get_args(clazz) + params_to_args = dict(zip(base_parameters, type_arguments)) + non_generic_fields = [ # swap generic typed fields with types in given type arguments + ( + f.name, + params_to_args.get(f.type, f.type), + dataclasses.field( + default=f.default, + # ignoring mypy: https://github.com/python/mypy/issues/6910 + default_factory=f.default_factory, # type: ignore + init=f.init, + metadata=f.metadata, + ), + ) + for f in dataclasses.fields(base_dataclass) + ] + non_generic_dataclass = dataclasses.make_dataclass( + cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields + ) + return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass) + + def NewType( name: str, typ: Type[_U], diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 02f3ba3d..36b04887 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -324,6 +324,33 @@ class J: [validator_a, validator_b, validator_c, validator_d], ) + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class Nested: + data: SimpleGeneric[int] + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_n = class_schema(Nested)() + self.assertEqual( + Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}}) + ) + self.assertEqual( + schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}} + ) + with self.assertRaises(ValidationError): + schema_n.load({"data": {"data": "str"}}) + if __name__ == "__main__": unittest.main() From 7ab7b524417a32e3caf7286873359fb5e07a736c Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 21 Apr 2022 18:17:53 +0200 Subject: [PATCH 2/5] update pre-commit hooks --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dde3d4df..b04918f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.31.0 # Later versions do not support python 3.6 + rev: v2.32.0 hooks: - id: pyupgrade args: ["--py36-plus"] @@ -15,13 +15,13 @@ repos: - id: flake8 additional_dependencies: ['flake8-bugbear==19.8.0'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 # Later versions do not support python 3.6 + rev: v0.942 hooks: - id: mypy additional_dependencies: [marshmallow-enum,typeguard,marshmallow] args: [--show-error-codes] - repo: https://github.com/asottile/blacken-docs - rev: v1.12.0 # Later versions do not support python 3.6 + rev: v1.12.1 hooks: - id: blacken-docs additional_dependencies: [black==19.3b0] From b3d3797c79343806b7db79d6b74f664f28cdba8b Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Mon, 28 Nov 2022 17:05:40 +0800 Subject: [PATCH 3/5] support nested generic dataclasses --- marshmallow_dataclass/__init__.py | 80 +++++++++++++++++-------------- tests/test_class_schema.py | 34 ++++++++++--- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 279488da..2cf28438 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -356,20 +356,27 @@ def class_schema( del current_frame _RECURSION_GUARD.seen_classes = {} try: - return _internal_class_schema(clazz, base_schema, clazz_frame) + return _internal_class_schema(clazz, base_schema, clazz_frame, None) finally: _RECURSION_GUARD.seen_classes.clear() +def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + if _is_generic_alias_of_dataclass(clazz): + clazz = typing_inspect.get_origin(clazz) + return dataclasses.fields(clazz) + + @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, clazz_frame: types.FrameType = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> Type[marshmallow.Schema]: _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ try: - class_name, fields = _dataclass_name_and_fields(clazz) + fields = _dataclass_fields(clazz) except TypeError: # Not a dataclass try: warnings.warn( @@ -384,7 +391,9 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) + return _internal_class_schema( + created_dataclass, base_schema, clazz_frame, generic_params_to_args + ) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -397,10 +406,11 @@ def _internal_class_schema( if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST } + if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None: + generic_params_to_args = _generic_params_to_args(clazz) + + type_hints = _dataclass_type_hints(clazz, clazz_frame, generic_params_to_args) # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None - ) attributes.update( ( field.name, @@ -410,13 +420,14 @@ def _internal_class_schema( field.metadata, base_schema, clazz_frame, + generic_params_to_args, ), ) for field in fields if field.init ) - schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -551,7 +562,7 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=field_for_schema( @@ -603,6 +614,7 @@ def field_for_schema( metadata: Mapping[str, Any] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, typ_frame: Optional[types.FrameType] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -732,7 +744,7 @@ def field_for_schema( nested_schema or forward_reference or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) + or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args) ) return marshmallow.fields.Nested(nested, **metadata) @@ -786,35 +798,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: ) -# noinspection PyDataclass -def _dataclass_name_and_fields( - clazz: type, -) -> Tuple[str, Tuple[dataclasses.Field, ...]]: - if not _is_generic_alias_of_dataclass(clazz): - return clazz.__name__, dataclasses.fields(clazz) - +def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]: base_dataclass = typing_inspect.get_origin(clazz) base_parameters = typing_inspect.get_parameters(base_dataclass) type_arguments = typing_inspect.get_args(clazz) - params_to_args = dict(zip(base_parameters, type_arguments)) - non_generic_fields = [ # swap generic typed fields with types in given type arguments - ( - f.name, - params_to_args.get(f.type, f.type), - dataclasses.field( - default=f.default, - # ignoring mypy: https://github.com/python/mypy/issues/6910 - default_factory=f.default_factory, # type: ignore - init=f.init, - metadata=f.metadata, - ), - ) - for f in dataclasses.fields(base_dataclass) - ] - non_generic_dataclass = dataclasses.make_dataclass( - cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields - ) - return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass) + return tuple(zip(base_parameters, type_arguments)) + + +def _dataclass_type_hints( + clazz: type, + clazz_frame: types.FrameType = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, +) -> Mapping[str, type]: + localns = clazz_frame.f_locals if clazz_frame else None + if not _is_generic_alias_of_dataclass(clazz): + return get_type_hints(clazz, localns=localns) + # dataclass is generic + generic_type_hints = get_type_hints(typing_inspect.get_origin(clazz), localns) + generic_params_map = dict(generic_params_to_args if generic_params_to_args else {}) + + def _get_hint(_t: type) -> type: + if isinstance(_t, TypeVar): + return generic_params_map[_t] + return _t + + return { + field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items() + } def NewType( diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index db051116..5f643a17 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -14,7 +14,7 @@ from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass class TestClassSchema(unittest.TestCase): @@ -409,24 +409,45 @@ class SimpleGeneric(typing.Generic[T]): data: T @dataclasses.dataclass - class Nested: + class NestedFixed: data: SimpleGeneric[int] + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + schema_s = class_schema(SimpleGeneric[str])() self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) with self.assertRaises(ValidationError): schema_s.load({"data": 2}) - schema_n = class_schema(Nested)() + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() self.assertEqual( - Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}}) + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), ) self.assertEqual( - schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}} + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, ) with self.assertRaises(ValidationError): - schema_n.load({"data": {"data": "str"}}) + schema_nested_generic.load({"data": {"data": "str"}}) def test_recursive_reference(self): @dataclasses.dataclass @@ -461,5 +482,6 @@ class Second: {"first": {"second": {"first": None}}}, ) + if __name__ == "__main__": unittest.main() From f9d894bc464957fa744374e5189d9b5ecbfb60f3 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Fri, 9 Dec 2022 22:19:06 +1100 Subject: [PATCH 4/5] add test for repeated fields, fix __name__ attr for py<3.10 --- marshmallow_dataclass/__init__.py | 54 +++++++++++++++++++++++++------ tests/test_class_schema.py | 25 ++++++++++++++ 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 2cf28438..74a3c30c 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -374,7 +374,10 @@ def _internal_class_schema( clazz_frame: types.FrameType = None, generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ + # generic aliases do not have a __name__ prior python 3.10 + _name = getattr(clazz, "__name__", repr(clazz)) + + _RECURSION_GUARD.seen_classes[clazz] = _name try: fields = _dataclass_fields(clazz) except TypeError: # Not a dataclass @@ -427,7 +430,7 @@ def _internal_class_schema( if field.init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -446,6 +449,7 @@ def _field_by_supertype( metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], typ_frame: Optional[types.FrameType], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -477,6 +481,7 @@ def _field_by_supertype( default=default, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) @@ -501,6 +506,7 @@ def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], typ_frame: Optional[types.FrameType], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -514,7 +520,10 @@ def _field_for_generic_type( if origin in (list, List): child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) list_type = cast( Type[marshmallow.fields.List], @@ -529,14 +538,20 @@ def _field_for_generic_type( from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata @@ -545,14 +560,22 @@ def _field_for_generic_type( from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) + field_for_schema( + arg, + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, + ) for arg in arguments ) tuple_type = cast( @@ -566,10 +589,16 @@ def _field_for_generic_type( dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame + arguments[1], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), **metadata, ) @@ -587,6 +616,7 @@ def _field_for_generic_type( metadata=metadata, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) from . import union_field @@ -599,6 +629,7 @@ def _field_for_generic_type( metadata={"required": True}, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), ) for subtyp in subtypes @@ -707,7 +738,9 @@ def field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) + return field_for_schema( + subtyp, default, metadata, base_schema, typ_frame, generic_params_to_args + ) # Generic types generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) @@ -725,6 +758,7 @@ def field_for_schema( metadata=metadata, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) # enumerations diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 5f643a17..52df73a1 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -449,6 +449,31 @@ class NestedGeneric(typing.Generic[T]): with self.assertRaises(ValidationError): schema_nested_generic.load({"data": {"data": "str"}}) + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + def test_recursive_reference(self): @dataclasses.dataclass class Tree: From de27c32150643c3cdcc1ff15003b8884d6366886 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Thu, 15 Dec 2022 15:28:57 +1100 Subject: [PATCH 5/5] support py3.6 --- marshmallow_dataclass/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index bff9fe52..2515ab86 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -827,8 +827,13 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ - return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass( - typing_inspect.get_origin(clazz) + is_generic = typing_inspect.is_generic_type(clazz) + type_arguments = typing_inspect.get_args(clazz) + origin_class = typing_inspect.get_origin(clazz) + return ( + is_generic + and len(type_arguments) > 0 + and dataclasses.is_dataclass(origin_class) )