diff --git a/.flake8 b/.flake8 index b545e45a4f..329729ebc4 100644 --- a/.flake8 +++ b/.flake8 @@ -8,3 +8,5 @@ classmethod-decorators = classmethod validator root_validator +per-file-ignores = + starlite/types/builtin_types.py:E800,F401 diff --git a/docs/index.md b/docs/index.md index c09595a844..abf75ec5bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -81,7 +81,7 @@ class User(BaseModel): id: UUID4 ``` -Alternatively, you can **use a dataclass** – either from dataclasses or from pydantic: +Alternatively, you can **use a dataclass** – either from dataclasses or from pydantic, or a [`TypedDict`][typing.TypedDict]: ```python title="my_app/models/user.py" from uuid import UUID diff --git a/docs/reference/utils/0-predicate-utils.md b/docs/reference/utils/0-predicate-utils.md index 6131b9c5ab..781e56f86c 100644 --- a/docs/reference/utils/0-predicate-utils.md +++ b/docs/reference/utils/0-predicate-utils.md @@ -8,4 +8,10 @@ ::: starlite.utils.predicates.is_class_and_subclass +::: starlite.utils.predicates.is_dataclass_class_or_instance_typeguard + +::: starlite.utils.predicates.is_dataclass_class_typeguard + ::: starlite.utils.predicates.is_optional_union + +::: starlite.utils.predicates.is_typeddict_typeguard diff --git a/docs/usage/11-data-transfer-objects.md b/docs/usage/11-data-transfer-objects.md index 4eda156cb8..301be9216a 100644 --- a/docs/usage/11-data-transfer-objects.md +++ b/docs/usage/11-data-transfer-objects.md @@ -2,7 +2,7 @@ Starlite includes a [`DTOFactory`][starlite.dto.DTOFactory] class that allows you to create DTOs from pydantic models, -dataclasses and any other class supported via plugins. +dataclasses, [`TypedDict`][typing.TypedDict], and any other class supported via plugins. An instance of the factory must first be created, optionally passing plugins to it as a kwarg. It can then be used to create a [`DTO`][starlite.dto.DTO] by calling the instance like a function. Additionally, it can exclude (drop) diff --git a/docs/usage/4-request-data/0-request-data.md b/docs/usage/4-request-data/0-request-data.md index 83510f9f35..d30322ed80 100644 --- a/docs/usage/4-request-data/0-request-data.md +++ b/docs/usage/4-request-data/0-request-data.md @@ -17,7 +17,8 @@ async def create_user(data: User) -> User: ... ``` -The type of `data` does not need to be a pydantic model - it can be any supported type, e.g. a dataclass: +The type of `data` does not need to be a pydantic model - it can be any supported type, e.g. a dataclass, or a +[`TypedDict`][typing.TypedDict]: ```python from starlite import post diff --git a/starlite/dto.py b/starlite/dto.py index 989b094471..ab0a8d8c2e 100644 --- a/starlite/dto.py +++ b/starlite/dto.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, is_dataclass +from dataclasses import asdict from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -22,7 +22,13 @@ from starlite.exceptions import ImproperlyConfiguredException from starlite.plugins import PluginProtocol, get_plugin_for_value -from starlite.utils import convert_dataclass_to_model, is_async_callable +from starlite.utils import ( + convert_dataclass_to_model, + convert_typeddict_to_model, + is_async_callable, + is_dataclass_class_or_instance_typeguard, + is_typeddict_typeguard, +) if TYPE_CHECKING: from typing import Awaitable @@ -86,6 +92,8 @@ def from_model_instance(cls, model_instance: T) -> "DTO[T]": values = cast("Dict[str, Any]", result) elif isinstance(model_instance, BaseModel): values = model_instance.dict() + elif isinstance(model_instance, dict): + values = dict(model_instance) # copy required as `_from_value_mapping()` mutates `values`. else: values = asdict(model_instance) return cls._from_value_mapping(mapping=values) @@ -133,8 +141,10 @@ def to_model_instance(self) -> T: class DTOFactory: def __init__(self, plugins: Optional[List[PluginProtocol]] = None) -> None: - """Create [`DTO`][starlite.dto.DTO] types from pydantic models, - dataclasses and other types supported via plugins. + """Create [`DTO`][starlite.dto.DTO] types. + + Pydantic models, [`TypedDict`][typing.TypedDict] and dataclasses are natively supported. Other types supported + via plugins. Args: plugins (list[PluginProtocol] | None): Plugins used to support `DTO` construction from arbitrary types. @@ -150,8 +160,8 @@ def __call__( field_definitions: Optional[Dict[str, Tuple[Any, Any]]] = None, ) -> Type[DTO[T]]: """ - Given a supported model class - either pydantic, dataclass or a class supported via plugins, - create a DTO pydantic model class. + Given a supported model class - either pydantic, [`TypedDict`][typing.TypedDict], dataclass or a class supported + via plugins, create a DTO pydantic model class. An instance of the factory must first be created, passing any plugins to it. It can then be used to create a DTO by calling the instance like a function. Additionally, it can exclude (drop) @@ -193,8 +203,8 @@ def create_obj(data: MyClassDTO) -> MyClass: Args: name (str): This becomes the name of the generated pydantic model. - source (type[T]): A type that is either a subclass of `BaseModel`, a `dataclass` or any other type with a - plugin registered. + source (type[T]): A type that is either a subclass of `BaseModel`, [`TypedDict`][typing.TypedDict], a + `dataclass` or any other type with a plugin registered. exclude (list[str] | None): Names of attributes on `source`. Named Attributes will not have a field generated on the resultant pydantic model. field_mapping (dict[str, str | tuple[str, Any]] | None): Keys are names of attributes on `source`. Values @@ -208,7 +218,8 @@ def create_obj(data: MyClassDTO) -> MyClass: Raises: [ImproperlyConfiguredException][starlite.exceptions.ImproperlyConfiguredException]: If `source` is not a - pydantic model or dataclass, and there is no plugin registered for its type. + pydantic model, [`TypedDict`][typing.TypedDict] or dataclass, and there is no plugin registered for its + type. """ field_definitions = field_definitions or {} exclude = exclude or [] @@ -228,14 +239,17 @@ def create_obj(data: MyClassDTO) -> MyClass: def _get_fields_from_source( self, source: Type[T] # pyright: ignore ) -> Tuple[Dict[str, ModelField], Optional[PluginProtocol]]: - """Converts a `BaseModel` subclass, `dataclass` or any other type that - has a plugin registered into a mapping of `str` to `ModelField`.""" + """Converts a `BaseModel` subclass, [`TypedDict`][typing.TypedDict], + `dataclass` or any other type that has a plugin registered into a + mapping of `str` to `ModelField`.""" plugin: Optional[PluginProtocol] = None if issubclass(source, BaseModel): source.update_forward_refs() fields = source.__fields__ - elif is_dataclass(source): + elif is_dataclass_class_or_instance_typeguard(source): fields = convert_dataclass_to_model(source).__fields__ + elif is_typeddict_typeguard(source): + fields = convert_typeddict_to_model(source).__fields__ else: plugin = get_plugin_for_value(value=source, plugins=self.plugins) if not plugin: diff --git a/starlite/openapi/schema.py b/starlite/openapi/schema.py index b64f20aabd..0670ca767b 100644 --- a/starlite/openapi/schema.py +++ b/starlite/openapi/schema.py @@ -1,4 +1,3 @@ -from dataclasses import is_dataclass from datetime import datetime from decimal import Decimal from enum import Enum, EnumMeta @@ -32,7 +31,15 @@ ) from starlite.openapi.enums import OpenAPIFormat, OpenAPIType from starlite.openapi.utils import get_openapi_type_for_complex_type -from starlite.utils.model import convert_dataclass_to_model, create_parsed_model_field +from starlite.utils import ( + is_dataclass_class_or_instance_typeguard, + is_typeddict_typeguard, +) +from starlite.utils.model import ( + convert_dataclass_to_model, + convert_typeddict_to_model, + create_parsed_model_field, +) if TYPE_CHECKING: from starlite.plugins.base import PluginProtocol @@ -44,7 +51,7 @@ def normalize_example_value(value: Any) -> Any: value = round(float(value), 2) if isinstance(value, Enum): value = value.value - if is_dataclass(value): + if is_dataclass_class_or_instance_typeguard(value): value = convert_dataclass_to_model(value) if isinstance(value, BaseModel): value = value.dict() @@ -184,8 +191,10 @@ def get_schema_for_field_type(field: ModelField, plugins: List["PluginProtocol"] return TYPE_MAP[field_type].copy() if is_pydantic_model(field_type): return OpenAPI310PydanticSchema(schema_class=field_type) - if is_dataclass(field_type): + if is_dataclass_class_or_instance_typeguard(field_type): return OpenAPI310PydanticSchema(schema_class=convert_dataclass_to_model(field_type)) + if is_typeddict_typeguard(field_type): + return OpenAPI310PydanticSchema(schema_class=convert_typeddict_to_model(field_type)) if isinstance(field_type, EnumMeta): enum_values: List[Union[str, int]] = [v.value for v in field_type] # type: ignore openapi_type = OpenAPIType.STRING if isinstance(enum_values[0], str) else OpenAPIType.INTEGER diff --git a/starlite/types/builtin_types.py b/starlite/types/builtin_types.py new file mode 100644 index 0000000000..92a72be3cd --- /dev/null +++ b/starlite/types/builtin_types.py @@ -0,0 +1,27 @@ +# nopycln: file + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + + from typing import Type, Union + + from pydantic_factories.protocols import DataclassProtocol + from typing_extensions import TypeAlias, TypedDict + + +__all__ = [ + "DataclassClass", + "DataclassClassOrInstance", + "NoneType", + "TypedDictClass", +] + +DataclassClass: "TypeAlias" = "Type[DataclassProtocol]" + +DataclassClassOrInstance: "TypeAlias" = "Union[DataclassClass, DataclassProtocol]" + +NoneType = type(None) + +# mypy issue: https://github.com/python/mypy/issues/11030 +TypedDictClass: "TypeAlias" = "Type[TypedDict]" # type:ignore[valid-type] diff --git a/starlite/types/partial.py b/starlite/types/partial.py index 1bf38b4616..7e56d8aa63 100644 --- a/starlite/types/partial.py +++ b/starlite/types/partial.py @@ -1,13 +1,29 @@ from dataclasses import MISSING from dataclasses import Field as DataclassField -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from inspect import getmro -from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Optional, + Tuple, + Type, + TypeVar, + get_type_hints, +) from pydantic import BaseModel, create_model +from typing_extensions import TypedDict, get_args from starlite.exceptions import ImproperlyConfiguredException -from starlite.utils.predicates import is_class_and_subclass +from starlite.types.builtin_types import NoneType +from starlite.utils.predicates import ( + is_class_and_subclass, + is_dataclass_class_typeguard, + is_typeddict_typeguard, +) try: # python 3.9 changed these variable @@ -15,88 +31,111 @@ except ImportError: # pragma: no cover from typing import _GenericAlias as GenericAlias # type: ignore +if TYPE_CHECKING: + from typing import TypeAlias, Union # noqa: F401 # nopycln: import + + from starlite.types.builtin_types import DataclassClass, TypedDictClass + T = TypeVar("T") +SupportedTypes: "TypeAlias" = "Union[DataclassClass, Type[BaseModel], TypedDictClass]" +"""Types that are supported by [`Partial`][starlite.types.partial.Partial]""" + class Partial(Generic[T]): - """Partial is a special typing helper that takes a generic T, which must be - a dataclass or pydantic model class, + """Type generation for PATCH routes. - and returns to static type checkers a version of this T in which all fields - and nested fields - are optional. + Partial is a special typing helper that takes a generic T, which must be a + [`TypedDict`][typing.TypedDict], dataclass or pydantic model class, and + returns to static type checkers a version of this T in which all fields - + and nested fields - are optional. """ - _models: Dict[Type[T], Type[T]] = {} + _models: Dict[SupportedTypes, SupportedTypes] = {} def __class_getitem__(cls, item: Type[T]) -> Type[T]: - """Takes a pydantic model class or a dataclass and returns an all - optional version of that class. + """Takes a pydantic model class, [`TypedDict`][typing.TypedDict] or a + dataclass and returns an all optional version of that class. Args: - item: A pydantic model or dataclass class. + item: A pydantic model, [`TypedDict`][typing.TypedDict] or dataclass class. Returns: - A pydantic model or dataclass. + A pydantic model, [`TypedDict`][typing.TypedDict], or dataclass. """ if item not in cls._models: if is_class_and_subclass(item, BaseModel): cls._create_partial_pydantic_model(item=item) - elif is_dataclass(item): + elif is_dataclass_class_typeguard(item): cls._create_partial_dataclass(item=item) + elif is_typeddict_typeguard(item): + cls._create_partial_typeddict(item=item) else: raise ImproperlyConfiguredException( - "The type argument T passed to Partial[T] must be a dataclass or pydantic model class" + "The type argument T passed to Partial[T] must be a `TypedDict`, dataclass or pydantic model class" ) - return cls._models[item] # pyright: ignore + return cls._models[item] # type:ignore[return-value] @classmethod def _create_partial_pydantic_model(cls, item: Type[BaseModel]) -> None: - """Receives a pydantic model class and returns an all optional subclass + """Receives a pydantic model class and creates an all optional subclass of it. Args: item: A pydantic model class. - - Returns: - A pydantic model class. """ field_definitions: Dict[str, Tuple[Any, None]] = {} for field_name, field_type in get_type_hints(item).items(): - if not isinstance(field_type, GenericAlias) or type(None) not in field_type.__args__: + if not isinstance(field_type, GenericAlias) or NoneType not in field_type.__args__: field_definitions[field_name] = (Optional[field_type], None) else: field_definitions[field_name] = (field_type, None) - cls._models[item] = create_model(f"Partial{item.__name__}", __base__=item, **field_definitions) # type: ignore + cls._models[item] = create_model(cls._create_partial_type_name(item), __base__=item, **field_definitions) # type: ignore @classmethod - def _create_partial_dataclass(cls, item: Type[T]) -> None: - """Receives a dataclass class and returns an all optional subclass of + def _create_partial_dataclass(cls, item: "DataclassClass") -> None: + """Receives a dataclass class and creates an all optional subclass of it. Args: item: A dataclass class. - - Returns: - A dataclass class. """ fields: Dict[str, DataclassField] = cls._create_optional_field_map(item) - partial_type: Type[T] = dataclass( # pyright: ignore - type(f"Partial{item.__name__}", (item,), {"__dataclass_fields__": fields}) + partial_type: "DataclassClass" = dataclass( + type(cls._create_partial_type_name(item), (item,), {"__dataclass_fields__": fields}) ) annotated_ancestors = [a for a in getmro(partial_type) if hasattr(a, "__annotations__")] for ancestor in annotated_ancestors: for field_name, annotation in ancestor.__annotations__.items(): - if not isinstance(annotation, GenericAlias) or type(None) not in annotation.__args__: + if not isinstance(annotation, GenericAlias) or NoneType not in annotation.__args__: partial_type.__annotations__[field_name] = Optional[annotation] else: partial_type.__annotations__[field_name] = annotation cls._models[item] = partial_type + @classmethod + def _create_partial_typeddict(cls, item: "TypedDictClass") -> None: + """Receives a typeddict class and creates a new type with all + attributes `Optional`. + + Args: + item: A [`TypedDict`][typing.TypeDict] class. + """ + type_hints: Dict[str, Any] = {} + for key_name, value_type in get_type_hints(item).items(): + if NoneType in get_args(value_type): + type_hints[key_name] = value_type + continue + type_hints[key_name] = Optional[value_type] + type_name = cls._create_partial_type_name(item) + cls._models[item] = TypedDict(type_name, type_hints, total=False) # type:ignore + @staticmethod - def _create_optional_field_map(item: Type[T]) -> Dict[str, DataclassField]: + def _create_optional_field_map(item: "DataclassClass") -> Dict[str, DataclassField]: """Creates a map of field name to optional dataclass Fields for a given dataclass. @@ -107,10 +146,15 @@ def _create_optional_field_map(item: Type[T]) -> Dict[str, DataclassField]: A map of field name to optional dataclass fields. """ fields: Dict[str, DataclassField] = {} - for field_name, dataclass_field in item.__dataclass_fields__.items(): # type: ignore[attr-defined] - if not isinstance(dataclass_field.type, GenericAlias) or type(None) not in dataclass_field.type.__args__: + # https://github.com/python/typing/discussions/1056 + for field_name, dataclass_field in item.__dataclass_fields__.items(): # pyright:ignore + if not isinstance(dataclass_field.type, GenericAlias) or NoneType not in dataclass_field.type.__args__: dataclass_field.type = Optional[dataclass_field.type] if dataclass_field.default_factory is MISSING: dataclass_field.default = None if dataclass_field.default is MISSING else dataclass_field.default fields[field_name] = dataclass_field return fields + + @staticmethod + def _create_partial_type_name(item: SupportedTypes) -> str: + return f"Partial{item.__name__}" diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index 77829f33ea..d375cf4222 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -6,9 +6,20 @@ get_exception_handler, ) from .extractors import ConnectionDataExtractor, ResponseDataExtractor, obfuscate -from .model import convert_dataclass_to_model, create_parsed_model_field +from .model import ( + convert_dataclass_to_model, + convert_typeddict_to_model, + create_parsed_model_field, +) from .path import join_paths, normalize_path -from .predicates import is_async_callable, is_class_and_subclass, is_optional_union +from .predicates import ( + is_async_callable, + is_class_and_subclass, + is_dataclass_class_or_instance_typeguard, + is_dataclass_class_typeguard, + is_optional_union, + is_typeddict_typeguard, +) from .scope import get_serializer_from_scope from .sequence import find_index, unique from .serialization import default_serializer @@ -22,6 +33,7 @@ "as_async_callable_list", "async_partial", "convert_dataclass_to_model", + "convert_typeddict_to_model", "create_exception_response", "create_parsed_model_field", "default_serializer", @@ -30,8 +42,11 @@ "get_serializer_from_scope", "is_async_callable", "is_class_and_subclass", + "is_dataclass_class_or_instance_typeguard", + "is_dataclass_class_typeguard", "is_dependency_field", "is_optional_union", + "is_typeddict_typeguard", "join_paths", "normalize_path", "obfuscate", diff --git a/starlite/utils/model.py b/starlite/utils/model.py index 0e3f15a49e..9ec850c499 100644 --- a/starlite/utils/model.py +++ b/starlite/utils/model.py @@ -1,12 +1,13 @@ -from inspect import isclass from typing import TYPE_CHECKING, Any, Dict, Type, cast -from pydantic import BaseConfig, BaseModel, create_model +from pydantic import BaseConfig, BaseModel, create_model, create_model_from_typeddict from pydantic_factories.utils import create_model_from_dataclass if TYPE_CHECKING: from pydantic.fields import ModelField + from starlite.types.builtin_types import DataclassClassOrInstance, TypedDictClass + class Config(BaseConfig): arbitrary_types_allowed = True @@ -19,13 +20,28 @@ def create_parsed_model_field(value: Type[Any]) -> "ModelField": return cast("BaseModel", model).__fields__["value"] -_dataclass_model_map: Dict[Any, Type[BaseModel]] = {} +_type_model_map: Dict[Type[Any], Type[BaseModel]] = {} + + +def convert_dataclass_to_model(dataclass_or_instance: "DataclassClassOrInstance") -> Type[BaseModel]: + """Converts a dataclass or dataclass instance to a pydantic model and + memoizes the result.""" + + if not isinstance(dataclass_or_instance, type): + dataclass = type(dataclass_or_instance) + else: + dataclass = dataclass_or_instance + + existing = _type_model_map.get(dataclass) + if not existing: + _type_model_map[dataclass] = existing = create_model_from_dataclass(dataclass) + return existing -def convert_dataclass_to_model(dataclass: Any) -> Type[BaseModel]: - """Converts a dataclass to a pydantic model and memoizes the result.""" - if not isclass(dataclass) and hasattr(dataclass, "__class__"): - dataclass = dataclass.__class__ - if not _dataclass_model_map.get(dataclass): - _dataclass_model_map[dataclass] = create_model_from_dataclass(dataclass) # pyright: ignore - return _dataclass_model_map[dataclass] +def convert_typeddict_to_model(typeddict: "TypedDictClass") -> Type[BaseModel]: + """Converts a [`TypedDict`][typing.TypedDict] to a pydantic model and + memoizes the result.""" + existing = _type_model_map.get(typeddict) + if not existing: + _type_model_map[typeddict] = existing = create_model_from_typeddict(typeddict) + return existing diff --git a/starlite/utils/predicates.py b/starlite/utils/predicates.py index e485b3cec6..49d7895bd2 100644 --- a/starlite/utils/predicates.py +++ b/starlite/utils/predicates.py @@ -1,10 +1,13 @@ import asyncio import functools import sys +from dataclasses import is_dataclass from inspect import isclass -from typing import Any, Awaitable, Callable, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Type, TypeVar, Union -from typing_extensions import ParamSpec, TypeGuard, get_args, get_origin +from typing_extensions import ParamSpec, TypeGuard, get_args, get_origin, is_typeddict + +from starlite.types.builtin_types import NoneType if sys.version_info >= (3, 10): from types import UnionType @@ -13,6 +16,14 @@ else: # pragma: no cover UNION_TYPES = {Union} +if TYPE_CHECKING: + + from starlite.types.builtin_types import ( + DataclassClass, + DataclassClassOrInstance, + TypedDictClass, + ) + P = ParamSpec("P") T = TypeVar("T") @@ -66,4 +77,40 @@ def is_optional_union(annotation: Any) -> bool: Returns: True for a union, False otherwise. """ - return get_origin(annotation) in UNION_TYPES and type(None) in get_args(annotation) + return get_origin(annotation) in UNION_TYPES and NoneType in get_args(annotation) + + +def is_dataclass_class_typeguard(value: Any) -> "TypeGuard[DataclassClass]": + """Wrapper for `is_dataclass()` that narrows to type only, not instance. + + Args: + value: tested to determine if type of `dataclass`. + + Returns: + `True` if `value` is a `dataclass` type. + """ + return is_dataclass(value) and isinstance(value, type) + + +def is_dataclass_class_or_instance_typeguard(value: Any) -> "TypeGuard[DataclassClassOrInstance]": + """Wrapper for `is_dataclass()` that narrows type. + + Args: + value: tested to determine if instance or type of `dataclass`. + + Returns: + `True` if instance or type of `dataclass`. + """ + return is_dataclass(value) + + +def is_typeddict_typeguard(value: Any) -> "TypeGuard[TypedDictClass]": + """Wrapper for `is_typeddict()` that narrows type. + + Args: + value: tested to determine if instance or type of `dataclass`. + + Returns: + `True` if instance or type of `dataclass`. + """ + return is_typeddict(value) diff --git a/tests/__init__.py b/tests/__init__.py index f05f7915a6..13bfb0deaf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic_factories import ModelFactory +from typing_extensions import TypedDict class Species(str, Enum): @@ -55,3 +56,12 @@ class PydanticDataClassPerson: optional: Optional[str] complex: Dict[str, List[Dict[str, str]]] pets: Optional[List[Pet]] = None + + +class TypedDictPerson(TypedDict): + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[Pet]] diff --git a/tests/dto_factory/test_dto_factory_integration.py b/tests/dto_factory/test_dto_factory_integration.py index 69aa3d609d..d952339e40 100644 --- a/tests/dto_factory/test_dto_factory_integration.py +++ b/tests/dto_factory/test_dto_factory_integration.py @@ -8,7 +8,7 @@ from starlite import DTOFactory, get, post from starlite.plugins.sql_alchemy import SQLAlchemyPlugin from starlite.testing import create_test_client -from tests import Person, VanillaDataClassPerson +from tests import Person, TypedDictPerson, VanillaDataClassPerson from tests.plugins.sql_alchemy_plugin import Pet @@ -17,6 +17,7 @@ [ [Person, ["id"], {"complex": "ultra"}, []], [VanillaDataClassPerson, ["id"], {"complex": "ultra"}, []], + [TypedDictPerson, ["id"], {"complex": "ultra"}, []], [Pet, ["owner"], {"species": "kind"}, [SQLAlchemyPlugin()]], ], ) @@ -50,6 +51,7 @@ def get_handler() -> Any: [ [Person, ["id"], {"complex": "ultra"}, {"special": (str, ...)}, []], [VanillaDataClassPerson, ["id"], {"complex": "ultra"}, {"special": (str, ...)}, []], + [TypedDictPerson, ["id"], {"complex": "ultra"}, {"special": (str, ...)}, []], [Pet, ["age"], {"species": "kind"}, {"special": (str, ...)}, [SQLAlchemyPlugin()]], ], ) diff --git a/tests/dto_factory/test_dto_factory_model_conversion.py b/tests/dto_factory/test_dto_factory_model_conversion.py index 124abf1d95..29cd1efa03 100644 --- a/tests/dto_factory/test_dto_factory_model_conversion.py +++ b/tests/dto_factory/test_dto_factory_model_conversion.py @@ -5,20 +5,30 @@ import pytest from pydantic_factories import ModelFactory +from typing_extensions import is_typeddict from starlite import DTOFactory, ImproperlyConfiguredException from starlite.plugins.sql_alchemy import SQLAlchemyPlugin from starlite.plugins.tortoise_orm import TortoiseORMPlugin -from tests import Person, Species, VanillaDataClassPerson +from tests import Person, Species, TypedDictPerson, VanillaDataClassPerson from tests.plugins.sql_alchemy_plugin import Pet from tests.plugins.tortoise_orm import Tournament +def _get_attribute_value(model_instance: Any, key: str) -> Any: + """Utility to support getting values from a class instance, or dict.""" + try: + return model_instance.__getattribute__(key) + except AttributeError: + return model_instance[key] + + @pytest.mark.parametrize( "model, exclude, field_mapping, plugins", [ [Person, [], {"complex": "ultra"}, []], [VanillaDataClassPerson, [], {"complex": "ultra"}, []], + [TypedDictPerson, [], {"complex": "ultra"}, []], [Pet, ["age"], {"species": "kind"}, [SQLAlchemyPlugin()]], ], ) @@ -34,10 +44,12 @@ class DTOModelFactory(ModelFactory[MyDTO]): # type: ignore for key in dto_instance.__fields__: # type: ignore if key not in MyDTO.dto_field_mapping: - assert model_instance.__getattribute__(key) == dto_instance.__getattribute__(key) # type: ignore + attribute_value = _get_attribute_value(model_instance, key) + assert attribute_value == dto_instance.__getattribute__(key) # type: ignore else: original_key = MyDTO.dto_field_mapping[key] - assert model_instance.__getattribute__(original_key) == dto_instance.__getattribute__(key) # type: ignore + attribute_value = _get_attribute_value(model_instance, original_key) + assert attribute_value == dto_instance.__getattribute__(key) # type: ignore @pytest.mark.skipif(sys.version_info < (3, 9), reason="dataclasses behave differently in lower versions") @@ -46,6 +58,7 @@ class DTOModelFactory(ModelFactory[MyDTO]): # type: ignore [ [Person, ["id"], {"complex": "ultra"}, []], [VanillaDataClassPerson, ["id"], {"complex": "ultra"}, []], + [TypedDictPerson, ["id"], {"complex": "ultra"}, []], [Pet, ["age"], {"species": "kind"}, [SQLAlchemyPlugin()]], ], ) @@ -54,7 +67,7 @@ def test_conversion_from_model_instance( ) -> None: DTO = DTOFactory(plugins=plugins)("MyDTO", model, exclude=exclude, field_mapping=field_mapping) - if issubclass(model, (Person, VanillaDataClassPerson)): + if issubclass(model, (Person, VanillaDataClassPerson)) or is_typeddict(model): model_instance = model( first_name="moishe", last_name="zuchmir", @@ -74,10 +87,10 @@ def test_conversion_from_model_instance( dto_instance = DTO.from_model_instance(model_instance=model_instance) for key in dto_instance.__fields__: if key not in DTO.dto_field_mapping: - assert model_instance.__getattribute__(key) == dto_instance.__getattribute__(key) + assert _get_attribute_value(model_instance, key) == _get_attribute_value(dto_instance, key) else: original_key = DTO.dto_field_mapping[key] - assert model_instance.__getattribute__(original_key) == dto_instance.__getattribute__(key) + assert _get_attribute_value(model_instance, original_key) == _get_attribute_value(dto_instance, key) async def test_async_conversion_from_model_instance(scaffold_tortoise: Callable, anyio_backend: str) -> None: diff --git a/tests/openapi/test_schema.py b/tests/openapi/test_schema.py index 34cdb7ba4e..91073b5454 100644 --- a/tests/openapi/test_schema.py +++ b/tests/openapi/test_schema.py @@ -1,6 +1,8 @@ from typing import Generic, TypeVar +from unittest.mock import MagicMock import pytest +from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_openapi_schema.v3_1_0.example import Example from pydantic_openapi_schema.v3_1_0.schema import Schema @@ -10,12 +12,17 @@ from starlite.constants import EXTRA_KEY_REQUIRED from starlite.enums import ParamType from starlite.exceptions import ImproperlyConfiguredException +from starlite.openapi import schema from starlite.openapi.constants import ( EXTRA_TO_OPENAPI_PROPERTY_MAP, PYDANTIC_TO_OPENAPI_PROPERTY_MAP, ) -from starlite.openapi.schema import update_schema_with_field_info +from starlite.openapi.schema import ( + get_schema_for_field_type, + update_schema_with_field_info, +) from starlite.testing import create_test_client +from tests import TypedDictPerson def test_update_schema_with_field_info() -> None: @@ -102,3 +109,18 @@ def handler_function(dep: GenericType[int]) -> None: with pytest.raises(ImproperlyConfiguredException): Starlite(route_handlers=[handler_function]) + + +def test_get_schema_for_field_type_typeddict(monkeypatch: pytest.MonkeyPatch) -> None: + return_value_mock = MagicMock() + convert_typeddict_to_model_mock = MagicMock(return_value=return_value_mock) + openapi_310_pydantic_schema_mock = MagicMock() + monkeypatch.setattr(schema, "OpenAPI310PydanticSchema", openapi_310_pydantic_schema_mock) + monkeypatch.setattr(schema, "convert_typeddict_to_model", convert_typeddict_to_model_mock) + + class M(BaseModel): + data: TypedDictPerson + + get_schema_for_field_type(M.__fields__["data"], []) + convert_typeddict_to_model_mock.assert_called_once_with(TypedDictPerson) + openapi_310_pydantic_schema_mock.assert_called_once_with(schema_class=return_value_mock) diff --git a/tests/plugins/piccolo_orm/piccolo_app.py b/tests/plugins/piccolo_orm/piccolo_app.py index 5cd0bba75e..c36ada291e 100644 --- a/tests/plugins/piccolo_orm/piccolo_app.py +++ b/tests/plugins/piccolo_orm/piccolo_app.py @@ -3,7 +3,7 @@ https://github.com/piccolo-orm/piccolo/blob/master/tests/example_apps/music/piccolo_app.py """ -import os +from pathlib import Path from piccolo.conf.apps import AppConfig @@ -15,7 +15,7 @@ Venue, ) -CURRENT_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) +CURRENT_DIRECTORY = Path(__file__).parent APP_CONFIG = AppConfig( app_name="music", @@ -26,6 +26,6 @@ Concert, RecordingStudio, ], - migrations_folder_path=os.path.join(CURRENT_DIRECTORY, "piccolo_migrations"), + migrations_folder_path=str(CURRENT_DIRECTORY / "piccolo_migrations"), commands=[], ) diff --git a/tests/test_typing.py b/tests/test_typing.py index 9885516dec..c0b315767e 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,12 +1,19 @@ import dataclasses -from typing import Any, Optional +from typing import Any, Optional, get_type_hints import pytest from pydantic import BaseModel +from typing_extensions import TypedDict, get_args from starlite.exceptions import ImproperlyConfiguredException +from starlite.types.builtin_types import NoneType from starlite.types.partial import Partial -from tests import Person, PydanticDataClassPerson, VanillaDataClassPerson +from tests import ( + Person, + PydanticDataClassPerson, + TypedDictPerson, + VanillaDataClassPerson, +) try: from typing import _UnionGenericAlias as GenericAlias # type: ignore @@ -23,9 +30,9 @@ def test_partial_pydantic_model() -> None: assert field.allow_none assert not field.required - for annotation in partial.__annotations__.values(): + for annotation in get_type_hints(partial).values(): assert isinstance(annotation, GenericAlias) - assert type(None) in annotation.__args__ + assert NoneType in get_args(annotation) @pytest.mark.parametrize("cls", [VanillaDataClassPerson, PydanticDataClassPerson]) @@ -36,11 +43,21 @@ def test_partial_dataclass(cls: Any) -> None: for field in partial.__dataclass_fields__.values(): # type: ignore assert field.default is None - assert type(None) in field.type.__args__ + assert NoneType in get_args(field.type) - for annotation in partial.__annotations__.values(): + for annotation in get_type_hints(partial).values(): assert isinstance(annotation, GenericAlias) - assert type(None) in annotation.__args__ + assert NoneType in get_args(annotation) + + +def test_partial_typeddict() -> None: + partial = Partial[TypedDictPerson] + + assert len(get_type_hints(partial)) == len(get_type_hints(TypedDictPerson)) + + for annotation in get_type_hints(partial).values(): + assert isinstance(annotation, GenericAlias) + assert NoneType in get_args(annotation) def test_partial_pydantic_model_with_superclass() -> None: @@ -58,7 +75,7 @@ class Child(Parent): assert field.allow_none assert not field.required - assert partial_child.__annotations__ == { + assert get_type_hints(partial_child) == { "parent_attribute": Optional[int], "child_attribute": Optional[int], } @@ -79,12 +96,21 @@ class Child(Parent): for field in partial_child.__dataclass_fields__.values(): # type: ignore assert field.default is None - assert type(None) in field.type.__args__ + assert NoneType in get_args(field.type) - assert partial_child.__annotations__ == { - "parent_attribute": Optional[int], - "child_attribute": Optional[int], - } + assert get_type_hints(partial_child) == {"parent_attribute": Optional[int], "child_attribute": Optional[int]} + + +def test_partial_typeddict_with_superclass() -> None: + class Parent(TypedDict, total=True): + parent_attribute: int + + class Child(Parent): + child_attribute: int + + partial_child = Partial[Child] + + assert get_type_hints(partial_child) == {"parent_attribute": Optional[int], "child_attribute": Optional[int]} class Foo: diff --git a/tests/utils/test_model.py b/tests/utils/test_model.py new file mode 100644 index 0000000000..711871bfeb --- /dev/null +++ b/tests/utils/test_model.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from typing_extensions import TypedDict + +from starlite.utils import model + +if TYPE_CHECKING: + from pytest import MonkeyPatch # noqa: PT013 + + +def test_convert_dataclass_to_model_cache(monkeypatch: "MonkeyPatch") -> None: + @dataclass + class DC: + a: str + b: int + + response_mock = MagicMock() + create_model_from_dataclass_mock = MagicMock(return_value=response_mock) + monkeypatch.setattr(model, "create_model_from_dataclass", create_model_from_dataclass_mock) + # test calling the function twice returns the expected response each time + for _ in range(2): + response = model.convert_dataclass_to_model(DC) + assert response is response_mock + # ensures that the work of the function has only been done once + create_model_from_dataclass_mock.assert_called_once_with(DC) + + +def test_convert_typeddict_to_model_cache(monkeypatch: "MonkeyPatch") -> None: + class TD(TypedDict): + a: str + b: int + + response_mock = MagicMock() + create_model_from_typeddict_mock = MagicMock(return_value=response_mock) + monkeypatch.setattr(model, "create_model_from_typeddict", create_model_from_typeddict_mock) + # test calling the function twice returns the expected response each time + for _ in range(2): + response = model.convert_typeddict_to_model(TD) + assert response is response_mock + # ensures that the work of the function has only been done once + create_model_from_typeddict_mock.assert_called_once_with(TD)