Skip to content

Commit

Permalink
Fix bug with repeated use of the sane generic dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
dairiki committed Jan 12, 2023
1 parent de27c32 commit 024375f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
72 changes: 58 additions & 14 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class User:
Any,
Callable,
Dict,
Generic,
Hashable,
List,
Mapping,
NewType as typing_NewType,
Expand Down Expand Up @@ -367,6 +369,43 @@ def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
return dataclasses.fields(clazz)


class InvalidStateError(Exception):
"""Raised when an operation is performed on a future that is not
allowed in the current state.
"""


class _Future(Generic[_U]):
"""The _Future class allows deferred access to a result that is not
yet available.
"""

_done: bool
_result: _U

def __init__(self) -> None:
self._done = False

def done(self) -> bool:
"""Return ``True`` if the value is available"""
return self._done

def result(self) -> _U:
"""Return the deferred value.
Raises ``InvalidStateError`` if the value has not been set.
"""
if self.done():
return self._result
raise InvalidStateError("result has not been set")

def set_result(self, result: _U) -> None:
if self.done():
raise InvalidStateError("result has already been set")
self._result = result
self._done = True


@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
def _internal_class_schema(
clazz: type,
Expand All @@ -377,7 +416,8 @@ def _internal_class_schema(
# generic aliases do not have a __name__ prior python 3.10
_name = getattr(clazz, "__name__", repr(clazz))

_RECURSION_GUARD.seen_classes[clazz] = _name
future: _Future[Type[marshmallow.Schema]] = _Future()
_RECURSION_GUARD.seen_classes[clazz] = future
try:
fields = _dataclass_fields(clazz)
except TypeError: # Not a dataclass
Expand Down Expand Up @@ -430,8 +470,11 @@ def _internal_class_schema(
if field.init
)

schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes)
return cast(Type[marshmallow.Schema], schema_class)
schema_class: Type[marshmallow.Schema] = type(
_name, (_base_schema(clazz, base_schema),), attributes
)
future.set_result(schema_class)
return schema_class


def _field_by_type(
Expand Down Expand Up @@ -769,17 +812,18 @@ def field_for_schema(

# Nested marshmallow dataclass
# it would be just a class name instead of actual schema util the schema is not ready yet
nested_schema = getattr(typ, "Schema", None)

# Nested dataclasses
forward_reference = getattr(typ, "__forward_arg__", None)

nested = (
nested_schema
or forward_reference
or _RECURSION_GUARD.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args)
)
if typ in _RECURSION_GUARD.seen_classes:
nested = _RECURSION_GUARD.seen_classes[typ].result
elif hasattr(typ, "Schema"):
nested = typ.Schema
elif hasattr(typ, "__forward_arg__"):
# FIXME: is this still used?
nested = typ.__forward_arg__
else:
assert isinstance(typ, Hashable)
nested = _internal_class_schema(
typ, base_schema, typ_frame, generic_params_to_args
)

return marshmallow.fields.Nested(nested, **metadata)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_class_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ class BB(typing.Generic[T]):

@dataclasses.dataclass
class Nested:
y: BB[AA]
x: BB[float]
z: BB[float]
# if y is the first field in this class, deserialisation will fail.
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
y: BB[AA]

schema_nested = class_schema(Nested)()
self.assertEqual(
Expand Down

0 comments on commit 024375f

Please sign in to comment.