Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support generic dataclasses #172

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def class_schema(
>>> class_schema(Custom)().load({})
Custom(name=None)
"""
if not dataclasses.is_dataclass(clazz):
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
clazz
):
clazz = dataclasses.dataclass(clazz)
if not clazz_frame:
current_frame = inspect.currentframe()
Expand All @@ -354,21 +356,27 @@ def class_schema(
del current_frame
_RECURSION_GUARD.seen_classes = {}
try:
return _internal_class_schema(clazz, base_schema, clazz_frame)
return _internal_class_schema(clazz, base_schema, clazz_frame, None)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have to add None here explicitly so we hit the cache for _internal_class_schema. lru_cache uses all given params as key, not adding the default values for not specified params

finally:
_RECURSION_GUARD.seen_classes.clear()


def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
if _is_generic_alias_of_dataclass(clazz):
clazz = typing_inspect.get_origin(clazz)
return dataclasses.fields(clazz)


@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
def _internal_class_schema(
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
clazz_frame: types.FrameType = None,
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
) -> Type[marshmallow.Schema]:
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
try:
# noinspection PyDataclass
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
fields = _dataclass_fields(clazz)
except TypeError: # Not a dataclass
try:
warnings.warn(
Expand All @@ -383,7 +391,9 @@ def _internal_class_schema(
"****** WARNING ******"
)
created_dataclass: type = dataclasses.dataclass(clazz)
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
return _internal_class_schema(
created_dataclass, base_schema, clazz_frame, generic_params_to_args
)
except Exception as exc:
raise TypeError(
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
Expand All @@ -396,10 +406,11 @@ def _internal_class_schema(
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
}

if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None:
generic_params_to_args = _generic_params_to_args(clazz)

type_hints = _dataclass_type_hints(clazz, clazz_frame, generic_params_to_args)
# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = get_type_hints(
clazz, localns=clazz_frame.f_locals if clazz_frame else None
)
attributes.update(
(
field.name,
Expand All @@ -409,6 +420,7 @@ def _internal_class_schema(
field.metadata,
base_schema,
clazz_frame,
generic_params_to_args,
),
)
for field in fields
Expand Down Expand Up @@ -550,7 +562,7 @@ def _field_for_generic_type(
),
)
return tuple_type(children, **metadata)
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=field_for_schema(
Expand Down Expand Up @@ -602,6 +614,7 @@ def field_for_schema(
metadata: Mapping[str, Any] = None,
base_schema: Optional[Type[marshmallow.Schema]] = None,
typ_frame: Optional[types.FrameType] = None,
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
) -> marshmallow.fields.Field:
"""
Get a marshmallow Field corresponding to the given python type.
Expand Down Expand Up @@ -731,7 +744,7 @@ def field_for_schema(
nested_schema
or forward_reference
or _RECURSION_GUARD.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema, typ_frame)
or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args)
)

return marshmallow.fields.Nested(nested, **metadata)
Expand Down Expand Up @@ -775,6 +788,45 @@ def _get_field_default(field: dataclasses.Field):
return field.default


def _is_generic_alias_of_dataclass(clazz: type) -> bool:
"""
Check if given class is a generic alias of a dataclass, if the dataclass is
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
"""
return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass(
typing_inspect.get_origin(clazz)
)


def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]:
base_dataclass = typing_inspect.get_origin(clazz)
base_parameters = typing_inspect.get_parameters(base_dataclass)
type_arguments = typing_inspect.get_args(clazz)
return tuple(zip(base_parameters, type_arguments))


def _dataclass_type_hints(
clazz: type,
clazz_frame: types.FrameType = None,
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
) -> Mapping[str, type]:
localns = clazz_frame.f_locals if clazz_frame else None
if not _is_generic_alias_of_dataclass(clazz):
return get_type_hints(clazz, localns=localns)
# dataclass is generic
generic_type_hints = get_type_hints(typing_inspect.get_origin(clazz), localns)
generic_params_map = dict(generic_params_to_args if generic_params_to_args else {})

def _get_hint(_t: type) -> type:
if isinstance(_t, TypeVar):
return generic_params_map[_t]
return _t

return {
field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items()
}


def NewType(
name: str,
typ: Type[_U],
Expand Down
50 changes: 49 additions & 1 deletion tests/test_class_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
from marshmallow.validate import Validator

from marshmallow_dataclass import class_schema, NewType
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass


class TestClassSchema(unittest.TestCase):
Expand Down Expand Up @@ -401,6 +401,54 @@ class J:
[validator_a, validator_b, validator_c, validator_d],
)

def test_generic_dataclass(self):
T = typing.TypeVar("T")

@dataclasses.dataclass
class SimpleGeneric(typing.Generic[T]):
data: T

@dataclasses.dataclass
class NestedFixed:
data: SimpleGeneric[int]

@dataclasses.dataclass
class NestedGeneric(typing.Generic[T]):
data: SimpleGeneric[T]

self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))

schema_s = class_schema(SimpleGeneric[str])()
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
with self.assertRaises(ValidationError):
schema_s.load({"data": 2})

schema_nested = class_schema(NestedFixed)()
self.assertEqual(
NestedFixed(data=SimpleGeneric(1)),
schema_nested.load({"data": {"data": 1}}),
)
self.assertEqual(
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
{"data": {"data": 1}},
)
with self.assertRaises(ValidationError):
schema_nested.load({"data": {"data": "str"}})

schema_nested_generic = class_schema(NestedGeneric[int])()
self.assertEqual(
NestedGeneric(data=SimpleGeneric(1)),
schema_nested_generic.load({"data": {"data": 1}}),
)
self.assertEqual(
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
{"data": {"data": 1}},
)
with self.assertRaises(ValidationError):
schema_nested_generic.load({"data": {"data": "str"}})

def test_recursive_reference(self):
@dataclasses.dataclass
class Tree:
Expand Down