From 024375ffcdf4517471eacb68a52ce24b7ce85cc1 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 12 Jan 2023 12:18:54 -0800 Subject: [PATCH] Fix bug with repeated use of the sane generic dataclass See https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1358658963 --- marshmallow_dataclass/__init__.py | 72 +++++++++++++++++++++++++------ tests/test_class_schema.py | 2 +- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 2515ab8..15311e9 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -46,6 +46,8 @@ class User: Any, Callable, Dict, + Generic, + Hashable, List, Mapping, NewType as typing_NewType, @@ -367,6 +369,43 @@ def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: return dataclasses.fields(clazz) +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, @@ -377,7 +416,8 @@ def _internal_class_schema( # generic aliases do not have a __name__ prior python 3.10 _name = getattr(clazz, "__name__", repr(clazz)) - _RECURSION_GUARD.seen_classes[clazz] = _name + future: _Future[Type[marshmallow.Schema]] = _Future() + _RECURSION_GUARD.seen_classes[clazz] = future try: fields = _dataclass_fields(clazz) except TypeError: # Not a dataclass @@ -430,8 +470,11 @@ def _internal_class_schema( if field.init ) - schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes) - return cast(Type[marshmallow.Schema], schema_class) + schema_class: Type[marshmallow.Schema] = type( + _name, (_base_schema(clazz, base_schema),), attributes + ) + future.set_result(schema_class) + return schema_class def _field_by_type( @@ -769,17 +812,18 @@ def field_for_schema( # Nested marshmallow dataclass # it would be just a class name instead of actual schema util the schema is not ready yet - nested_schema = getattr(typ, "Schema", None) - - # Nested dataclasses - forward_reference = getattr(typ, "__forward_arg__", None) - - nested = ( - nested_schema - or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args) - ) + if typ in _RECURSION_GUARD.seen_classes: + nested = _RECURSION_GUARD.seen_classes[typ].result + elif hasattr(typ, "Schema"): + nested = typ.Schema + elif hasattr(typ, "__forward_arg__"): + # FIXME: is this still used? + nested = typ.__forward_arg__ + else: + assert isinstance(typ, Hashable) + nested = _internal_class_schema( + typ, base_schema, typ_frame, generic_params_to_args + ) return marshmallow.fields.Nested(nested, **metadata) diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 10487e1..66580b4 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -462,11 +462,11 @@ class BB(typing.Generic[T]): @dataclasses.dataclass class Nested: + y: BB[AA] x: BB[float] z: BB[float] # if y is the first field in this class, deserialisation will fail. # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 - y: BB[AA] schema_nested = class_schema(Nested)() self.assertEqual(