diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index e5bbaf6fc..3da55d2df 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -2823,7 +2823,8 @@ "propertyName": "modelType", "mapping": { "type1": "#/components/schemas/ADiscriminatedUnionType1", - "type2": "#/components/schemas/ADiscriminatedUnionType2" + "type2": "#/components/schemas/ADiscriminatedUnionType2", + "type2-another-value": "#/components/schemas/ADiscriminatedUnionType2" } }, "oneOf": [ @@ -2841,7 +2842,8 @@ "modelType": { "type": "string" } - } + }, + "required": ["modelType"] }, "ADiscriminatedUnionType2": { "type": "object", @@ -2849,7 +2851,8 @@ "modelType": { "type": "string" } - } + }, + "required": ["modelType"] } }, "parameters": { diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index b6a6941e2..746ea91ef 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -2817,7 +2817,8 @@ info: "propertyName": "modelType", "mapping": { "type1": "#/components/schemas/ADiscriminatedUnionType1", - "type2": "#/components/schemas/ADiscriminatedUnionType2" + "type2": "#/components/schemas/ADiscriminatedUnionType2", + "type2-another-value": "#/components/schemas/ADiscriminatedUnionType2" } }, "oneOf": [ @@ -2835,7 +2836,8 @@ info: "modelType": { "type": "string" } - } + }, + "required": ["modelType"] }, "ADiscriminatedUnionType2": { "type": "object", @@ -2843,7 +2845,8 @@ info: "modelType": { "type": "string" } - } + }, + "required": ["modelType"] } } "parameters": { diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py index cb1184b18..ed02a0aaf 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field -from ..types import UNSET, Unset - T = TypeVar("T", bound="ADiscriminatedUnionType1") @@ -12,10 +10,10 @@ class ADiscriminatedUnionType1: """ Attributes: - model_type (Union[Unset, str]): + model_type (str): """ - model_type: Union[Unset, str] = UNSET + model_type: str additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> Dict[str, Any]: @@ -23,16 +21,18 @@ def to_dict(self) -> Dict[str, Any]: field_dict: Dict[str, Any] = {} field_dict.update(self.additional_properties) - field_dict.update({}) - if model_type is not UNSET: - field_dict["modelType"] = model_type + field_dict.update( + { + "modelType": model_type, + } + ) return field_dict @classmethod def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() - model_type = d.pop("modelType", UNSET) + model_type = d.pop("modelType") a_discriminated_union_type_1 = cls( model_type=model_type, diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py index 734f3bef4..93ee2cbb9 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field -from ..types import UNSET, Unset - T = TypeVar("T", bound="ADiscriminatedUnionType2") @@ -12,10 +10,10 @@ class ADiscriminatedUnionType2: """ Attributes: - model_type (Union[Unset, str]): + model_type (str): """ - model_type: Union[Unset, str] = UNSET + model_type: str additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> Dict[str, Any]: @@ -23,16 +21,18 @@ def to_dict(self) -> Dict[str, Any]: field_dict: Dict[str, Any] = {} field_dict.update(self.additional_properties) - field_dict.update({}) - if model_type is not UNSET: - field_dict["modelType"] = model_type + field_dict.update( + { + "modelType": model_type, + } + ) return field_dict @classmethod def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() - model_type = d.pop("modelType", UNSET) + model_type = d.pop("modelType") a_discriminated_union_type_2 = cls( model_type=model_type, diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py index e03a6e698..6f37fe454 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py @@ -59,23 +59,42 @@ def _parse_discriminated_union( return data if isinstance(data, Unset): return data - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data) - - return componentsschemas_a_discriminated_union_type_0 - except: # noqa: E722 - pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) - - return componentsschemas_a_discriminated_union_type_1 - except: # noqa: E722 - pass - return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data) + if not isinstance(data, dict): + raise TypeError() + if "modelType" in data: + _discriminator_value = data["modelType"] + + def _parse_1(data: object) -> ADiscriminatedUnionType1: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data) + + return componentsschemas_a_discriminated_union_type_0 + + def _parse_2(data: object) -> ADiscriminatedUnionType2: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) + + return componentsschemas_a_discriminated_union_type_1 + + def _parse_3(data: object) -> ADiscriminatedUnionType2: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) + + return componentsschemas_a_discriminated_union_type_1 + + _discriminator_mapping = { + "type1": _parse_1, + "type2": _parse_2, + "type2-another-value": _parse_3, + } + if _parse_fn := _discriminator_mapping.get(_discriminator_value): + return cast( + Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data) + ) + raise TypeError("unrecognized value for property modelType") discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET)) diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 0dc13be54..fd3e24ac8 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol): relative_imports: set[str] | None lazy_imports: set[str] | None additional_properties: Property | None + ref_path: ReferencePath | None = None _json_type_string: ClassVar[str] = "Dict[str, Any]" template: ClassVar[str] = "model_property.py.jinja" diff --git a/openapi_python_client/parser/properties/protocol.py b/openapi_python_client/parser/properties/protocol.py index c9555949d..17b55c3f1 100644 --- a/openapi_python_client/parser/properties/protocol.py +++ b/openapi_python_client/parser/properties/protocol.py @@ -1,5 +1,7 @@ from __future__ import annotations +from openapi_python_client.parser.properties.schemas import ReferencePath + __all__ = ["PropertyProtocol", "Value"] from abc import abstractmethod @@ -185,3 +187,6 @@ def is_base_type(self) -> bool: ListProperty.__name__, UnionProperty.__name__, } + + def get_ref_path(self) -> ReferencePath | None: + return self.ref_path if hasattr(self, "ref_path") else None diff --git a/openapi_python_client/parser/properties/schemas.py b/openapi_python_client/parser/properties/schemas.py index dad89a572..a1243ddb1 100644 --- a/openapi_python_client/parser/properties/schemas.py +++ b/openapi_python_client/parser/properties/schemas.py @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]: return cast(ReferencePath, parsed.fragment) +def get_reference_simple_name(ref_path: str) -> str: + """ + Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`. + """ + return ref_path.split("/", 3)[-1] + + @define class Class: """Represents Python class which will be generated from an OpenAPI schema""" @@ -135,6 +142,15 @@ def update_schemas_with_data( ) return prop + # Save the original path (/components/schemas/X) in the property. This is important because: + # 1. There are some contexts (such as a union with a discriminator) where we have a Property + # instance and we want to know what its path is, instead of the other way round. + # 2. Even though we did set prop.name to be the same as ref_path when we created it above, + # whenever there's a $ref to this property, we end up making a copy of it and changing + # the name. So we can't rely on prop.name always being the path. + if hasattr(prop, "ref_path"): + prop.ref_path = ref_path + schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference}) return schemas diff --git a/openapi_python_client/parser/properties/union.py b/openapi_python_client/parser/properties/union.py index 8b7b02a48..efa48eda2 100644 --- a/openapi_python_client/parser/properties/union.py +++ b/openapi_python_client/parser/properties/union.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, Mapping, OrderedDict, cast from attr import define, evolve @@ -10,7 +10,36 @@ from ...utils import PythonIdentifier from ..errors import ParseError, PropertyError from .protocol import PropertyProtocol, Value -from .schemas import Schemas +from .schemas import Schemas, get_reference_simple_name, parse_reference_path + + +@define +class DiscriminatorDefinition: + """Represents a discriminator that can optionally be specified for a union type. + + Normally, a UnionProperty has either zero or one of these. However, a nested union + could have more than one, as we accumulate all the discriminators when we flatten + out the nested schemas. For example: + + anyOf: + - anyOf: + - $ref: "#/components/schemas/Cat" + - $ref: "#/components/schemas/Dog" + discriminator: + propertyName: mammalType + - anyOf: + - $ref: "#/components/schemas/Condor" + - $ref: "#/components/schemas/Chicken" + discriminator: + propertyName: birdType + + In this example there are four schemas and two discriminators. The deserializer + logic will check for the mammalType property first, then birdType. + """ + + property_name: str + value_to_model_map: Mapping[str, PropertyProtocol] + # Every value in the map is really a ModelProperty, but this avoids circular imports @define @@ -24,6 +53,7 @@ class UnionProperty(PropertyProtocol): description: str | None example: str | None inner_properties: list[PropertyProtocol] + discriminators: list[DiscriminatorDefinition] | None = None template: ClassVar[str] = "union_property.py.jinja" @classmethod @@ -67,16 +97,7 @@ def build( return PropertyError(detail=f"Invalid property in union {name}", data=sub_prop_data), schemas sub_properties.append(sub_prop) - def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[PropertyProtocol]: - flattened = [] - for sub_prop in sub_properties: - if isinstance(sub_prop, UnionProperty): - flattened.extend(flatten_union_properties(sub_prop.inner_properties)) - else: - flattened.append(sub_prop) - return flattened - - sub_properties = flatten_union_properties(sub_properties) + sub_properties, discriminators_from_nested_unions = _flatten_union_properties(sub_properties) prop = UnionProperty( name=name, @@ -92,6 +113,16 @@ def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[Pro default_or_error.data = data return default_or_error, schemas prop = evolve(prop, default=default_or_error) + + all_discriminators = discriminators_from_nested_unions + if data.discriminator: + discriminator_or_error = _parse_discriminator(data.discriminator, sub_properties, schemas) + if isinstance(discriminator_or_error, PropertyError): + return discriminator_or_error, schemas + all_discriminators = [discriminator_or_error, *all_discriminators] + if all_discriminators: + prop = evolve(prop, discriminators=all_discriminators) + return prop, schemas def convert_value(self, value: Any) -> Value | None | PropertyError: @@ -189,3 +220,98 @@ def validate_location(self, location: oai.ParameterLocation) -> ParseError | Non if evolve(cast(Property, inner_prop), required=self.required).validate_location(location) is not None: return ParseError(detail=f"{self.get_type_string()} is not allowed in {location}") return None + + +def _flatten_union_properties( + sub_properties: list[PropertyProtocol], +) -> tuple[list[PropertyProtocol], list[DiscriminatorDefinition]]: + flattened = [] + discriminators = [] + for sub_prop in sub_properties: + if isinstance(sub_prop, UnionProperty): + if sub_prop.discriminators: + discriminators.extend(sub_prop.discriminators) + new_props, new_discriminators = _flatten_union_properties(sub_prop.inner_properties) + flattened.extend(new_props) + discriminators.extend(new_discriminators) + else: + flattened.append(sub_prop) + return flattened, discriminators + + +def _parse_discriminator( + data: oai.Discriminator, + subtypes: list[PropertyProtocol], + schemas: Schemas, +) -> DiscriminatorDefinition | PropertyError: + from .model_property import ModelProperty + + # See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object + + # Conditions that must be true when there is a discriminator: + # 1. Every type in the anyOf/oneOf list must be a $ref to a named schema, such as + # #/components/schemas/X, rather than an inline schema. This is important because + # we may need to use the schema's simple name (X). + # 2. There must be a propertyName, representing a property that exists in every + # schema in that list (although we can't currently enforce the latter condition, + # because those properties haven't been parsed yet at this point.) + # + # There *may* also be a mapping of lookup values (the possible values of the property) + # to schemas. Schemas can be referenced either by a full path or a name: + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB" + # + # For any type that isn't specified in the mapping (or if the whole mapping is omitted) + # the default lookup value for each schema is the same as the schema name. So this-- + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # --is exactly equivalent to this: + # discriminator: + # propertyName: modelType + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # ModelB: "#/components/schemas/ModelB" + + def _get_model_name(model: ModelProperty) -> str | None: + return get_reference_simple_name(model.ref_path) if model.ref_path else None + + model_types_by_name: dict[str, ModelProperty] = {} + for model in subtypes: + # Note, model here can never be a UnionProperty, because we've already done + # flatten_union_properties() before this point. + if not isinstance(model, ModelProperty): + return PropertyError( + detail="All schema variants must be objects when using a discriminator", + ) + name = _get_model_name(model) + if not name: + return PropertyError( + detail="Inline schema declarations are not allowed when using a discriminator", + ) + model_types_by_name[name] = model + + mapping: dict[str, ModelProperty] = OrderedDict() # use ordered dict for test determinacy + unspecified_models = list(model_types_by_name.values()) + if data.mapping: + for discriminator_value, model_ref in data.mapping.items(): + if "/" in model_ref: + ref_path = parse_reference_path(model_ref) + if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference: + return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping') + name = get_reference_simple_name(ref_path) + else: + name = model_ref + mapped_model = model_types_by_name.get(name) + if not mapped_model: + return PropertyError( + detail=f'Discriminator mapping referred to "{name}" which is not one of the schema variants', + ) + mapping[discriminator_value] = mapped_model + if mapped_model in unspecified_models: + # could've already been removed if more than one value is mapped to the same model + unspecified_models.remove(mapped_model) + for model in unspecified_models: + if name := _get_model_name(model): + mapping[name] = model + return DiscriminatorDefinition(property_name=data.propertyName, value_to_model_map=mapping) diff --git a/openapi_python_client/schema/__init__.py b/openapi_python_client/schema/__init__.py index d3de0e493..bfc0a0b5b 100644 --- a/openapi_python_client/schema/__init__.py +++ b/openapi_python_client/schema/__init__.py @@ -1,4 +1,5 @@ __all__ = [ + "Discriminator", "MediaType", "OpenAPI", "Operation", @@ -17,6 +18,7 @@ from .data_type import DataType from .openapi_schema_pydantic import ( + Discriminator, MediaType, OpenAPI, Operation, diff --git a/openapi_python_client/templates/property_templates/union_property.py.jinja b/openapi_python_client/templates/property_templates/union_property.py.jinja index dbf7ee9dc..89d6b6ffd 100644 --- a/openapi_python_client/templates/property_templates/union_property.py.jinja +++ b/openapi_python_client/templates/property_templates/union_property.py.jinja @@ -1,3 +1,36 @@ +{% macro construct_inner_property(inner_property) %} +{% import "property_templates/" + inner_property.template as inner_template %} +{% if inner_template.check_type_for_construct %} +if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: + raise TypeError() +{% endif %} +{{ inner_template.construct(inner_property, "data") }} +return {{ inner_property.python_name }} +{%- endmacro %} + +{% macro construct_discriminator_lookup(property) %} +{% set _discriminator_properties = [] -%} +{% for discriminator in property.discriminators %} +{{- _discriminator_properties.append(discriminator.property_name) or "" -}} +if not isinstance(data, dict): + raise TypeError() +if "{{ discriminator.property_name }}" in data: + _discriminator_value = data["{{ discriminator.property_name }}"] + {% for value, model in discriminator.value_to_model_map.items() %} + def _parse_{{ loop.index }}(data: object) -> {{ model.get_type_string() }}: +{{ construct_inner_property(model) | indent(8, True) }} + {% endfor %} + _discriminator_mapping = { + {% for value, model in discriminator.value_to_model_map.items() %} + "{{ value }}": _parse_{{ loop.index }}, + {% endfor %} + } + if _parse_fn := _discriminator_mapping.get(_discriminator_value): + return cast({{ property.get_type_string() }}, _parse_fn(data)) +{% endfor %} +raise TypeError(f"unrecognized value for property {{ _discriminator_properties | join(' or ') }}") +{% endmacro %} + {% macro construct(property, source) %} def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_string() }}: {% if "None" in property.get_type_strings_in_union(json=True, multipart=False) %} @@ -8,6 +41,9 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri if isinstance(data, Unset): return data {% endif %} +{% if property.discriminators %} +{{ construct_discriminator_lookup(property) | indent(4, True) }} +{% else %} {% set ns = namespace(contains_unmodified_properties = false) %} {% for inner_property in property.inner_properties %} {% import "property_templates/" + inner_property.template as inner_template %} @@ -17,24 +53,17 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri {% endif %} {% if inner_template.check_type_for_construct and (not loop.last or ns.contains_unmodified_properties) %} try: - if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: - raise TypeError() - {{ inner_template.construct(inner_property, "data") | indent(8) }} - return {{ inner_property.python_name }} +{{ construct_inner_property(inner_property) | indent(8, True) }} except: # noqa: E722 pass {% else %}{# Don't do try/except for the last one nor any properties with no type checking #} - {% if inner_template.check_type_for_construct %} - if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: - raise TypeError() - {% endif %} - {{ inner_template.construct(inner_property, "data") | indent(4) }} - return {{ inner_property.python_name }} +{{ construct_inner_property(inner_property) | indent(4, True) }} {% endif %} {% endfor %} {% if ns.contains_unmodified_properties %} return cast({{ property.get_type_string() }}, data) {% endif %} +{% endif %} {{ property.python_name }} = _parse_{{ property.python_name }}({{ source }}) {% endmacro %} diff --git a/tests/test_parser/test_properties/properties_test_helpers.py b/tests/test_parser/test_properties/properties_test_helpers.py new file mode 100644 index 000000000..a154bcb62 --- /dev/null +++ b/tests/test_parser/test_properties/properties_test_helpers.py @@ -0,0 +1,16 @@ +import re +from typing import Any, Union + +from openapi_python_client.parser.errors import PropertyError +from openapi_python_client.parser.properties.property import Property + + +def assert_prop_error( + p: Union[Property, PropertyError], + message_regex: str, + data: Any = None, +) -> None: + assert isinstance(p, PropertyError) + assert re.search(message_regex, p.detail) + if data is not None: + assert p.data == data diff --git a/tests/test_parser/test_properties/test_protocol.py b/tests/test_parser/test_properties/test_protocol.py index 1d4111750..5c2f1c993 100644 --- a/tests/test_parser/test_properties/test_protocol.py +++ b/tests/test_parser/test_properties/test_protocol.py @@ -3,6 +3,7 @@ import pytest from openapi_python_client.parser.properties.protocol import Value +from openapi_python_client.parser.properties.schemas import ReferencePath def test_is_base_type(any_property_factory): @@ -85,3 +86,12 @@ def test_get_base_json_type_string(quoted, expected, any_property_factory, mocke mocker.patch.object(AnyProperty, "_json_type_string", "str") p = any_property_factory() assert p.get_base_json_type_string(quoted=quoted) is expected + + +def test_ref_path(any_property_factory, model_property_factory): + p1 = any_property_factory() + assert p1.get_ref_path() is None + + path = ReferencePath("/components/schemas/A") + p2 = model_property_factory(ref_path=path) + assert p2.get_ref_path() == path diff --git a/tests/test_parser/test_properties/test_union.py b/tests/test_parser/test_properties/test_union.py index acbbd06d6..b3305547b 100644 --- a/tests/test_parser/test_properties/test_union.py +++ b/tests/test_parser/test_properties/test_union.py @@ -1,8 +1,18 @@ +from typing import Dict, List, Optional, Tuple, Union + +import pytest +from attr import evolve + import openapi_python_client.schema as oai +from openapi_python_client.config import Config from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas, UnionProperty +from openapi_python_client.parser.properties.model_property import ModelProperty +from openapi_python_client.parser.properties.property import Property from openapi_python_client.parser.properties.protocol import Value +from openapi_python_client.parser.properties.schemas import Class, ReferencePath from openapi_python_client.schema import DataType, ParameterLocation +from tests.test_parser.test_properties.properties_test_helpers import assert_prop_error def test_property_from_data_union(union_property_factory, date_time_property_factory, string_property_factory, config): @@ -33,6 +43,207 @@ def test_property_from_data_union(union_property_factory, date_time_property_fac assert s == Schemas() +def _make_basic_model( + name: str, + props: Dict[str, oai.Schema], + required_prop: Optional[str], + schemas: Schemas, + config: Config, +) -> Tuple[ModelProperty, Schemas]: + model, schemas = ModelProperty.build( + data=oai.Schema.model_construct( + required=[required_prop] if required_prop else [], + title=name, + properties=props, + ), + name=name or "some_generated_name", + schemas=schemas, + required=False, + parent_name="", + config=config, + roots={"root"}, + process_properties=True, + ) + assert isinstance(model, ModelProperty) + if name: + model.ref_path = ReferencePath(f"/components/schemas/{name}") + schemas = evolve( + schemas, classes_by_reference={**schemas.classes_by_reference, f"/components/schemas/{name}": model} + ) + return model, schemas + + +def _assert_valid_discriminator( + p: Union[Property, PropertyError], + expected_discriminators: List[Tuple[str, Dict[str, Class]]], +) -> None: + assert isinstance(p, UnionProperty) + assert p.discriminators + assert [(d[0], {key: model.class_info for key, model in d[1].items()}) for d in expected_discriminators] == [ + (d.property_name, {key: model.class_info for key, model in d.value_to_model_map.items()}) + for d in p.discriminators + ] + + +def test_discriminator_with_explicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + # mappings can use either a fully-qualified schema reference or just the schema name + "type1": "#/components/schemas/Model1", + "type2": "Model2", + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"type1": model1, "type2": model2})]) + + +def test_discriminator_with_implicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"Model1": model1, "Model2": model2})]) + + +def test_discriminator_with_partial_explicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "type1": "#/components/schemas/Model1", + # no value specified for Model2, so it defaults to just "Model2" + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"type1": model1, "Model2": model2})]) + + +def test_discriminators_in_nested_unions_same_property(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) + model4, schemas = _make_basic_model("Model4", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model3"), + oai.Reference(ref="#/components/schemas/Model4"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + ], + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator( + p, + [ + ("type", {"Model1": model1, "Model2": model2}), + ("type", {"Model3": model3, "Model4": model4}), + ], + ) + + +def test_discriminators_in_nested_unions_different_property(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props1 = {"type": oai.Schema.model_construct(type="string")} + props2 = {"other": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props1, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props1, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props2, "other", schemas, config) + model4, schemas = _make_basic_model("Model4", props2, "other", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model3"), + oai.Reference(ref="#/components/schemas/Model4"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="other"), + ), + ], + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator( + p, + [ + ("type", {"Model1": model1, "Model2": model2}), + ("other", {"Model3": model3, "Model4": model4}), + ], + ) + + def test_build_union_property_invalid_property(config): name = "bad_union" required = True @@ -42,7 +253,7 @@ def test_build_union_property_invalid_property(config): p, s = UnionProperty.build( name=name, required=required, data=data, schemas=Schemas(), parent_name="parent", config=config ) - assert p == PropertyError(detail=f"Invalid property in union {name}", data=reference) + assert_prop_error(p, f"Invalid property in union {name}", data=reference) def test_invalid_default(config): @@ -82,3 +293,115 @@ def test_not_required_in_path(config): err = prop.validate_location(ParameterLocation.PATH) assert isinstance(err, ParseError) + + +@pytest.mark.parametrize("bad_ref", ["#/components/schemas/UnknownModel", "http://remote/Model2"]) +def test_discriminator_invalid_reference(bad_ref, config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "Model1": "#/components/schemas/Model1", + "Model2": bad_ref, + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "^Invalid reference") + + +def test_discriminator_mapping_uses_schema_not_in_list(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "Model1": "#/components/schemas/Model1", + "Model3": "#/components/schemas/Model3", + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "not one of the schema variants") + + +def test_discriminator_invalid_variant_is_not_object(config, string_property_factory): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model_type, schemas = _make_basic_model("ModelType", props, "type", schemas, config) + string_type = string_property_factory() + schemas = evolve( + schemas, + classes_by_reference={ + **schemas.classes_by_reference, + "/components/schemas/StringType": string_type, + }, + ) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/ModelType"), + oai.Reference(ref="#/components/schemas/StringType"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "must be objects") + + +def test_discriminator_invalid_inline_schema_variant(config, string_property_factory): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Schema.model_construct( + type="object", + properties=props, + ), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "Inline schema")