Skip to content

Commit

Permalink
Fix tests after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanderlee committed Jun 24, 2024
1 parent 8443336 commit 8a0f837
Showing 1 changed file with 53 additions and 29 deletions.
82 changes: 53 additions & 29 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def dataclass(
frozen: bool = False,
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
) -> Type[_U]: ...
) -> Type[_U]:
...


@overload
Expand All @@ -215,7 +216,8 @@ def dataclass(
frozen: bool = False,
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
) -> Callable[[Type[_U]], Type[_U]]: ...
) -> Callable[[Type[_U]], Type[_U]]:
...


# _cls should never be specified by keyword, so start it with an
Expand Down Expand Up @@ -280,13 +282,15 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:


@overload
def add_schema(_cls: Type[_U]) -> Type[_U]: ...
def add_schema(_cls: Type[_U]) -> Type[_U]:
...


@overload
def add_schema(
base_schema: Optional[Type[marshmallow.Schema]] = None,
) -> Callable[[Type[_U]], Type[_U]]: ...
) -> Callable[[Type[_U]], Type[_U]]:
...


@overload
Expand All @@ -295,7 +299,8 @@ def add_schema(
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
stacklevel: int = 1,
) -> Type[_U]: ...
) -> Type[_U]:
...


def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
Expand Down Expand Up @@ -348,7 +353,8 @@ def class_schema(
*,
globalns: Optional[Dict[str, Any]] = None,
localns: Optional[Dict[str, Any]] = None,
) -> Type[marshmallow.Schema]: ...
) -> Type[marshmallow.Schema]:
...


@overload
Expand All @@ -358,7 +364,8 @@ def class_schema(
clazz_frame: Optional[types.FrameType] = None,
*,
globalns: Optional[Dict[str, Any]] = None,
) -> Type[marshmallow.Schema]: ...
) -> Type[marshmallow.Schema]:
...


def class_schema(
Expand Down Expand Up @@ -573,7 +580,8 @@ def _internal_class_schema(
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
else:
class_name = clazz.__name__
# generic aliases do not have a __name__ prior python 3.10
class_name = getattr(clazz, "__name__", repr(clazz))

schema_ctx.seen_classes[clazz] = class_name

Expand Down Expand Up @@ -613,11 +621,20 @@ def _internal_class_schema(
# Determine whether we should include non-init fields
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = {}
if not is_generic_type(clazz):
type_hints = _get_type_hints(clazz, schema_ctx)

attributes.update(
(
field.name,
field_for_schema(
_get_field_type_hints(field, schema_ctx),
_field_for_schema(
(
type_hints[field.name]
if not is_generic_type(clazz)
else _get_generic_type_hints(field.type, schema_ctx)
),
_get_field_default(field),
field.metadata,
base_schema,
Expand Down Expand Up @@ -710,7 +727,7 @@ def _field_for_generic_type(
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin in (list, List):
child_type = field_for_schema(
child_type = _field_for_schema(
arguments[0],
base_schema=base_schema,
)
Expand All @@ -726,15 +743,15 @@ def _field_for_generic_type(
):
from . import collection_field

child_type = field_for_schema(
child_type = _field_for_schema(
arguments[0],
base_schema=base_schema,
)
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
if origin in (set, Set):
from . import collection_field

child_type = field_for_schema(
child_type = _field_for_schema(
arguments[0],
base_schema=base_schema,
)
Expand All @@ -744,7 +761,7 @@ def _field_for_generic_type(
if origin in (frozenset, FrozenSet):
from . import collection_field

child_type = field_for_schema(
child_type = _field_for_schema(
arguments[0],
base_schema=base_schema,
)
Expand All @@ -753,7 +770,7 @@ def _field_for_generic_type(
)
if origin in (tuple, Tuple):
children = tuple(
field_for_schema(
_field_for_schema(
arg,
base_schema=base_schema,
)
Expand Down Expand Up @@ -980,7 +997,7 @@ def _field_for_schema(
)
else:
subtyp = Any
return field_for_schema(subtyp, default, metadata, base_schema)
return _field_for_schema(subtyp, default, metadata, base_schema)

annotated_field = _field_for_annotated_type(typ, **metadata)
if annotated_field:
Expand Down Expand Up @@ -1081,30 +1098,37 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
)


def _get_field_type_hints(
field: dataclasses.Field,
schema_ctx: Optional[_SchemaContext] = None,
) -> type:
"""typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works."""

class X:
x: field.type # type: ignore[name-defined]

def _get_type_hints(
obj,
schema_ctx: _SchemaContext,
):
if sys.version_info >= (3, 9):
type_hints = get_type_hints(
X,
obj,
globalns=schema_ctx.globalns,
localns=schema_ctx.localns,
include_extras=True,
)["x"]
)
else:
type_hints = get_type_hints(
X, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)["x"]
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)

return type_hints


def _get_generic_type_hints(
obj,
schema_ctx: _SchemaContext,
) -> type:
"""typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works."""

class X:
x: obj # type: ignore[name-defined]

return _get_type_hints(X, schema_ctx)["x"]


def _is_generic_alias(clazz: type) -> bool:
"""
Check if given class is a generic alias of a class is
Expand Down

0 comments on commit 8a0f837

Please sign in to comment.