From 0496ff886af2a73718b2ceb250e49078905fc666 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Mon, 28 Oct 2024 11:30:15 -0700 Subject: [PATCH 1/6] add discriminator property support --- end_to_end_tests/__init__.py | 3 + end_to_end_tests/baseline_openapi_3.0.json | 6 +- end_to_end_tests/baseline_openapi_3.1.yaml | 6 +- end_to_end_tests/end_to_end_live_tests.py | 33 ++ .../models/a_discriminated_union_type_1.py | 18 +- .../models/a_discriminated_union_type_2.py | 18 +- .../models/model_with_discriminated_union.py | 45 ++- end_to_end_tests/test_end_to_end.py | 18 +- .../parser/properties/schemas.py | 7 + .../parser/properties/union.py | 134 ++++++- openapi_python_client/schema/__init__.py | 2 + .../union_property.py.jinja | 49 ++- .../properties_test_helpers.py | 16 + .../test_parser/test_properties/test_union.py | 364 +++++++++++++++++- 14 files changed, 655 insertions(+), 64 deletions(-) create mode 100644 end_to_end_tests/end_to_end_live_tests.py create mode 100644 tests/test_parser/test_properties/properties_test_helpers.py diff --git a/end_to_end_tests/__init__.py b/end_to_end_tests/__init__.py index 1bf33f63f..b91590c3b 100644 --- a/end_to_end_tests/__init__.py +++ b/end_to_end_tests/__init__.py @@ -1 +1,4 @@ """ Generate a complete client and verify that it is correct """ +import pytest + +pytest.register_assert_rewrite('end_to_end_tests.end_to_end_live_tests') diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index e5bbaf6fc..1f3232e69 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -2841,7 +2841,8 @@ "modelType": { "type": "string" } - } + }, + "required": ["modelType"] }, "ADiscriminatedUnionType2": { "type": "object", @@ -2849,7 +2850,8 @@ "modelType": { "type": "string" } - } + }, + "required": ["modelType"] } }, "parameters": { diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index b6a6941e2..373660695 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -2835,7 +2835,8 @@ info: "modelType": { "type": "string" } - } + }, + "required": ["modelType"] }, "ADiscriminatedUnionType2": { "type": "object", @@ -2843,7 +2844,8 @@ info: "modelType": { "type": "string" } - } + }, + "required": ["modelType"] } } "parameters": { diff --git a/end_to_end_tests/end_to_end_live_tests.py b/end_to_end_tests/end_to_end_live_tests.py new file mode 100644 index 000000000..df82a8354 --- /dev/null +++ b/end_to_end_tests/end_to_end_live_tests.py @@ -0,0 +1,33 @@ +import importlib +from typing import Any + +import pytest + + +def live_tests_3_x(): + _test_model_with_discriminated_union() + + +def _import_model(module_name, class_name: str) -> Any: + module = importlib.import_module(f"my_test_api_client.models.{module_name}") + module = importlib.reload(module) # avoid test contamination from previous import + return getattr(module, class_name) + + +def _test_model_with_discriminated_union(): + ModelType1Class = _import_model("a_discriminated_union_type_1", "ADiscriminatedUnionType1") + ModelType2Class = _import_model("a_discriminated_union_type_2", "ADiscriminatedUnionType2") + ModelClass = _import_model("model_with_discriminated_union", "ModelWithDiscriminatedUnion") + + assert ( + ModelClass.from_dict({"discriminated_union": {"modelType": "type1"}}) == + ModelClass(discriminated_union=ModelType1Class.from_dict({"modelType": "type1"})) + ) + assert ( + ModelClass.from_dict({"discriminated_union": {"modelType": "type2"}}) == + ModelClass(discriminated_union=ModelType2Class.from_dict({"modelType": "type2"})) + ) + with pytest.raises(TypeError): + ModelClass.from_dict({"discriminated_union": {"modelType": "type3"}}) + with pytest.raises(TypeError): + ModelClass.from_dict({"discriminated_union": {}}) diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py index cb1184b18..ed02a0aaf 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field -from ..types import UNSET, Unset - T = TypeVar("T", bound="ADiscriminatedUnionType1") @@ -12,10 +10,10 @@ class ADiscriminatedUnionType1: """ Attributes: - model_type (Union[Unset, str]): + model_type (str): """ - model_type: Union[Unset, str] = UNSET + model_type: str additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> Dict[str, Any]: @@ -23,16 +21,18 @@ def to_dict(self) -> Dict[str, Any]: field_dict: Dict[str, Any] = {} field_dict.update(self.additional_properties) - field_dict.update({}) - if model_type is not UNSET: - field_dict["modelType"] = model_type + field_dict.update( + { + "modelType": model_type, + } + ) return field_dict @classmethod def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() - model_type = d.pop("modelType", UNSET) + model_type = d.pop("modelType") a_discriminated_union_type_1 = cls( model_type=model_type, diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py index 734f3bef4..93ee2cbb9 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field -from ..types import UNSET, Unset - T = TypeVar("T", bound="ADiscriminatedUnionType2") @@ -12,10 +10,10 @@ class ADiscriminatedUnionType2: """ Attributes: - model_type (Union[Unset, str]): + model_type (str): """ - model_type: Union[Unset, str] = UNSET + model_type: str additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> Dict[str, Any]: @@ -23,16 +21,18 @@ def to_dict(self) -> Dict[str, Any]: field_dict: Dict[str, Any] = {} field_dict.update(self.additional_properties) - field_dict.update({}) - if model_type is not UNSET: - field_dict["modelType"] = model_type + field_dict.update( + { + "modelType": model_type, + } + ) return field_dict @classmethod def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() - model_type = d.pop("modelType", UNSET) + model_type = d.pop("modelType") a_discriminated_union_type_2 = cls( model_type=model_type, diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py index e03a6e698..4e9cdfd86 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py @@ -59,23 +59,34 @@ def _parse_discriminated_union( return data if isinstance(data, Unset): return data - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data) - - return componentsschemas_a_discriminated_union_type_0 - except: # noqa: E722 - pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) - - return componentsschemas_a_discriminated_union_type_1 - except: # noqa: E722 - pass - return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data) + if not isinstance(data, dict): + raise TypeError() + if "modelType" in data: + _discriminator_value = data["modelType"] + + def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data) + + return componentsschemas_a_discriminated_union_type_1 + + def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data) + + return componentsschemas_a_discriminated_union_type_2 + + _discriminator_mapping = { + "type1": _parse_componentsschemas_a_discriminated_union_type_1, + "type2": _parse_componentsschemas_a_discriminated_union_type_2, + } + if _parse_fn := _discriminator_mapping.get(_discriminator_value): + return cast( + Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data) + ) + raise TypeError("unrecognized value for property modelType") discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET)) diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index a448a0698..46755e6ca 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -1,13 +1,17 @@ +import os import shutil from filecmp import cmpfiles, dircmp from pathlib import Path -from typing import Dict, List, Optional, Set +import sys +from typing import Callable, Dict, List, Optional, Set import pytest from click.testing import Result from typer.testing import CliRunner from openapi_python_client.cli import app +from .end_to_end_live_tests import live_tests_3_x + def _compare_directories( @@ -83,6 +87,7 @@ def run_e2e_test( golden_record_path: str = "golden-record", output_path: str = "my-test-api-client", expected_missing: Optional[Set[str]] = None, + live_tests: Optional[Callable[[str], None]] = None, ) -> Result: output_path = Path.cwd() / output_path shutil.rmtree(output_path, ignore_errors=True) @@ -97,6 +102,13 @@ def run_e2e_test( _compare_directories( gr_path, output_path, expected_differences=expected_differences, expected_missing=expected_missing ) + if live_tests: + old_path = sys.path.copy() + sys.path.insert(0, str(output_path)) + try: + live_tests() + finally: + sys.path = old_path import mypy.api @@ -131,11 +143,11 @@ def _run_command(command: str, extra_args: Optional[List[str]] = None, openapi_d def test_baseline_end_to_end_3_0(): - run_e2e_test("baseline_openapi_3.0.json", [], {}) + run_e2e_test("baseline_openapi_3.0.json", [], {}, live_tests=live_tests_3_x) def test_baseline_end_to_end_3_1(): - run_e2e_test("baseline_openapi_3.1.yaml", [], {}) + run_e2e_test("baseline_openapi_3.1.yaml", [], {}, live_tests=live_tests_3_x) def test_3_1_specific_features(): diff --git a/openapi_python_client/parser/properties/schemas.py b/openapi_python_client/parser/properties/schemas.py index dad89a572..c8d0cde12 100644 --- a/openapi_python_client/parser/properties/schemas.py +++ b/openapi_python_client/parser/properties/schemas.py @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]: return cast(ReferencePath, parsed.fragment) +def get_reference_simple_name(ref_path: str) -> str: + """ + Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`. + """ + return ref_path.split("/", 3)[-1] + + @define class Class: """Represents Python class which will be generated from an OpenAPI schema""" diff --git a/openapi_python_client/parser/properties/union.py b/openapi_python_client/parser/properties/union.py index 8b7b02a48..e03aacf6f 100644 --- a/openapi_python_client/parser/properties/union.py +++ b/openapi_python_client/parser/properties/union.py @@ -10,7 +10,14 @@ from ...utils import PythonIdentifier from ..errors import ParseError, PropertyError from .protocol import PropertyProtocol, Value -from .schemas import Schemas +from .schemas import Schemas, get_reference_simple_name, parse_reference_path + + +@define +class DiscriminatorDefinition: + property_name: str + value_to_model_map: dict[str, PropertyProtocol] + # Every value in the map is really a ModelProperty, but this avoids circular imports @define @@ -24,6 +31,7 @@ class UnionProperty(PropertyProtocol): description: str | None example: str | None inner_properties: list[PropertyProtocol] + discriminators: list[DiscriminatorDefinition] | None = None template: ClassVar[str] = "union_property.py.jinja" @classmethod @@ -67,16 +75,7 @@ def build( return PropertyError(detail=f"Invalid property in union {name}", data=sub_prop_data), schemas sub_properties.append(sub_prop) - def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[PropertyProtocol]: - flattened = [] - for sub_prop in sub_properties: - if isinstance(sub_prop, UnionProperty): - flattened.extend(flatten_union_properties(sub_prop.inner_properties)) - else: - flattened.append(sub_prop) - return flattened - - sub_properties = flatten_union_properties(sub_properties) + sub_properties, discriminators_list = _flatten_union_properties(sub_properties) prop = UnionProperty( name=name, @@ -92,6 +91,17 @@ def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[Pro default_or_error.data = data return default_or_error, schemas prop = evolve(prop, default=default_or_error) + + if data.discriminator: + discriminator_or_error = _parse_discriminator(data.discriminator, sub_properties, schemas) + if isinstance(discriminator_or_error, PropertyError): + return discriminator_or_error, schemas + discriminators_list = [discriminator_or_error, *discriminators_list] + if discriminators_list: + if error := _validate_discriminators(discriminators_list): + return error, schemas + prop = evolve(prop, discriminators=discriminators_list) + return prop, schemas def convert_value(self, value: Any) -> Value | None | PropertyError: @@ -189,3 +199,105 @@ def validate_location(self, location: oai.ParameterLocation) -> ParseError | Non if evolve(cast(Property, inner_prop), required=self.required).validate_location(location) is not None: return ParseError(detail=f"{self.get_type_string()} is not allowed in {location}") return None + + +def _flatten_union_properties( + sub_properties: list[PropertyProtocol], +) -> tuple[list[PropertyProtocol], list[DiscriminatorDefinition]]: + flattened = [] + discriminators = [] + for sub_prop in sub_properties: + if isinstance(sub_prop, UnionProperty): + if sub_prop.discriminators: + discriminators.extend(sub_prop.discriminators) + new_props, new_discriminators = _flatten_union_properties(sub_prop.inner_properties) + flattened.extend(new_props) + discriminators.extend(new_discriminators) + else: + flattened.append(sub_prop) + return flattened, discriminators + + +def _parse_discriminator( + data: oai.Discriminator, + subtypes: list[PropertyProtocol], + schemas: Schemas, +) -> DiscriminatorDefinition | PropertyError: + from .model_property import ModelProperty + + # See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object + + def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None: + # This is needed because, when we built the union list, $refs were changed into a copy of + # the type they referred to, without preserving the original name. We need to know that + # every type in the discriminator is a $ref to a top-level type and we need its name. + for prop in schemas.classes_by_reference.values(): + if isinstance(prop, ModelProperty): + if prop.class_info == matching_model.class_info: + return prop + return None + + model_types_by_name: dict[str, PropertyProtocol] = {} + for model in subtypes: + # Note, model here can never be a UnionProperty, because we've already done + # flatten_union_properties() before this point. + if not isinstance(model, ModelProperty): + return PropertyError( + detail="All schema variants must be objects when using a discriminator", + ) + top_level_model = _find_top_level_model(model) + if not top_level_model: + return PropertyError( + detail="Inline schema declarations are not allowed when using a discriminator", + ) + name = top_level_model.name + if name.startswith("/components/schemas/"): + name = get_reference_simple_name(name) + model_types_by_name[name] = top_level_model + + # The discriminator can specify an explicit mapping of values to types, but it doesn't + # have to; the default behavior is that the value for each type is simply its name. + mapping: dict[str, PropertyProtocol] = model_types_by_name.copy() + if data.mapping: + for discriminator_value, model_ref in data.mapping.items(): + ref_path = parse_reference_path( + model_ref if model_ref.startswith("#/components/schemas/") else f"#/components/schemas/{model_ref}" + ) + if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference: + return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping') + name = get_reference_simple_name(ref_path) + if not (lookup_model := model_types_by_name.get(name)): + return PropertyError( + detail=f'Discriminator mapping referred to "{model_ref}" which is not one of the schema variants', + ) + for original_value in (name for name, m in model_types_by_name.items() if m == lookup_model): + mapping.pop(original_value) + mapping[discriminator_value] = lookup_model + else: + mapping = model_types_by_name + + return DiscriminatorDefinition(property_name=data.propertyName, value_to_model_map=mapping) + + +def _validate_discriminators( + discriminators: list[DiscriminatorDefinition], +) -> PropertyError | None: + from .model_property import ModelProperty + + prop_names_values_classes = [ + (discriminator.property_name, key, cast(ModelProperty, model).class_info.name) + for discriminator in discriminators + for key, model in discriminator.value_to_model_map.items() + ] + for p, v in {(p, v) for p, v, _ in prop_names_values_classes}: + if len({c for p1, v1, c in prop_names_values_classes if (p1, v1) == (p, v)}) > 1: + return PropertyError(f'Discriminator property "{p}" had more than one schema for value "{v}"') + return None + + # TODO: We should also validate that property_name refers to a property that 1. exists, + # 2. is required, 3. is a string (in all of these models). However, currently we can't + # do that because, at the time this function is called, the ModelProperties within the + # union haven't yet been post-processed and so we don't have full information about + # their properties. To fix this, we may need to generalize the post-processing phase so + # that any Property type, not just ModelProperty, can say it needs post-processing; then + # we can defer _validate_discriminators till that phase. diff --git a/openapi_python_client/schema/__init__.py b/openapi_python_client/schema/__init__.py index d3de0e493..bfc0a0b5b 100644 --- a/openapi_python_client/schema/__init__.py +++ b/openapi_python_client/schema/__init__.py @@ -1,4 +1,5 @@ __all__ = [ + "Discriminator", "MediaType", "OpenAPI", "Operation", @@ -17,6 +18,7 @@ from .data_type import DataType from .openapi_schema_pydantic import ( + Discriminator, MediaType, OpenAPI, Operation, diff --git a/openapi_python_client/templates/property_templates/union_property.py.jinja b/openapi_python_client/templates/property_templates/union_property.py.jinja index dbf7ee9dc..0a0b1d49f 100644 --- a/openapi_python_client/templates/property_templates/union_property.py.jinja +++ b/openapi_python_client/templates/property_templates/union_property.py.jinja @@ -1,3 +1,36 @@ +{% macro construct_inner_property(inner_property) %} +{% import "property_templates/" + inner_property.template as inner_template %} +{% if inner_template.check_type_for_construct %} +if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: + raise TypeError() +{% endif %} +{{ inner_template.construct(inner_property, "data") }} +return {{ inner_property.python_name }} +{%- endmacro %} + +{% macro construct_discriminator_lookup(property) %} +{% set _discriminator_properties = [] -%} +{% for discriminator in property.discriminators %} +{{- _discriminator_properties.append(discriminator.property_name) or "" -}} +if not isinstance(data, dict): + raise TypeError() +if "{{ discriminator.property_name }}" in data: + _discriminator_value = data["{{ discriminator.property_name }}"] + {% for model in discriminator.value_to_model_map.values() %} + def _parse_{{ model.python_name }}(data: object) -> {{ model.get_type_string() }}: +{{ construct_inner_property(model) | indent(12, True) }} + {% endfor %} + _discriminator_mapping = { + {% for value, model in discriminator.value_to_model_map.items() %} + "{{ value }}": _parse_{{ model.python_name }}, + {% endfor %} + } + if _parse_fn := _discriminator_mapping.get(_discriminator_value): + return cast({{ property.get_type_string() }}, _parse_fn(data)) +{% endfor %} +raise TypeError(f"unrecognized value for property {{ _discriminator_properties | join(' or ') }}") +{% endmacro %} + {% macro construct(property, source) %} def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_string() }}: {% if "None" in property.get_type_strings_in_union(json=True, multipart=False) %} @@ -8,6 +41,9 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri if isinstance(data, Unset): return data {% endif %} +{% if property.discriminators %} +{{ construct_discriminator_lookup(property) | indent(4, True) }} +{% else %} {% set ns = namespace(contains_unmodified_properties = false) %} {% for inner_property in property.inner_properties %} {% import "property_templates/" + inner_property.template as inner_template %} @@ -17,24 +53,17 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri {% endif %} {% if inner_template.check_type_for_construct and (not loop.last or ns.contains_unmodified_properties) %} try: - if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: - raise TypeError() - {{ inner_template.construct(inner_property, "data") | indent(8) }} - return {{ inner_property.python_name }} +{{ construct_inner_property(inner_property) | indent(8, True) }} except: # noqa: E722 pass {% else %}{# Don't do try/except for the last one nor any properties with no type checking #} - {% if inner_template.check_type_for_construct %} - if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: - raise TypeError() - {% endif %} - {{ inner_template.construct(inner_property, "data") | indent(4) }} - return {{ inner_property.python_name }} +{{ construct_inner_property(inner_property) | indent(4, True) }} {% endif %} {% endfor %} {% if ns.contains_unmodified_properties %} return cast({{ property.get_type_string() }}, data) {% endif %} +{% endif %} {{ property.python_name }} = _parse_{{ property.python_name }}({{ source }}) {% endmacro %} diff --git a/tests/test_parser/test_properties/properties_test_helpers.py b/tests/test_parser/test_properties/properties_test_helpers.py new file mode 100644 index 000000000..a154bcb62 --- /dev/null +++ b/tests/test_parser/test_properties/properties_test_helpers.py @@ -0,0 +1,16 @@ +import re +from typing import Any, Union + +from openapi_python_client.parser.errors import PropertyError +from openapi_python_client.parser.properties.property import Property + + +def assert_prop_error( + p: Union[Property, PropertyError], + message_regex: str, + data: Any = None, +) -> None: + assert isinstance(p, PropertyError) + assert re.search(message_regex, p.detail) + if data is not None: + assert p.data == data diff --git a/tests/test_parser/test_properties/test_union.py b/tests/test_parser/test_properties/test_union.py index acbbd06d6..27b3c77c5 100644 --- a/tests/test_parser/test_properties/test_union.py +++ b/tests/test_parser/test_properties/test_union.py @@ -1,8 +1,18 @@ +from typing import Dict, List, Optional, Tuple, Union + +import pytest +from attr import evolve + import openapi_python_client.schema as oai +from openapi_python_client.config import Config from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas, UnionProperty +from openapi_python_client.parser.properties.model_property import ModelProperty +from openapi_python_client.parser.properties.property import Property from openapi_python_client.parser.properties.protocol import Value +from openapi_python_client.parser.properties.schemas import Class from openapi_python_client.schema import DataType, ParameterLocation +from tests.test_parser.test_properties.properties_test_helpers import assert_prop_error def test_property_from_data_union(union_property_factory, date_time_property_factory, string_property_factory, config): @@ -33,6 +43,206 @@ def test_property_from_data_union(union_property_factory, date_time_property_fac assert s == Schemas() +def _make_basic_model( + name: str, + props: Dict[str, oai.Schema], + required_prop: Optional[str], + schemas: Schemas, + config: Config, +) -> Tuple[ModelProperty, Schemas]: + model, schemas = ModelProperty.build( + data=oai.Schema.model_construct( + required=[required_prop] if required_prop else [], + title=name, + properties=props, + ), + name=name or "some_generated_name", + schemas=schemas, + required=False, + parent_name="", + config=config, + roots={"root"}, + process_properties=True, + ) + assert isinstance(model, ModelProperty) + if name: + schemas = evolve( + schemas, classes_by_reference={**schemas.classes_by_reference, f"/components/schemas/{name}": model} + ) + return model, schemas + + +def _assert_valid_discriminator( + p: Union[Property, PropertyError], + expected_discriminators: List[Tuple[str, Dict[str, Class]]], +) -> None: + assert isinstance(p, UnionProperty) + assert p.discriminators + assert [(d[0], {key: model.class_info for key, model in d[1].items()}) for d in expected_discriminators] == [ + (d.property_name, {key: model.class_info for key, model in d.value_to_model_map.items()}) + for d in p.discriminators + ] + + +def test_discriminator_with_explicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + # mappings can use either a fully-qualified schema reference or just the schema name + "type1": "#/components/schemas/Model1", + "type2": "Model2", + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"type1": model1, "type2": model2})]) + + +def test_discriminator_with_implicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"Model1": model1, "Model2": model2})]) + + +def test_discriminator_with_partial_explicit_mapping(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "type1": "#/components/schemas/Model1", + # no value specified for Model2, so it defaults to just "Model2" + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator(p, [("type", {"type1": model1, "Model2": model2})]) + + +def test_discriminators_in_nested_unions_same_property(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) + model4, schemas = _make_basic_model("Model4", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model3"), + oai.Reference(ref="#/components/schemas/Model4"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + ], + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator( + p, + [ + ("type", {"Model1": model1, "Model2": model2}), + ("type", {"Model3": model3, "Model4": model4}), + ], + ) + + +def test_discriminators_in_nested_unions_different_property(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props1 = {"type": oai.Schema.model_construct(type="string")} + props2 = {"other": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props1, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props1, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props2, "other", schemas, config) + model4, schemas = _make_basic_model("Model4", props2, "other", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="type"), + ), + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model3"), + oai.Reference(ref="#/components/schemas/Model4"), + ], + discriminator=oai.Discriminator.model_construct(propertyName="other"), + ), + ], + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + _assert_valid_discriminator( + p, + [ + ("type", {"Model1": model1, "Model2": model2}), + ("other", {"Model3": model3, "Model4": model4}), + ], + ) + + def test_build_union_property_invalid_property(config): name = "bad_union" required = True @@ -42,7 +252,7 @@ def test_build_union_property_invalid_property(config): p, s = UnionProperty.build( name=name, required=required, data=data, schemas=Schemas(), parent_name="parent", config=config ) - assert p == PropertyError(detail=f"Invalid property in union {name}", data=reference) + assert_prop_error(p, f"Invalid property in union {name}", data=reference) def test_invalid_default(config): @@ -82,3 +292,155 @@ def test_not_required_in_path(config): err = prop.validate_location(ParameterLocation.PATH) assert isinstance(err, ParseError) + + +@pytest.mark.parametrize("bad_ref", ["#/components/schemas/UnknownModel", "http://remote/Model2"]) +def test_discriminator_invalid_reference(bad_ref, config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "Model1": "#/components/schemas/Model1", + "Model2": bad_ref, + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "^Invalid reference") + + +def test_discriminator_mapping_uses_schema_not_in_list(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={ + "Model1": "#/components/schemas/Model1", + "Model3": "#/components/schemas/Model3", + }, + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "not one of the schema variants") + + +def test_discriminator_invalid_variant_is_not_object(config, string_property_factory): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model_type, schemas = _make_basic_model("ModelType", props, "type", schemas, config) + string_type = string_property_factory() + schemas = evolve( + schemas, + classes_by_reference={ + **schemas.classes_by_reference, + "/components/schemas/StringType": string_type, + }, + ) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/ModelType"), + oai.Reference(ref="#/components/schemas/StringType"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "must be objects") + + +def test_discriminator_invalid_inline_schema_variant(config, string_property_factory): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Schema.model_construct( + type="object", + properties=props, + ), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + ), + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, "Inline schema") + + +def test_conflicting_discriminator_mappings(config): + from openapi_python_client.parser.properties import Schemas, property_from_data + + schemas = Schemas() + props = {"type": oai.Schema.model_construct(type="string")} + model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) + model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) + model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) + model4, schemas = _make_basic_model("Model4", props, "type", schemas, config) + data = oai.Schema.model_construct( + oneOf=[ + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model1"), + oai.Reference(ref="#/components/schemas/Model2"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={"a": "Model1", "b": "Model2"}, + ), + ), + oai.Schema.model_construct( + oneOf=[ + oai.Reference(ref="#/components/schemas/Model3"), + oai.Reference(ref="#/components/schemas/Model4"), + ], + discriminator=oai.Discriminator.model_construct( + propertyName="type", + mapping={"a": "Model3", "x": "Model4"}, + ), + ), + ], + ) + + p, schemas = property_from_data( + name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config + ) + assert_prop_error(p, '"type" had more than one schema for value "a"') From b75adda0bcf64009120f8b6eedeaf55bf95a5fde Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Wed, 30 Oct 2024 15:25:37 -0700 Subject: [PATCH 2/6] simply the discriminator logic a bit --- .../models/model_with_discriminated_union.py | 16 +- .../parser/properties/model_property.py | 1 + .../parser/properties/protocol.py | 5 + .../parser/properties/schemas.py | 9 ++ .../parser/properties/union.py | 139 ++++++++++-------- .../union_property.py.jinja | 8 +- .../test_parser/test_properties/test_union.py | 43 +----- 7 files changed, 104 insertions(+), 117 deletions(-) diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py index 4e9cdfd86..1a0918332 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py @@ -64,23 +64,23 @@ def _parse_discriminated_union( if "modelType" in data: _discriminator_value = data["modelType"] - def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1: + def _parse_1(data: object) -> ADiscriminatedUnionType1: if not isinstance(data, dict): raise TypeError() - componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data) + componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data) - return componentsschemas_a_discriminated_union_type_1 + return componentsschemas_a_discriminated_union_type_0 - def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2: + def _parse_2(data: object) -> ADiscriminatedUnionType2: if not isinstance(data, dict): raise TypeError() - componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data) + componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) - return componentsschemas_a_discriminated_union_type_2 + return componentsschemas_a_discriminated_union_type_1 _discriminator_mapping = { - "type1": _parse_componentsschemas_a_discriminated_union_type_1, - "type2": _parse_componentsschemas_a_discriminated_union_type_2, + "type1": _parse_1, + "type2": _parse_2, } if _parse_fn := _discriminator_mapping.get(_discriminator_value): return cast( diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 0dc13be54..fd3e24ac8 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol): relative_imports: set[str] | None lazy_imports: set[str] | None additional_properties: Property | None + ref_path: ReferencePath | None = None _json_type_string: ClassVar[str] = "Dict[str, Any]" template: ClassVar[str] = "model_property.py.jinja" diff --git a/openapi_python_client/parser/properties/protocol.py b/openapi_python_client/parser/properties/protocol.py index c9555949d..17b55c3f1 100644 --- a/openapi_python_client/parser/properties/protocol.py +++ b/openapi_python_client/parser/properties/protocol.py @@ -1,5 +1,7 @@ from __future__ import annotations +from openapi_python_client.parser.properties.schemas import ReferencePath + __all__ = ["PropertyProtocol", "Value"] from abc import abstractmethod @@ -185,3 +187,6 @@ def is_base_type(self) -> bool: ListProperty.__name__, UnionProperty.__name__, } + + def get_ref_path(self) -> ReferencePath | None: + return self.ref_path if hasattr(self, "ref_path") else None diff --git a/openapi_python_client/parser/properties/schemas.py b/openapi_python_client/parser/properties/schemas.py index c8d0cde12..a1243ddb1 100644 --- a/openapi_python_client/parser/properties/schemas.py +++ b/openapi_python_client/parser/properties/schemas.py @@ -142,6 +142,15 @@ def update_schemas_with_data( ) return prop + # Save the original path (/components/schemas/X) in the property. This is important because: + # 1. There are some contexts (such as a union with a discriminator) where we have a Property + # instance and we want to know what its path is, instead of the other way round. + # 2. Even though we did set prop.name to be the same as ref_path when we created it above, + # whenever there's a $ref to this property, we end up making a copy of it and changing + # the name. So we can't rely on prop.name always being the path. + if hasattr(prop, "ref_path"): + prop.ref_path = ref_path + schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference}) return schemas diff --git a/openapi_python_client/parser/properties/union.py b/openapi_python_client/parser/properties/union.py index e03aacf6f..6ee629f37 100644 --- a/openapi_python_client/parser/properties/union.py +++ b/openapi_python_client/parser/properties/union.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, OrderedDict, cast from attr import define, evolve @@ -15,6 +15,27 @@ @define class DiscriminatorDefinition: + """Represents a discriminator that can optionally be specified for a union type. + + Normally, a UnionProperty has either zero or one of these. However, a nested union + could have more than one, as we accumulate all the discriminators when we flatten + out the nested schemas. For example: + + anyOf: + - anyOf: + - $ref: "#/components/schemas/Cat" + - $ref: "#/components/schemas/Dog" + discriminator: + propertyName: mammalType + - anyOf: + - $ref: "#/components/schemas/Condor" + - $ref: "#/components/schemas/Chicken" + discriminator: + propertyName: birdType + + In this example there are four schemas and two discriminators. The deserializer + logic will check for the mammalType property first, then birdType. + """ property_name: str value_to_model_map: dict[str, PropertyProtocol] # Every value in the map is really a ModelProperty, but this avoids circular imports @@ -75,7 +96,7 @@ def build( return PropertyError(detail=f"Invalid property in union {name}", data=sub_prop_data), schemas sub_properties.append(sub_prop) - sub_properties, discriminators_list = _flatten_union_properties(sub_properties) + sub_properties, discriminators_from_nested_unions = _flatten_union_properties(sub_properties) prop = UnionProperty( name=name, @@ -92,15 +113,14 @@ def build( return default_or_error, schemas prop = evolve(prop, default=default_or_error) + all_discriminators = discriminators_from_nested_unions if data.discriminator: discriminator_or_error = _parse_discriminator(data.discriminator, sub_properties, schemas) if isinstance(discriminator_or_error, PropertyError): return discriminator_or_error, schemas - discriminators_list = [discriminator_or_error, *discriminators_list] - if discriminators_list: - if error := _validate_discriminators(discriminators_list): - return error, schemas - prop = evolve(prop, discriminators=discriminators_list) + all_discriminators = [discriminator_or_error, *all_discriminators] + if all_discriminators: + prop = evolve(prop, discriminators=all_discriminators) return prop, schemas @@ -227,15 +247,33 @@ def _parse_discriminator( # See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object - def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None: - # This is needed because, when we built the union list, $refs were changed into a copy of - # the type they referred to, without preserving the original name. We need to know that - # every type in the discriminator is a $ref to a top-level type and we need its name. - for prop in schemas.classes_by_reference.values(): - if isinstance(prop, ModelProperty): - if prop.class_info == matching_model.class_info: - return prop - return None + # Conditions that must be true when there is a discriminator: + # 1. Every type in the anyOf/oneOf list must be a $ref to a named schema, such as + # #/components/schemas/X, rather than an inline schema. This is important because + # we may need to use the schema's simple name (X). + # 2. There must be a propertyName, representing a property that exists in every + # schema in that list (although we can't currently enforce the latter condition, + # because those properties haven't been parsed yet at this point.) + # + # There *may* also be a mapping of lookup values (the possible values of the property) + # to schemas. Schemas can be referenced either by a full path or a name: + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB" + # + # For any type that isn't specified in the mapping (or if the whole mapping is omitted) + # the default lookup value for each schema is the same as the schema name. So this-- + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # --is exactly equivalent to this: + # discriminator: + # propertyName: modelType + # mapping: + # value_for_a: "#/components/schemas/ModelA" + # ModelB: "#/components/schemas/ModelB" + + def _get_model_name(model: ModelProperty) -> str | None: + return get_reference_simple_name(model.ref_path) if model.ref_path else None model_types_by_name: dict[str, PropertyProtocol] = {} for model in subtypes: @@ -245,59 +283,32 @@ def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None return PropertyError( detail="All schema variants must be objects when using a discriminator", ) - top_level_model = _find_top_level_model(model) - if not top_level_model: + name = _get_model_name(model) + if not name: return PropertyError( detail="Inline schema declarations are not allowed when using a discriminator", ) - name = top_level_model.name - if name.startswith("/components/schemas/"): - name = get_reference_simple_name(name) - model_types_by_name[name] = top_level_model - - # The discriminator can specify an explicit mapping of values to types, but it doesn't - # have to; the default behavior is that the value for each type is simply its name. - mapping: dict[str, PropertyProtocol] = model_types_by_name.copy() + model_types_by_name[name] = model + + mapping: dict[str, PropertyProtocol] = OrderedDict() # use ordered dict for test determinacy + unspecified_models = list(model_types_by_name.values()) if data.mapping: for discriminator_value, model_ref in data.mapping.items(): - ref_path = parse_reference_path( - model_ref if model_ref.startswith("#/components/schemas/") else f"#/components/schemas/{model_ref}" - ) - if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference: - return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping') - name = get_reference_simple_name(ref_path) - if not (lookup_model := model_types_by_name.get(name)): + if "/" in model_ref: + ref_path = parse_reference_path(model_ref) + if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference: + return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping') + name = get_reference_simple_name(ref_path) + else: + name = model_ref + model = model_types_by_name.get(name) + if not model: return PropertyError( - detail=f'Discriminator mapping referred to "{model_ref}" which is not one of the schema variants', + detail=f'Discriminator mapping referred to "{name}" which is not one of the schema variants', ) - for original_value in (name for name, m in model_types_by_name.items() if m == lookup_model): - mapping.pop(original_value) - mapping[discriminator_value] = lookup_model - else: - mapping = model_types_by_name - + mapping[discriminator_value] = model + unspecified_models.remove(model) + for model in unspecified_models: + if name := _get_model_name(model): + mapping[name] = model return DiscriminatorDefinition(property_name=data.propertyName, value_to_model_map=mapping) - - -def _validate_discriminators( - discriminators: list[DiscriminatorDefinition], -) -> PropertyError | None: - from .model_property import ModelProperty - - prop_names_values_classes = [ - (discriminator.property_name, key, cast(ModelProperty, model).class_info.name) - for discriminator in discriminators - for key, model in discriminator.value_to_model_map.items() - ] - for p, v in {(p, v) for p, v, _ in prop_names_values_classes}: - if len({c for p1, v1, c in prop_names_values_classes if (p1, v1) == (p, v)}) > 1: - return PropertyError(f'Discriminator property "{p}" had more than one schema for value "{v}"') - return None - - # TODO: We should also validate that property_name refers to a property that 1. exists, - # 2. is required, 3. is a string (in all of these models). However, currently we can't - # do that because, at the time this function is called, the ModelProperties within the - # union haven't yet been post-processed and so we don't have full information about - # their properties. To fix this, we may need to generalize the post-processing phase so - # that any Property type, not just ModelProperty, can say it needs post-processing; then - # we can defer _validate_discriminators till that phase. diff --git a/openapi_python_client/templates/property_templates/union_property.py.jinja b/openapi_python_client/templates/property_templates/union_property.py.jinja index 0a0b1d49f..89d6b6ffd 100644 --- a/openapi_python_client/templates/property_templates/union_property.py.jinja +++ b/openapi_python_client/templates/property_templates/union_property.py.jinja @@ -16,13 +16,13 @@ if not isinstance(data, dict): raise TypeError() if "{{ discriminator.property_name }}" in data: _discriminator_value = data["{{ discriminator.property_name }}"] - {% for model in discriminator.value_to_model_map.values() %} - def _parse_{{ model.python_name }}(data: object) -> {{ model.get_type_string() }}: -{{ construct_inner_property(model) | indent(12, True) }} + {% for value, model in discriminator.value_to_model_map.items() %} + def _parse_{{ loop.index }}(data: object) -> {{ model.get_type_string() }}: +{{ construct_inner_property(model) | indent(8, True) }} {% endfor %} _discriminator_mapping = { {% for value, model in discriminator.value_to_model_map.items() %} - "{{ value }}": _parse_{{ model.python_name }}, + "{{ value }}": _parse_{{ loop.index }}, {% endfor %} } if _parse_fn := _discriminator_mapping.get(_discriminator_value): diff --git a/tests/test_parser/test_properties/test_union.py b/tests/test_parser/test_properties/test_union.py index 27b3c77c5..b3305547b 100644 --- a/tests/test_parser/test_properties/test_union.py +++ b/tests/test_parser/test_properties/test_union.py @@ -10,7 +10,7 @@ from openapi_python_client.parser.properties.model_property import ModelProperty from openapi_python_client.parser.properties.property import Property from openapi_python_client.parser.properties.protocol import Value -from openapi_python_client.parser.properties.schemas import Class +from openapi_python_client.parser.properties.schemas import Class, ReferencePath from openapi_python_client.schema import DataType, ParameterLocation from tests.test_parser.test_properties.properties_test_helpers import assert_prop_error @@ -66,6 +66,7 @@ def _make_basic_model( ) assert isinstance(model, ModelProperty) if name: + model.ref_path = ReferencePath(f"/components/schemas/{name}") schemas = evolve( schemas, classes_by_reference={**schemas.classes_by_reference, f"/components/schemas/{name}": model} ) @@ -404,43 +405,3 @@ def test_discriminator_invalid_inline_schema_variant(config, string_property_fac name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config ) assert_prop_error(p, "Inline schema") - - -def test_conflicting_discriminator_mappings(config): - from openapi_python_client.parser.properties import Schemas, property_from_data - - schemas = Schemas() - props = {"type": oai.Schema.model_construct(type="string")} - model1, schemas = _make_basic_model("Model1", props, "type", schemas, config) - model2, schemas = _make_basic_model("Model2", props, "type", schemas, config) - model3, schemas = _make_basic_model("Model3", props, "type", schemas, config) - model4, schemas = _make_basic_model("Model4", props, "type", schemas, config) - data = oai.Schema.model_construct( - oneOf=[ - oai.Schema.model_construct( - oneOf=[ - oai.Reference(ref="#/components/schemas/Model1"), - oai.Reference(ref="#/components/schemas/Model2"), - ], - discriminator=oai.Discriminator.model_construct( - propertyName="type", - mapping={"a": "Model1", "b": "Model2"}, - ), - ), - oai.Schema.model_construct( - oneOf=[ - oai.Reference(ref="#/components/schemas/Model3"), - oai.Reference(ref="#/components/schemas/Model4"), - ], - discriminator=oai.Discriminator.model_construct( - propertyName="type", - mapping={"a": "Model3", "x": "Model4"}, - ), - ), - ], - ) - - p, schemas = property_from_data( - name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config - ) - assert_prop_error(p, '"type" had more than one schema for value "a"') From 29414518643acfaa72d97e21a17b64224cb0ef24 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Thu, 31 Oct 2024 10:17:53 -0700 Subject: [PATCH 3/6] remove new category of tests for now --- end_to_end_tests/__init__.py | 3 --- end_to_end_tests/end_to_end_live_tests.py | 33 ----------------------- end_to_end_tests/test_end_to_end.py | 18 +++---------- 3 files changed, 3 insertions(+), 51 deletions(-) delete mode 100644 end_to_end_tests/end_to_end_live_tests.py diff --git a/end_to_end_tests/__init__.py b/end_to_end_tests/__init__.py index b91590c3b..1bf33f63f 100644 --- a/end_to_end_tests/__init__.py +++ b/end_to_end_tests/__init__.py @@ -1,4 +1 @@ """ Generate a complete client and verify that it is correct """ -import pytest - -pytest.register_assert_rewrite('end_to_end_tests.end_to_end_live_tests') diff --git a/end_to_end_tests/end_to_end_live_tests.py b/end_to_end_tests/end_to_end_live_tests.py deleted file mode 100644 index df82a8354..000000000 --- a/end_to_end_tests/end_to_end_live_tests.py +++ /dev/null @@ -1,33 +0,0 @@ -import importlib -from typing import Any - -import pytest - - -def live_tests_3_x(): - _test_model_with_discriminated_union() - - -def _import_model(module_name, class_name: str) -> Any: - module = importlib.import_module(f"my_test_api_client.models.{module_name}") - module = importlib.reload(module) # avoid test contamination from previous import - return getattr(module, class_name) - - -def _test_model_with_discriminated_union(): - ModelType1Class = _import_model("a_discriminated_union_type_1", "ADiscriminatedUnionType1") - ModelType2Class = _import_model("a_discriminated_union_type_2", "ADiscriminatedUnionType2") - ModelClass = _import_model("model_with_discriminated_union", "ModelWithDiscriminatedUnion") - - assert ( - ModelClass.from_dict({"discriminated_union": {"modelType": "type1"}}) == - ModelClass(discriminated_union=ModelType1Class.from_dict({"modelType": "type1"})) - ) - assert ( - ModelClass.from_dict({"discriminated_union": {"modelType": "type2"}}) == - ModelClass(discriminated_union=ModelType2Class.from_dict({"modelType": "type2"})) - ) - with pytest.raises(TypeError): - ModelClass.from_dict({"discriminated_union": {"modelType": "type3"}}) - with pytest.raises(TypeError): - ModelClass.from_dict({"discriminated_union": {}}) diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index 46755e6ca..a448a0698 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -1,17 +1,13 @@ -import os import shutil from filecmp import cmpfiles, dircmp from pathlib import Path -import sys -from typing import Callable, Dict, List, Optional, Set +from typing import Dict, List, Optional, Set import pytest from click.testing import Result from typer.testing import CliRunner from openapi_python_client.cli import app -from .end_to_end_live_tests import live_tests_3_x - def _compare_directories( @@ -87,7 +83,6 @@ def run_e2e_test( golden_record_path: str = "golden-record", output_path: str = "my-test-api-client", expected_missing: Optional[Set[str]] = None, - live_tests: Optional[Callable[[str], None]] = None, ) -> Result: output_path = Path.cwd() / output_path shutil.rmtree(output_path, ignore_errors=True) @@ -102,13 +97,6 @@ def run_e2e_test( _compare_directories( gr_path, output_path, expected_differences=expected_differences, expected_missing=expected_missing ) - if live_tests: - old_path = sys.path.copy() - sys.path.insert(0, str(output_path)) - try: - live_tests() - finally: - sys.path = old_path import mypy.api @@ -143,11 +131,11 @@ def _run_command(command: str, extra_args: Optional[List[str]] = None, openapi_d def test_baseline_end_to_end_3_0(): - run_e2e_test("baseline_openapi_3.0.json", [], {}, live_tests=live_tests_3_x) + run_e2e_test("baseline_openapi_3.0.json", [], {}) def test_baseline_end_to_end_3_1(): - run_e2e_test("baseline_openapi_3.1.yaml", [], {}, live_tests=live_tests_3_x) + run_e2e_test("baseline_openapi_3.1.yaml", [], {}) def test_3_1_specific_features(): From d189f771180f189555e7109aee24da871d548724 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Tue, 5 Nov 2024 13:29:29 -0800 Subject: [PATCH 4/6] lint --- .../parser/properties/union.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/openapi_python_client/parser/properties/union.py b/openapi_python_client/parser/properties/union.py index 6ee629f37..6b758993e 100644 --- a/openapi_python_client/parser/properties/union.py +++ b/openapi_python_client/parser/properties/union.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, ClassVar, OrderedDict, cast +from typing import Any, ClassVar, Mapping, OrderedDict, cast from attr import define, evolve @@ -16,7 +16,7 @@ @define class DiscriminatorDefinition: """Represents a discriminator that can optionally be specified for a union type. - + Normally, a UnionProperty has either zero or one of these. However, a nested union could have more than one, as we accumulate all the discriminators when we flatten out the nested schemas. For example: @@ -36,8 +36,9 @@ class DiscriminatorDefinition: In this example there are four schemas and two discriminators. The deserializer logic will check for the mammalType property first, then birdType. """ + property_name: str - value_to_model_map: dict[str, PropertyProtocol] + value_to_model_map: Mapping[str, PropertyProtocol] # Every value in the map is really a ModelProperty, but this avoids circular imports @@ -260,7 +261,7 @@ def _parse_discriminator( # mapping: # value_for_a: "#/components/schemas/ModelA" # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB" - # + # # For any type that isn't specified in the mapping (or if the whole mapping is omitted) # the default lookup value for each schema is the same as the schema name. So this-- # mapping: @@ -275,7 +276,7 @@ def _parse_discriminator( def _get_model_name(model: ModelProperty) -> str | None: return get_reference_simple_name(model.ref_path) if model.ref_path else None - model_types_by_name: dict[str, PropertyProtocol] = {} + model_types_by_name: dict[str, ModelProperty] = {} for model in subtypes: # Note, model here can never be a UnionProperty, because we've already done # flatten_union_properties() before this point. @@ -290,7 +291,7 @@ def _get_model_name(model: ModelProperty) -> str | None: ) model_types_by_name[name] = model - mapping: dict[str, PropertyProtocol] = OrderedDict() # use ordered dict for test determinacy + mapping: dict[str, ModelProperty] = OrderedDict() # use ordered dict for test determinacy unspecified_models = list(model_types_by_name.values()) if data.mapping: for discriminator_value, model_ref in data.mapping.items(): @@ -301,13 +302,13 @@ def _get_model_name(model: ModelProperty) -> str | None: name = get_reference_simple_name(ref_path) else: name = model_ref - model = model_types_by_name.get(name) - if not model: + mapped_model = model_types_by_name.get(name) + if not mapped_model: return PropertyError( detail=f'Discriminator mapping referred to "{name}" which is not one of the schema variants', ) - mapping[discriminator_value] = model - unspecified_models.remove(model) + mapping[discriminator_value] = mapped_model + unspecified_models.remove(mapped_model) for model in unspecified_models: if name := _get_model_name(model): mapping[name] = model From de1ddf33667318aad595227492109fed6436bb70 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Tue, 5 Nov 2024 14:22:56 -0800 Subject: [PATCH 5/6] lint --- tests/test_parser/test_properties/test_protocol.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_parser/test_properties/test_protocol.py b/tests/test_parser/test_properties/test_protocol.py index 1d4111750..5c2f1c993 100644 --- a/tests/test_parser/test_properties/test_protocol.py +++ b/tests/test_parser/test_properties/test_protocol.py @@ -3,6 +3,7 @@ import pytest from openapi_python_client.parser.properties.protocol import Value +from openapi_python_client.parser.properties.schemas import ReferencePath def test_is_base_type(any_property_factory): @@ -85,3 +86,12 @@ def test_get_base_json_type_string(quoted, expected, any_property_factory, mocke mocker.patch.object(AnyProperty, "_json_type_string", "str") p = any_property_factory() assert p.get_base_json_type_string(quoted=quoted) is expected + + +def test_ref_path(any_property_factory, model_property_factory): + p1 = any_property_factory() + assert p1.get_ref_path() is None + + path = ReferencePath("/components/schemas/A") + p2 = model_property_factory(ref_path=path) + assert p2.get_ref_path() == path From 9071de05a4158ff397f74ac829ba3eb9f11daf62 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Wed, 6 Nov 2024 13:21:14 -0800 Subject: [PATCH 6/6] handle a case where there's multiple values mapped to same type --- end_to_end_tests/baseline_openapi_3.0.json | 3 ++- end_to_end_tests/baseline_openapi_3.1.yaml | 3 ++- .../models/model_with_discriminated_union.py | 8 ++++++++ openapi_python_client/parser/properties/union.py | 4 +++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index 1f3232e69..3da55d2df 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -2823,7 +2823,8 @@ "propertyName": "modelType", "mapping": { "type1": "#/components/schemas/ADiscriminatedUnionType1", - "type2": "#/components/schemas/ADiscriminatedUnionType2" + "type2": "#/components/schemas/ADiscriminatedUnionType2", + "type2-another-value": "#/components/schemas/ADiscriminatedUnionType2" } }, "oneOf": [ diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index 373660695..746ea91ef 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -2817,7 +2817,8 @@ info: "propertyName": "modelType", "mapping": { "type1": "#/components/schemas/ADiscriminatedUnionType1", - "type2": "#/components/schemas/ADiscriminatedUnionType2" + "type2": "#/components/schemas/ADiscriminatedUnionType2", + "type2-another-value": "#/components/schemas/ADiscriminatedUnionType2" } }, "oneOf": [ diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py index 1a0918332..6f37fe454 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py @@ -78,9 +78,17 @@ def _parse_2(data: object) -> ADiscriminatedUnionType2: return componentsschemas_a_discriminated_union_type_1 + def _parse_3(data: object) -> ADiscriminatedUnionType2: + if not isinstance(data, dict): + raise TypeError() + componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data) + + return componentsschemas_a_discriminated_union_type_1 + _discriminator_mapping = { "type1": _parse_1, "type2": _parse_2, + "type2-another-value": _parse_3, } if _parse_fn := _discriminator_mapping.get(_discriminator_value): return cast( diff --git a/openapi_python_client/parser/properties/union.py b/openapi_python_client/parser/properties/union.py index 6b758993e..efa48eda2 100644 --- a/openapi_python_client/parser/properties/union.py +++ b/openapi_python_client/parser/properties/union.py @@ -308,7 +308,9 @@ def _get_model_name(model: ModelProperty) -> str | None: detail=f'Discriminator mapping referred to "{name}" which is not one of the schema variants', ) mapping[discriminator_value] = mapped_model - unspecified_models.remove(mapped_model) + if mapped_model in unspecified_models: + # could've already been removed if more than one value is mapped to the same model + unspecified_models.remove(mapped_model) for model in unspecified_models: if name := _get_model_name(model): mapping[name] = model