From 850df2e572dd84428ae3181a3c15ec4bffff2c10 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Thu, 19 Sep 2024 17:19:16 -0700 Subject: [PATCH 1/4] fix: use correct class name when overriding model property --- .changeset/fix-model-override.md | 7 ++ .../parser/properties/__init__.py | 2 +- .../parser/properties/merge_properties.py | 38 ++++++ .../parser/properties/model_property.py | 80 ++++++++----- .../parser/properties/protocol.py | 4 + tests/conftest.py | 21 +++- .../test_properties/test_merge_properties.py | 110 +++++++++++++++++- .../test_properties/test_model_property.py | 21 +++- 8 files changed, 241 insertions(+), 42 deletions(-) create mode 100644 .changeset/fix-model-override.md diff --git a/.changeset/fix-model-override.md b/.changeset/fix-model-override.md new file mode 100644 index 000000000..7b80e1448 --- /dev/null +++ b/.changeset/fix-model-override.md @@ -0,0 +1,7 @@ +--- +default: patch +--- + +# Fix overriding of object property class + +Fixed issue #1121, where redefining an object property within an `allOf` would not use the correct class name if the property's type was changed from one object type to another. diff --git a/openapi_python_client/parser/properties/__init__.py b/openapi_python_client/parser/properties/__init__.py index 02a6fdafe..d35edfa61 100644 --- a/openapi_python_client/parser/properties/__init__.py +++ b/openapi_python_client/parser/properties/__init__.py @@ -125,7 +125,7 @@ def _property_from_ref( return prop, schemas -def property_from_data( # noqa: PLR0911, PLR0912 +def property_from_data( # noqa: PLR0911 name: str, required: bool, data: oai.Reference | oai.Schema, diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index dc7b3e5eb..0dc98d0fd 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -3,6 +3,7 @@ from openapi_python_client.parser.properties.date import DateProperty from openapi_python_client.parser.properties.datetime import DateTimeProperty from openapi_python_client.parser.properties.file import FileProperty +from openapi_python_client.parser.properties.model_property import ModelProperty __all__ = ["merge_properties"] @@ -75,6 +76,9 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop # It's always OK to redefine a property with everything exactly the same return prop1 + if isinstance(prop1, ModelProperty) and isinstance(prop2, ModelProperty): + return _merge_models(prop1, prop2) + if isinstance(prop1, ListProperty) and isinstance(prop2, ListProperty): inner_property = merge_properties(prop1.inner_property, prop2.inner_property) # type: ignore if isinstance(inner_property, PropertyError): @@ -86,6 +90,24 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop return _merge_common_attributes(prop1, prop2) +def _merge_models(prop1: ModelProperty, prop2: ModelProperty) -> Property | PropertyError: + # Ideally, we would treat this case the same as a schema that consisted of "allOf: [prop1, prop2]", + # applying the property merge logic recursively and creating a new third schema if the result could + # not be fully described by one or the other. But for now we will just handle the common case where + # B is an object type that extends A and fully includes it, with no changes to any of A's properties; + # in that case, it is valid to just reuse the model class for B. + for prop in [prop1, prop2]: + if prop.needs_processing(): + return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop) + if _model_is_extension_of(prop1, prop2): + extended_model = prop1 + elif _model_is_extension_of(prop2, prop1): + extended_model = prop2 + else: + return PropertyError(detail="unable to merge two unrelated object types for this property") + return _merge_common_attributes(extended_model, prop1, prop2) + + def _merge_string_with_format(prop1: Property, prop2: Property) -> Property | None | PropertyError: """Merge a string that has no format with a string that has a format""" # Here we need to use the DateProperty/DateTimeProperty/FileProperty as the base so that we preserve @@ -166,3 +188,19 @@ def _merge_common_attributes(base: PropertyT, *extend_with: PropertyProtocol) -> def _values_are_subset(prop1: EnumProperty, prop2: EnumProperty) -> bool: return set(prop1.values.items()) <= set(prop2.values.items()) + + +def _model_is_extension_of(extended_model: ModelProperty, base_model: ModelProperty) -> bool: + def _list_is_extension_of(extended_list: list[Property], base_list: list[Property]) -> bool: + for p2 in base_list: + if not [p1 for p1 in extended_list if _property_is_extension_of(p2, p1)]: + return False + return True + + return _list_is_extension_of( + extended_model.required_properties, base_model.required_properties + ) and _list_is_extension_of(extended_model.optional_properties, base_model.optional_properties) + + +def _property_is_extension_of(extended_prop: PropertyProtocol, base_prop: PropertyProtocol) -> bool: + return base_prop.name == extended_prop.name and merge_properties(base_prop, extended_prop) == extended_prop diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 0dc13be54..3c5da7c60 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from itertools import chain from typing import Any, ClassVar, NamedTuple @@ -14,6 +15,15 @@ from .schemas import Class, ReferencePath, Schemas, parse_reference_path +@dataclass +class ModelDetails: + required_properties: list[Property] | None = None + optional_properties: list[Property] | None = None + additional_properties: Property | None = None + relative_imports: set[str] | None = None + lazy_imports: set[str] | None = None + + @define class ModelProperty(PropertyProtocol): """A property which refers to another Schema""" @@ -27,11 +37,7 @@ class ModelProperty(PropertyProtocol): data: oai.Schema description: str roots: set[ReferencePath | utils.ClassName] - required_properties: list[Property] | None - optional_properties: list[Property] | None - relative_imports: set[str] | None - lazy_imports: set[str] | None - additional_properties: Property | None + details: ModelDetails _json_type_string: ClassVar[str] = "Dict[str, Any]" template: ClassVar[str] = "model_property.py.jinja" @@ -75,22 +81,18 @@ def build( class_string = title class_info = Class.from_string(string=class_string, config=config) model_roots = {*roots, class_info.name} - required_properties: list[Property] | None = None - optional_properties: list[Property] | None = None - relative_imports: set[str] | None = None - lazy_imports: set[str] | None = None - additional_properties: Property | None = None + details = ModelDetails() if process_properties: data_or_err, schemas = _process_property_data( data=data, schemas=schemas, class_info=class_info, config=config, roots=model_roots ) if isinstance(data_or_err, PropertyError): return data_or_err, schemas - property_data, additional_properties = data_or_err - required_properties = property_data.required_props - optional_properties = property_data.optional_props - relative_imports = property_data.relative_imports - lazy_imports = property_data.lazy_imports + property_data, details.additional_properties = data_or_err + details.required_properties = property_data.required_props + details.optional_properties = property_data.optional_props + details.relative_imports = property_data.relative_imports + details.lazy_imports = property_data.lazy_imports for root in roots: if isinstance(root, utils.ClassName): continue @@ -100,11 +102,7 @@ def build( class_info=class_info, data=data, roots=model_roots, - required_properties=required_properties, - optional_properties=optional_properties, - relative_imports=relative_imports, - lazy_imports=lazy_imports, - additional_properties=additional_properties, + details=details, description=data.description or "", default=None, required=required, @@ -125,6 +123,31 @@ def build( ) return prop, schemas + def needs_processing(self) -> bool: + return not ( + isinstance(self.details.required_properties, list) and isinstance(self.details.optional_properties, list) + ) + + @property + def required_properties(self) -> list[Property]: + return self.details.required_properties or [] + + @property + def optional_properties(self) -> list[Property]: + return self.details.optional_properties or [] + + @property + def additional_properties(self) -> Property | None: + return self.details.additional_properties + + @property + def relative_imports(self) -> set[str]: + return self.details.relative_imports or set() + + @property + def lazy_imports(self) -> set[str] | None: + return self.details.lazy_imports or set() + @classmethod def convert_value(cls, value: Any) -> Value | None | PropertyError: if value is not None: @@ -132,7 +155,7 @@ def convert_value(cls, value: Any) -> Value | None | PropertyError: return None def __attrs_post_init__(self) -> None: - if self.relative_imports: + if self.details.relative_imports: self.set_relative_imports(self.relative_imports) @property @@ -175,7 +198,7 @@ def set_relative_imports(self, relative_imports: set[str]) -> None: Args: relative_imports: The set of relative import strings """ - object.__setattr__(self, "relative_imports", {ri for ri in relative_imports if self.self_import not in ri}) + self.details.relative_imports = {ri for ri in relative_imports if self.self_import not in ri} def set_lazy_imports(self, lazy_imports: set[str]) -> None: """Set the lazy imports set for this ModelProperty, filtering out self imports @@ -183,7 +206,7 @@ def set_lazy_imports(self, lazy_imports: set[str]) -> None: Args: lazy_imports: The set of lazy import strings """ - object.__setattr__(self, "lazy_imports", {li for li in lazy_imports if self.self_import not in li}) + self.details.lazy_imports = {li for li in lazy_imports if self.self_import not in li} def get_type_string( self, @@ -289,9 +312,7 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if not isinstance(sub_model, ModelProperty): return PropertyError("Cannot take allOf a non-object") # Properties of allOf references first should be processed first - if not ( - isinstance(sub_model.required_properties, list) and isinstance(sub_model.optional_properties, list) - ): + if sub_model.needs_processing(): return PropertyError(f"Reference {sub_model.name} in allOf was not processed", data=sub_prop) for prop in chain(sub_model.required_properties, sub_model.optional_properties): err = _add_if_no_conflict(prop) @@ -437,9 +458,10 @@ def process_model(model_prop: ModelProperty, *, schemas: Schemas, config: Config property_data, additional_properties = data_or_err - object.__setattr__(model_prop, "required_properties", property_data.required_props) - object.__setattr__(model_prop, "optional_properties", property_data.optional_props) + model_prop.details.required_properties = property_data.required_props + model_prop.details.optional_properties = property_data.optional_props + model_prop.details.additional_properties = additional_properties model_prop.set_relative_imports(property_data.relative_imports) model_prop.set_lazy_imports(property_data.lazy_imports) - object.__setattr__(model_prop, "additional_properties", additional_properties) + return schemas diff --git a/openapi_python_client/parser/properties/protocol.py b/openapi_python_client/parser/properties/protocol.py index c9555949d..5fe02387f 100644 --- a/openapi_python_client/parser/properties/protocol.py +++ b/openapi_python_client/parser/properties/protocol.py @@ -185,3 +185,7 @@ def is_base_type(self) -> bool: ListProperty.__name__, UnionProperty.__name__, } + + def needs_processing(self) -> bool: + """Returns true if the parser should call process_model() on this property in a second pass.""" + return False diff --git a/tests/conftest.py b/tests/conftest.py index c01b4ce87..c9970cf3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ UnionProperty, ) from openapi_python_client.parser.properties.float import FloatProperty +from openapi_python_client.parser.properties.model_property import ModelDetails from openapi_python_client.parser.properties.protocol import PropertyType, Value from openapi_python_client.schema.openapi_schema_pydantic import Parameter from openapi_python_client.schema.parameter_location import ParameterLocation @@ -64,15 +65,25 @@ def _factory(**kwargs): "class_info": Class(name=ClassName("MyClass", ""), module_name=PythonIdentifier("my_module", "")), "data": oai.Schema.model_construct(), "roots": set(), - "required_properties": None, - "optional_properties": None, - "relative_imports": None, - "lazy_imports": None, - "additional_properties": None, "python_name": "", "example": "", **kwargs, } + # shortcuts for setting attributes within ModelDetails + if "details" not in kwargs: + detail_args = {} + for arg_name in [ + "required_properties", + "optional_properties", + "additional_properties", + "relative_imports", + "lazy_imports", + ]: + if arg_name in kwargs: + detail_args[arg_name] = kwargs[arg_name] + kwargs.pop(arg_name) + kwargs["details"] = ModelDetails(**detail_args) + return ModelProperty(**kwargs) return _factory diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 12ddb79fa..3919732c0 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -6,6 +6,7 @@ from openapi_python_client.parser.properties.float import FloatProperty from openapi_python_client.parser.properties.int import IntProperty from openapi_python_client.parser.properties.merge_properties import merge_properties +from openapi_python_client.parser.properties.model_property import ModelDetails from openapi_python_client.parser.properties.protocol import Value from openapi_python_client.parser.properties.schemas import Class from openapi_python_client.parser.properties.string import StringProperty @@ -27,7 +28,7 @@ def test_merge_basic_attributes_same_type( float_property_factory(default=Value("1.5", 1.5)), string_property_factory(default=StringProperty.convert_value("x")), list_property_factory(), - model_property_factory(), + model_property_factory(required_properties=[], optional_properties=[]), ] for basic_prop in basic_props: with_required = evolve(basic_prop, required=True) @@ -237,3 +238,110 @@ def test_merge_lists(int_property_factory, list_property_factory, string_propert ) assert isinstance(merge_properties(list_prop_1, list_prop_3), PropertyError) + + +def test_merge_related_models(model_property_factory, string_property_factory, config): + base_model = model_property_factory( + name="BaseModel", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", description="base description"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + ], + ), + description="desc_1", + example="example_1", + class_info=Class.from_string(string="BaseModel", config=config), + ) + derived_model = model_property_factory( + name="DerivedModel", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", description="extended description"), + string_property_factory(name="req_2"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + string_property_factory(name="opt_2"), + ], + ), + description="desc_2", + class_info=Class.from_string(string="DerivedModel", config=config), + ) + + assert merge_properties(base_model, derived_model) == evolve(derived_model, example=base_model.example) + assert merge_properties(derived_model, base_model) == evolve( + derived_model, description=base_model.description, example=base_model.example + ) + + +def test_merge_models_fails_for_unrelated_required_property(model_property_factory, string_property_factory, config): + model_1 = model_property_factory( + name="model_1", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + ], + ), + description="desc_1", + example="example_1", + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="model_2", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_2"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + string_property_factory(name="opt_2"), + ], + ), + description="desc_2", + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2) + assert isinstance(result, PropertyError) + assert result.detail == "unable to merge two unrelated object types for this property" + + +def test_merge_models_fails_for_unrelated_optional_property(model_property_factory, string_property_factory, config): + model_1 = model_property_factory( + name="model_1", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + ], + ), + description="desc_1", + example="example_1", + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="model_2", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1"), + string_property_factory(name="req_2"), + ], + optional_properties=[ + string_property_factory(name="opt_2"), + ], + ), + description="desc_2", + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2) + assert isinstance(result, PropertyError) + assert result.detail == "unable to merge two unrelated object types for this property" diff --git a/tests/test_parser/test_properties/test_model_property.py b/tests/test_parser/test_properties/test_model_property.py index 8adc88e39..68eda7141 100644 --- a/tests/test_parser/test_properties/test_model_property.py +++ b/tests/test_parser/test_properties/test_model_property.py @@ -6,7 +6,12 @@ import openapi_python_client.schema as oai from openapi_python_client.parser.errors import PropertyError from openapi_python_client.parser.properties import Schemas, StringProperty -from openapi_python_client.parser.properties.model_property import ANY_ADDITIONAL_PROPERTY, _process_properties +from openapi_python_client.parser.properties.model_property import ( + ANY_ADDITIONAL_PROPERTY, + ModelDetails, + ModelProperty, + _process_properties, +) MODULE_NAME = "openapi_python_client.parser.properties.model_property" @@ -675,7 +680,7 @@ def test_process_model_error(self, mocker, model_property_factory, config): from openapi_python_client.parser.properties import Schemas from openapi_python_client.parser.properties.model_property import process_model - model_prop = model_property_factory() + model_prop: ModelProperty = model_property_factory(details=ModelDetails()) schemas = Schemas() process_property_data = mocker.patch(f"{MODULE_NAME}._process_property_data") process_property_data.return_value = (PropertyError(), schemas) @@ -683,9 +688,10 @@ def test_process_model_error(self, mocker, model_property_factory, config): result = process_model(model_prop=model_prop, schemas=schemas, config=config) assert result == PropertyError() - assert model_prop.required_properties is None - assert model_prop.optional_properties is None - assert model_prop.relative_imports is None + assert model_prop.needs_processing() + assert model_prop.required_properties == [] + assert model_prop.optional_properties == [] + assert model_prop.relative_imports == set() assert model_prop.additional_properties is None def test_process_model(self, mocker, model_property_factory, config): @@ -721,6 +727,9 @@ def test_set_relative_imports(model_property_factory): class_info = Class("ClassName", module_name="module_name") relative_imports = {"from typing import List", f"from ..models.{class_info.module_name} import {class_info.name}"} - model_property = model_property_factory(class_info=class_info, relative_imports=relative_imports) + model_property = model_property_factory( + class_info=class_info, + details=ModelDetails(relative_imports=relative_imports), + ) assert model_property.relative_imports == {"from typing import List"} From 5cd16ac385d2abc403a32c6af1da9874ea5557c6 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Thu, 10 Oct 2024 13:15:52 -0700 Subject: [PATCH 2/4] misc implementation changes --- .changeset/fix-model-override.md | 4 +- .../parser/properties/merge_properties.py | 98 +++++++++++++++---- .../parser/properties/model_property.py | 17 ++-- .../test_properties/test_merge_properties.py | 81 ++++++++------- 4 files changed, 133 insertions(+), 67 deletions(-) diff --git a/.changeset/fix-model-override.md b/.changeset/fix-model-override.md index 7b80e1448..d4ea23a05 100644 --- a/.changeset/fix-model-override.md +++ b/.changeset/fix-model-override.md @@ -4,4 +4,6 @@ default: patch # Fix overriding of object property class -Fixed issue #1121, where redefining an object property within an `allOf` would not use the correct class name if the property's type was changed from one object type to another. +Fixed issue #1123, in which a property could end up with the wrong type when combining two object schemas with `allOf`, if the type of the property was itself an object but had a different schema in each. Previously, if the property's type was A in the first schema and B in the second, the resulting schema would use type A for the property. + +The new behavior is, that the generator will test whether one of the types A/B is derived from the other. "Derived" here means that the result of `allOf[A, B]` would be exactly identical to B. If so, it will use the class name of B. If not, it will attempt to merge A and B with the usual `allOf` logic to create a new inline schema. diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index 0dc98d0fd..61af04365 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -1,9 +1,13 @@ from __future__ import annotations +from itertools import chain +from openapi_python_client import utils +from openapi_python_client.config import Config from openapi_python_client.parser.properties.date import DateProperty from openapi_python_client.parser.properties.datetime import DateTimeProperty from openapi_python_client.parser.properties.file import FileProperty -from openapi_python_client.parser.properties.model_property import ModelProperty +from openapi_python_client.parser.properties.model_property import ModelDetails, ModelProperty, _gather_property_data +from openapi_python_client.parser.properties.schemas import Class, Schemas __all__ = ["merge_properties"] @@ -27,7 +31,12 @@ STRING_WITH_FORMAT_TYPES = (DateProperty, DateTimeProperty, FileProperty) -def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyError: # noqa: PLR0911 +def merge_properties( + prop1: Property, + prop2: Property, + parent_name: str, + config: Config, +) -> Property | PropertyError: # noqa: PLR0911 """Attempt to create a new property that incorporates the behavior of both. This is used when merging schemas with allOf, when two schemas define a property with the same name. @@ -54,7 +63,7 @@ def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyErr if isinstance(prop1, EnumProperty) or isinstance(prop2, EnumProperty): return _merge_with_enum(prop1, prop2) - if (merged := _merge_same_type(prop1, prop2)) is not None: + if (merged := _merge_same_type(prop1, prop2, parent_name, config)) is not None: return merged if (merged := _merge_numeric(prop1, prop2)) is not None: @@ -68,7 +77,7 @@ def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyErr ) -def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | PropertyError: +def _merge_same_type(prop1: Property, prop2: Property, parent_name: str, config: Config) -> Property | None | PropertyError: if type(prop1) is not type(prop2): return None @@ -77,10 +86,10 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop return prop1 if isinstance(prop1, ModelProperty) and isinstance(prop2, ModelProperty): - return _merge_models(prop1, prop2) + return _merge_models(prop1, prop2, parent_name, config) if isinstance(prop1, ListProperty) and isinstance(prop2, ListProperty): - inner_property = merge_properties(prop1.inner_property, prop2.inner_property) # type: ignore + inner_property = merge_properties(prop1.inner_property, prop2.inner_property, "", config) # type: ignore if isinstance(inner_property, PropertyError): return PropertyError(detail=f"can't merge list properties: {inner_property.detail}") prop1.inner_property = inner_property @@ -90,7 +99,7 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop return _merge_common_attributes(prop1, prop2) -def _merge_models(prop1: ModelProperty, prop2: ModelProperty) -> Property | PropertyError: +def _merge_models(prop1: ModelProperty, prop2: ModelProperty, parent_name: str, config: Config) -> Property | PropertyError: # Ideally, we would treat this case the same as a schema that consisted of "allOf: [prop1, prop2]", # applying the property merge logic recursively and creating a new third schema if the result could # not be fully described by one or the other. But for now we will just handle the common case where @@ -98,14 +107,59 @@ def _merge_models(prop1: ModelProperty, prop2: ModelProperty) -> Property | Prop # in that case, it is valid to just reuse the model class for B. for prop in [prop1, prop2]: if prop.needs_processing(): + # This means not all of the details of the schema have been filled in, possibly due to a + # forward reference. That may be resolved in a later pass, but for now we can't proceed. return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop) - if _model_is_extension_of(prop1, prop2): - extended_model = prop1 - elif _model_is_extension_of(prop2, prop1): - extended_model = prop2 - else: - return PropertyError(detail="unable to merge two unrelated object types for this property") - return _merge_common_attributes(extended_model, prop1, prop2) + + # Detect whether one of the schemas is derived from the other-- that is, if it is (or is equivalent + # to) the result of taking the other type and adding/modifying properties with allOf. If so, then + # we can simply use the class of the derived type. We will still call _merge_common_attributes in + # case any metadata like "description" has been modified. + if _model_is_extension_of(prop1, prop2, parent_name, config): + return _merge_common_attributes(prop1, prop2) + elif _model_is_extension_of(prop2, prop1, parent_name, config): + return _merge_common_attributes(prop2, prop1, prop2) + + # Neither of the schemas is a superset of the other, so merging them will result in a new type. + merged_props: dict[str, Property] = {p.name: p for p in chain(prop1.required_properties, prop1.optional_properties)} + for model in [prop1, prop2]: + for sub_prop in chain(model.required_properties, model.optional_properties): + if sub_prop.name in merged_props: + merged_prop = merge_properties(merged_props[sub_prop.name], sub_prop, parent_name, config) + if isinstance(merged_prop, PropertyError): + return merged_prop + merged_props[sub_prop.name] = merged_prop + else: + merged_props[sub_prop.name] = sub_prop + + prop_data = _gather_property_data(merged_props.values(), Schemas()) + + name = prop2.name + class_string = f"{utils.pascal_case(parent_name)}{utils.pascal_case(name)}" + class_info = Class.from_string(string=class_string, config=config) + roots = prop1.roots.union(prop2.roots).difference({prop1.class_info.name, prop2.class_info.name}) + roots.add(class_info.name) + prop_details = ModelDetails( + required_properties=prop_data.required_props, + optional_properties=prop_data.optional_props, + additional_properties=None, + relative_imports=prop_data.relative_imports, + lazy_imports=prop_data.lazy_imports, + ) + prop = ModelProperty( + class_info=class_info, + data=prop2.data, # TODO: not sure what this should be + roots=roots, + details=prop_details, + description=prop2.description or prop1.description, + default=None, + required=prop2.required or prop1.required, + name=name, + python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix), + example=prop2.example or prop1.example, + ) + + return prop def _merge_string_with_format(prop1: Property, prop2: Property) -> Property | None | PropertyError: @@ -190,17 +244,19 @@ def _values_are_subset(prop1: EnumProperty, prop2: EnumProperty) -> bool: return set(prop1.values.items()) <= set(prop2.values.items()) -def _model_is_extension_of(extended_model: ModelProperty, base_model: ModelProperty) -> bool: - def _list_is_extension_of(extended_list: list[Property], base_list: list[Property]) -> bool: +def _model_is_extension_of(extended_model: ModelProperty, base_model: ModelProperty, parent_name: str, config: Config) -> bool: + def _properties_are_extension_of(extended_list: list[Property], base_list: list[Property]) -> bool: for p2 in base_list: - if not [p1 for p1 in extended_list if _property_is_extension_of(p2, p1)]: + if not [p1 for p1 in extended_list if _property_is_extension_of(p2, p1, parent_name, config)]: return False return True - return _list_is_extension_of( + return _properties_are_extension_of( extended_model.required_properties, base_model.required_properties - ) and _list_is_extension_of(extended_model.optional_properties, base_model.optional_properties) + ) and _properties_are_extension_of(extended_model.optional_properties, base_model.optional_properties) -def _property_is_extension_of(extended_prop: PropertyProtocol, base_prop: PropertyProtocol) -> bool: - return base_prop.name == extended_prop.name and merge_properties(base_prop, extended_prop) == extended_prop +def _property_is_extension_of(extended_prop: PropertyProtocol, base_prop: PropertyProtocol, parent_name: str, config: Config) -> bool: + return base_prop.name == extended_prop.name and ( + base_prop == extended_prop or merge_properties(base_prop, extended_prop, parent_name, config) == extended_prop + ) diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 3c5da7c60..a8f0523dc 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from itertools import chain -from typing import Any, ClassVar, NamedTuple +from typing import Any, ClassVar, Iterable, NamedTuple from attrs import define, evolve @@ -261,7 +261,7 @@ class _PropertyData(NamedTuple): schemas: Schemas -def _process_properties( # noqa: PLR0912, PLR0911 +def _process_properties( # noqa: PLR0911 *, data: oai.Schema, schemas: Schemas, @@ -273,15 +273,13 @@ def _process_properties( # noqa: PLR0912, PLR0911 from .merge_properties import merge_properties properties: dict[str, Property] = {} - relative_imports: set[str] = set() - lazy_imports: set[str] = set() required_set = set(data.required or []) def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: nonlocal properties name_conflict = properties.get(new_prop.name) - merged_prop = merge_properties(name_conflict, new_prop) if name_conflict else new_prop + merged_prop = merge_properties(name_conflict, new_prop, class_name, config) if name_conflict else new_prop if isinstance(merged_prop, PropertyError): merged_prop.header = f"Found conflicting properties named {new_prop.name} when creating {class_name}" return merged_prop @@ -340,9 +338,15 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if isinstance(prop_or_error, PropertyError): return prop_or_error + return _gather_property_data(properties.values(), schemas) + + +def _gather_property_data(properties: Iterable[Property], schemas: Schemas) -> _PropertyData: required_properties = [] optional_properties = [] - for prop in properties.values(): + relative_imports: set[str] = set() + lazy_imports: set[str] = set() + for prop in properties: if prop.required: required_properties.append(prop) else: @@ -350,7 +354,6 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: lazy_imports.update(prop.get_lazy_imports(prefix="..")) relative_imports.update(prop.get_imports(prefix="..")) - return _PropertyData( optional_props=optional_properties, required_props=required_properties, diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 3919732c0..face86a24 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -21,6 +21,7 @@ def test_merge_basic_attributes_same_type( string_property_factory, list_property_factory, model_property_factory, + config, ): basic_props = [ boolean_property_factory(default=Value(python_code="True", raw_value="True")), @@ -32,16 +33,16 @@ def test_merge_basic_attributes_same_type( ] for basic_prop in basic_props: with_required = evolve(basic_prop, required=True) - assert merge_properties(basic_prop, with_required) == with_required - assert merge_properties(with_required, basic_prop) == with_required + assert merge_properties(basic_prop, with_required, "", config) == with_required + assert merge_properties(with_required, basic_prop, "", config) == with_required without_default = evolve(basic_prop, default=None) - assert merge_properties(basic_prop, without_default) == basic_prop - assert merge_properties(without_default, basic_prop) == basic_prop + assert merge_properties(basic_prop, without_default, "", config) == basic_prop + assert merge_properties(without_default, basic_prop, "", config) == basic_prop with_desc1 = evolve(basic_prop, description="desc1") with_desc2 = evolve(basic_prop, description="desc2") - assert merge_properties(basic_prop, with_desc1) == with_desc1 - assert merge_properties(with_desc1, basic_prop) == with_desc1 - assert merge_properties(with_desc1, with_desc2) == with_desc2 + assert merge_properties(basic_prop, with_desc1, "", config) == with_desc1 + assert merge_properties(with_desc1, basic_prop, "", config) == with_desc1 + assert merge_properties(with_desc1, with_desc2, "", config) == with_desc2 def test_incompatible_types( @@ -51,6 +52,7 @@ def test_incompatible_types( string_property_factory, list_property_factory, model_property_factory, + config, ): props = [ boolean_property_factory(default=True), @@ -64,21 +66,21 @@ def test_incompatible_types( for prop1, prop2 in permutations(props, 2): if {prop1.__class__, prop2.__class__} == {IntProperty, FloatProperty}: continue # the int+float case is covered in another test - error = merge_properties(prop1, prop2) + error = merge_properties(prop1, prop2, "", config) assert isinstance(error, PropertyError), f"Expected {type(prop1)} and {type(prop2)} to be incompatible" -def test_merge_int_with_float(int_property_factory, float_property_factory): +def test_merge_int_with_float(int_property_factory, float_property_factory, config): int_prop = int_property_factory(description="desc1") float_prop = float_property_factory(default=Value("2", 2), description="desc2") - assert merge_properties(int_prop, float_prop) == ( + assert merge_properties(int_prop, float_prop, "", config) == ( evolve(int_prop, default=Value("2", 2), description=float_prop.description) ) - assert merge_properties(float_prop, int_prop) == evolve(int_prop, default=Value("2", 2)) + assert merge_properties(float_prop, int_prop, "", config) == evolve(int_prop, default=Value("2", 2)) float_prop_with_non_int_default = evolve(float_prop, default=Value("2.5", 2.5)) - error = merge_properties(int_prop, float_prop_with_non_int_default) + error = merge_properties(int_prop, float_prop_with_non_int_default, "", config) assert isinstance(error, PropertyError), "Expected invalid default to error" assert error.detail == "Invalid int value: 2.5" @@ -90,6 +92,7 @@ def test_merge_with_any( float_property_factory, string_property_factory, model_property_factory, + config, ): original_desc = "description" props = [ @@ -101,8 +104,8 @@ def test_merge_with_any( ] any_prop = any_property_factory() for prop in props: - assert merge_properties(any_prop, prop) == prop - assert merge_properties(prop, any_prop) == prop + assert merge_properties(any_prop, prop, "", config) == prop + assert merge_properties(prop, any_prop, "", config) == prop def test_merge_enums(enum_property_factory, config): @@ -121,19 +124,19 @@ def test_merge_enums(enum_property_factory, config): enum_with_fewer_values.class_info = Class.from_string(string="FewerValuesEnum", config=config) enum_with_more_values.class_info = Class.from_string(string="MoreValuesEnum", config=config) - assert merge_properties(enum_with_fewer_values, enum_with_more_values) == evolve( + assert merge_properties(enum_with_fewer_values, enum_with_more_values, "", config) == evolve( enum_with_more_values, values=enum_with_fewer_values.values, class_info=enum_with_fewer_values.class_info, description=enum_with_fewer_values.description, ) - assert merge_properties(enum_with_more_values, enum_with_fewer_values) == evolve( + assert merge_properties(enum_with_more_values, enum_with_fewer_values, "", config) == evolve( enum_with_fewer_values, example=enum_with_more_values.example, ) -def test_merge_string_with_string_enum(string_property_factory, enum_property_factory): +def test_merge_string_with_string_enum(string_property_factory, enum_property_factory, config): values = {"A": "A", "B": "B"} string_prop = string_property_factory(default=Value("A", "A"), description="desc1", example="example1") enum_prop = enum_property_factory( @@ -144,8 +147,8 @@ def test_merge_string_with_string_enum(string_property_factory, enum_property_fa value_type=str, ) - assert merge_properties(string_prop, enum_prop) == evolve(enum_prop, required=True) - assert merge_properties(enum_prop, string_prop) == evolve( + assert merge_properties(string_prop, enum_prop, "", config) == evolve(enum_prop, required=True) + assert merge_properties(enum_prop, string_prop, "", config) == evolve( enum_prop, required=True, default=Value("test.A", "A"), @@ -154,7 +157,7 @@ def test_merge_string_with_string_enum(string_property_factory, enum_property_fa ) -def test_merge_int_with_int_enum(int_property_factory, enum_property_factory): +def test_merge_int_with_int_enum(int_property_factory, enum_property_factory, config): values = {"VALUE_1": 1, "VALUE_2": 2} int_prop = int_property_factory(default=Value("1", 1), description="desc1", example="example1") enum_prop = enum_property_factory( @@ -165,8 +168,8 @@ def test_merge_int_with_int_enum(int_property_factory, enum_property_factory): value_type=int, ) - assert merge_properties(int_prop, enum_prop) == evolve(enum_prop, required=True) - assert merge_properties(enum_prop, int_prop) == evolve( + assert merge_properties(int_prop, enum_prop, "", config) == evolve(enum_prop, required=True) + assert merge_properties(enum_prop, int_prop, "", config) == evolve( enum_prop, required=True, description=int_prop.description, example=int_prop.example ) @@ -178,6 +181,7 @@ def test_merge_with_incompatible_enum( string_property_factory, enum_property_factory, model_property_factory, + config, ): props = [ boolean_property_factory(), @@ -190,11 +194,11 @@ def test_merge_with_incompatible_enum( int_enum_prop = enum_property_factory(value_type=int) for prop in props: if not isinstance(prop, StringProperty): - assert isinstance(merge_properties(prop, string_enum_prop), PropertyError) - assert isinstance(merge_properties(string_enum_prop, prop), PropertyError) + assert isinstance(merge_properties(prop, string_enum_prop, "", config), PropertyError) + assert isinstance(merge_properties(string_enum_prop, prop, "", config), PropertyError) if not isinstance(prop, IntProperty): - assert isinstance(merge_properties(prop, int_enum_prop), PropertyError) - assert isinstance(merge_properties(int_enum_prop, prop), PropertyError) + assert isinstance(merge_properties(prop, int_enum_prop, "", config), PropertyError) + assert isinstance(merge_properties(int_enum_prop, prop, "", config), PropertyError) def test_merge_string_with_formatted_string( @@ -202,6 +206,7 @@ def test_merge_string_with_formatted_string( date_time_property_factory, file_property_factory, string_property_factory, + config, ): string_prop = string_property_factory(description="a plain string") string_prop_with_invalid_default = string_property_factory( @@ -213,19 +218,19 @@ def test_merge_string_with_formatted_string( file_property_factory(description="a file"), ] for formatted_prop in formatted_props: - merged1 = merge_properties(string_prop, formatted_prop) + merged1 = merge_properties(string_prop, formatted_prop, "", config) assert isinstance(merged1, formatted_prop.__class__) assert merged1.description == formatted_prop.description - merged2 = merge_properties(formatted_prop, string_prop) + merged2 = merge_properties(formatted_prop, string_prop, "", config) assert isinstance(merged2, formatted_prop.__class__) assert merged2.description == string_prop.description - assert isinstance(merge_properties(string_prop_with_invalid_default, formatted_prop), PropertyError) - assert isinstance(merge_properties(formatted_prop, string_prop_with_invalid_default), PropertyError) + assert isinstance(merge_properties(string_prop_with_invalid_default, formatted_prop, "", config), PropertyError) + assert isinstance(merge_properties(formatted_prop, string_prop_with_invalid_default, "", config), PropertyError) -def test_merge_lists(int_property_factory, list_property_factory, string_property_factory): +def test_merge_lists(int_property_factory, list_property_factory, string_property_factory, config): string_prop_1 = string_property_factory(description="desc1") string_prop_2 = string_property_factory(example="desc2") int_prop = int_property_factory() @@ -233,11 +238,11 @@ def test_merge_lists(int_property_factory, list_property_factory, string_propert list_prop_2 = list_property_factory(inner_property=string_prop_2) list_prop_3 = list_property_factory(inner_property=int_prop) - assert merge_properties(list_prop_1, list_prop_2) == evolve( - list_prop_1, inner_property=merge_properties(string_prop_1, string_prop_2) + assert merge_properties(list_prop_1, list_prop_2, "", config) == evolve( + list_prop_1, inner_property=merge_properties(string_prop_1, string_prop_2, "", config) ) - assert isinstance(merge_properties(list_prop_1, list_prop_3), PropertyError) + assert isinstance(merge_properties(list_prop_1, list_prop_3, "", config), PropertyError) def test_merge_related_models(model_property_factory, string_property_factory, config): @@ -271,8 +276,8 @@ def test_merge_related_models(model_property_factory, string_property_factory, c class_info=Class.from_string(string="DerivedModel", config=config), ) - assert merge_properties(base_model, derived_model) == evolve(derived_model, example=base_model.example) - assert merge_properties(derived_model, base_model) == evolve( + assert merge_properties(base_model, derived_model, "", config) == evolve(derived_model, example=base_model.example) + assert merge_properties(derived_model, base_model, "", config) == evolve( derived_model, description=base_model.description, example=base_model.example ) @@ -307,7 +312,7 @@ def test_merge_models_fails_for_unrelated_required_property(model_property_facto class_info=Class.from_string(string="Model2", config=config), ) - result = merge_properties(model_1, model_2) + result = merge_properties(model_1, model_2, "", config) assert isinstance(result, PropertyError) assert result.detail == "unable to merge two unrelated object types for this property" @@ -342,6 +347,6 @@ def test_merge_models_fails_for_unrelated_optional_property(model_property_facto class_info=Class.from_string(string="Model2", config=config), ) - result = merge_properties(model_1, model_2) + result = merge_properties(model_1, model_2, "", config) assert isinstance(result, PropertyError) assert result.detail == "unable to merge two unrelated object types for this property" From e6cdf95cc386f689da56bd81b73eab999b1c12b7 Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Tue, 5 Nov 2024 12:23:03 -0800 Subject: [PATCH 3/4] lint --- .../parser/properties/__init__.py | 2 +- .../parser/properties/merge_properties.py | 24 ++++++++++++------- .../parser/properties/model_property.py | 4 ++-- .../test_properties/test_merge_properties.py | 7 ++++-- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/openapi_python_client/parser/properties/__init__.py b/openapi_python_client/parser/properties/__init__.py index 9a1e050e8..94c6e3d08 100644 --- a/openapi_python_client/parser/properties/__init__.py +++ b/openapi_python_client/parser/properties/__init__.py @@ -137,7 +137,7 @@ def _property_from_ref( return prop, schemas -def property_from_data( # noqa: PLR0911 +def property_from_data( # noqa: PLR0911, PLR0912 name: str, required: bool, data: oai.Reference | oai.Schema, diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index aaf5c8d1f..cccc8c7b9 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -34,7 +34,7 @@ STRING_WITH_FORMAT_TYPES = (DateProperty, DateTimeProperty, FileProperty) -def merge_properties( # noqa:PLR0911 +def merge_properties( # noqa:PLR0911 prop1: Property, prop2: Property, parent_name: str, @@ -83,7 +83,9 @@ def merge_properties( # noqa:PLR0911 ) -def _merge_same_type(prop1: Property, prop2: Property, parent_name: str, config: Config) -> Property | None | PropertyError: +def _merge_same_type( + prop1: Property, prop2: Property, parent_name: str, config: Config +) -> Property | None | PropertyError: if type(prop1) is not type(prop2): return None @@ -105,7 +107,9 @@ def _merge_same_type(prop1: Property, prop2: Property, parent_name: str, config: return _merge_common_attributes(prop1, prop2) -def _merge_models(prop1: ModelProperty, prop2: ModelProperty, parent_name: str, config: Config) -> Property | PropertyError: +def _merge_models( + prop1: ModelProperty, prop2: ModelProperty, parent_name: str, config: Config +) -> Property | PropertyError: # Ideally, we would treat this case the same as a schema that consisted of "allOf: [prop1, prop2]", # applying the property merge logic recursively and creating a new third schema if the result could # not be fully described by one or the other. But for now we will just handle the common case where @@ -115,7 +119,7 @@ def _merge_models(prop1: ModelProperty, prop2: ModelProperty, parent_name: str, if prop.needs_post_processing(): # This means not all of the details of the schema have been filled in, possibly due to a # forward reference. That may be resolved in a later pass, but for now we can't proceed. - return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop) + return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop.data) # Detect whether one of the schemas is derived from the other-- that is, if it is (or is equivalent # to) the result of taking the other type and adding/modifying properties with allOf. If so, then @@ -137,9 +141,9 @@ def _merge_models(prop1: ModelProperty, prop2: ModelProperty, parent_name: str, merged_props[sub_prop.name] = merged_prop else: merged_props[sub_prop.name] = sub_prop - + prop_data = _gather_property_data(merged_props.values(), Schemas()) - + name = prop2.name class_string = f"{utils.pascal_case(parent_name)}{utils.pascal_case(name)}" class_info = Class.from_string(string=class_string, config=config) @@ -276,7 +280,9 @@ def _values_are_subset(prop1: EnumProperty, prop2: EnumProperty) -> bool: return set(prop1.values.items()) <= set(prop2.values.items()) -def _model_is_extension_of(extended_model: ModelProperty, base_model: ModelProperty, parent_name: str, config: Config) -> bool: +def _model_is_extension_of( + extended_model: ModelProperty, base_model: ModelProperty, parent_name: str, config: Config +) -> bool: def _properties_are_extension_of(extended_list: list[Property], base_list: list[Property]) -> bool: for p2 in base_list: if not [p1 for p1 in extended_list if _property_is_extension_of(p2, p1, parent_name, config)]: @@ -288,7 +294,9 @@ def _properties_are_extension_of(extended_list: list[Property], base_list: list[ ) and _properties_are_extension_of(extended_model.optional_properties, base_model.optional_properties) -def _property_is_extension_of(extended_prop: PropertyProtocol, base_prop: PropertyProtocol, parent_name: str, config: Config) -> bool: +def _property_is_extension_of( + extended_prop: Property, base_prop: Property, parent_name: str, config: Config +) -> bool: return base_prop.name == extended_prop.name and ( base_prop == extended_prop or merge_properties(base_prop, extended_prop, parent_name, config) == extended_prop ) diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 0469b9ad1..e324d6056 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -342,8 +342,8 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: def _gather_property_data(properties: Iterable[Property], schemas: Schemas) -> _PropertyData: - required_properties = [] - optional_properties = [] + required_properties: list[Property] = [] + optional_properties: list[Property] = [] relative_imports: set[str] = set() lazy_imports: set[str] = set() for prop in properties: diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 2e67377ff..16228869c 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -153,7 +153,8 @@ def test_merge_enums(literal_enums, enum_property_factory, literal_enum_property @pytest.mark.parametrize("literal_enums", (False, True)) def test_merge_string_with_string_enum( - literal_enums, string_property_factory, enum_property_factory, literal_enum_property_factory, config): + literal_enums, string_property_factory, enum_property_factory, literal_enum_property_factory, config +): string_prop = string_property_factory(default=Value("A", "A"), description="desc1", example="example1") enum_prop = ( literal_enum_property_factory( @@ -327,7 +328,9 @@ def test_merge_related_models(model_property_factory, string_property_factory, c class_info=Class.from_string(string="DerivedModel", config=config), ) - assert merge_properties(base_model, extension_model, "", config) == evolve(extension_model, example=base_model.example) + assert merge_properties(base_model, extension_model, "", config) == evolve( + extension_model, example=base_model.example + ) assert merge_properties(extension_model, base_model, "", config) == evolve( extension_model, description=base_model.description, example=base_model.example ) From c363974148c079f88d23d1c88a0a0786edde393b Mon Sep 17 00:00:00 2001 From: Eli Bishop Date: Tue, 5 Nov 2024 12:50:53 -0800 Subject: [PATCH 4/4] simplify a bit --- .../parser/properties/merge_properties.py | 33 +++---- .../parser/properties/model_property.py | 83 +++++++---------- .../test_properties/test_merge_properties.py | 48 ++++++++++ .../test_properties/test_model_property.py | 91 +++++++++---------- .../test_properties/test_protocol.py | 8 ++ 5 files changed, 146 insertions(+), 117 deletions(-) diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index cccc8c7b9..fdb479090 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -9,8 +9,8 @@ from openapi_python_client.parser.properties.datetime import DateTimeProperty from openapi_python_client.parser.properties.file import FileProperty from openapi_python_client.parser.properties.literal_enum_property import LiteralEnumProperty -from openapi_python_client.parser.properties.model_property import ModelDetails, ModelProperty, _gather_property_data -from openapi_python_client.parser.properties.schemas import Class, Schemas +from openapi_python_client.parser.properties.model_property import ModelProperty, _gather_property_data +from openapi_python_client.parser.properties.schemas import Class __all__ = ["merge_properties"] @@ -110,21 +110,19 @@ def _merge_same_type( def _merge_models( prop1: ModelProperty, prop2: ModelProperty, parent_name: str, config: Config ) -> Property | PropertyError: - # Ideally, we would treat this case the same as a schema that consisted of "allOf: [prop1, prop2]", - # applying the property merge logic recursively and creating a new third schema if the result could - # not be fully described by one or the other. But for now we will just handle the common case where - # B is an object type that extends A and fully includes it, with no changes to any of A's properties; - # in that case, it is valid to just reuse the model class for B. + # The logic here is basically equivalent to what we would do for a schema that was + # "allOf: [prop1, prop2]". We apply the property merge logic recursively and create a new third + # schema if the result cannot be fully described by one or the other. If it *can* be fully + # described by one or the other, then we can simply reuse the class for that one: for instance, + # in a common case where B is an object type that extends A and fully includes it, so that + # "allOf: [A, B]" would be the same as B, then it's valid to use the existing B model class. + # We would still call _merge_common_attributes in that case, to handle metadat like "description". for prop in [prop1, prop2]: if prop.needs_post_processing(): # This means not all of the details of the schema have been filled in, possibly due to a # forward reference. That may be resolved in a later pass, but for now we can't proceed. return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop.data) - # Detect whether one of the schemas is derived from the other-- that is, if it is (or is equivalent - # to) the result of taking the other type and adding/modifying properties with allOf. If so, then - # we can simply use the class of the derived type. We will still call _merge_common_attributes in - # case any metadata like "description" has been modified. if _model_is_extension_of(prop1, prop2, parent_name, config): return _merge_common_attributes(prop1, prop2) elif _model_is_extension_of(prop2, prop1, parent_name, config): @@ -142,20 +140,13 @@ def _merge_models( else: merged_props[sub_prop.name] = sub_prop - prop_data = _gather_property_data(merged_props.values(), Schemas()) + prop_details = _gather_property_data(merged_props.values()) name = prop2.name class_string = f"{utils.pascal_case(parent_name)}{utils.pascal_case(name)}" class_info = Class.from_string(string=class_string, config=config) roots = prop1.roots.union(prop2.roots).difference({prop1.class_info.name, prop2.class_info.name}) roots.add(class_info.name) - prop_details = ModelDetails( - required_properties=prop_data.required_props, - optional_properties=prop_data.optional_props, - additional_properties=None, - relative_imports=prop_data.relative_imports, - lazy_imports=prop_data.lazy_imports, - ) prop = ModelProperty( class_info=class_info, data=oai.Schema.model_construct(allOf=[prop1.data, prop2.data]), @@ -294,9 +285,7 @@ def _properties_are_extension_of(extended_list: list[Property], base_list: list[ ) and _properties_are_extension_of(extended_model.optional_properties, base_model.optional_properties) -def _property_is_extension_of( - extended_prop: Property, base_prop: Property, parent_name: str, config: Config -) -> bool: +def _property_is_extension_of(extended_prop: Property, base_prop: Property, parent_name: str, config: Config) -> bool: return base_prop.name == extended_prop.name and ( base_prop == extended_prop or merge_properties(base_prop, extended_prop, parent_name, config) == extended_prop ) diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index e324d6056..fff99cc1a 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -1,10 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass from itertools import chain -from typing import Any, ClassVar, Iterable, NamedTuple +from typing import Any, ClassVar, Iterable -from attrs import define, evolve +from attrs import define, evolve, field from ... import Config, utils from ... import schema as oai @@ -15,13 +14,15 @@ from .schemas import Class, ReferencePath, Schemas, parse_reference_path -@dataclass +@define class ModelDetails: + """Container for basic attributes of a model schema that can be computed separately""" + required_properties: list[Property] | None = None optional_properties: list[Property] | None = None additional_properties: Property | None = None - relative_imports: set[str] | None = None - lazy_imports: set[str] | None = None + relative_imports: set[str] = field(factory=set) + lazy_imports: set[str] = field(factory=set) @define @@ -88,11 +89,7 @@ def build( ) if isinstance(data_or_err, PropertyError): return data_or_err, schemas - property_data, details.additional_properties = data_or_err - details.required_properties = property_data.required_props - details.optional_properties = property_data.optional_props - details.relative_imports = property_data.relative_imports - details.lazy_imports = property_data.lazy_imports + details = data_or_err for root in roots: if isinstance(root, utils.ClassName): continue @@ -142,11 +139,11 @@ def additional_properties(self) -> Property | None: @property def relative_imports(self) -> set[str]: - return self.details.relative_imports or set() + return self.details.relative_imports @property def lazy_imports(self) -> set[str] | None: - return self.details.lazy_imports or set() + return self.details.lazy_imports @classmethod def convert_value(cls, value: Any) -> Value | None | PropertyError: @@ -253,14 +250,6 @@ def _resolve_naming_conflict(first: Property, second: Property, config: Config) return None -class _PropertyData(NamedTuple): - optional_props: list[Property] - required_props: list[Property] - relative_imports: set[str] - lazy_imports: set[str] - schemas: Schemas - - def _process_properties( # noqa: PLR0911 *, data: oai.Schema, @@ -268,7 +257,7 @@ def _process_properties( # noqa: PLR0911 class_name: utils.ClassName, config: Config, roots: set[ReferencePath | utils.ClassName], -) -> _PropertyData | PropertyError: +) -> tuple[ModelDetails | PropertyError, Schemas]: from . import property_from_data from .merge_properties import merge_properties @@ -303,19 +292,19 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if isinstance(sub_prop, oai.Reference): ref_path = parse_reference_path(sub_prop.ref) if isinstance(ref_path, ParseError): - return PropertyError(detail=ref_path.detail, data=sub_prop) + return PropertyError(detail=ref_path.detail, data=sub_prop), schemas sub_model = schemas.classes_by_reference.get(ref_path) if sub_model is None: - return PropertyError(f"Reference {sub_prop.ref} not found") + return PropertyError(f"Reference {sub_prop.ref} not found"), schemas if not isinstance(sub_model, ModelProperty): - return PropertyError("Cannot take allOf a non-object") + return PropertyError("Cannot take allOf a non-object"), schemas # Properties of allOf references first should be processed first if sub_model.needs_post_processing(): - return PropertyError(f"Reference {sub_model.name} in allOf was not processed", data=sub_prop) + return PropertyError(f"Reference {sub_model.name} in allOf was not processed", data=sub_prop), schemas for prop in chain(sub_model.required_properties, sub_model.optional_properties): err = _add_if_no_conflict(prop) if err is not None: - return err + return err, schemas schemas.add_dependencies(ref_path=ref_path, roots=roots) else: unprocessed_props.extend(sub_prop.properties.items() if sub_prop.properties else []) @@ -336,12 +325,12 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if not isinstance(prop_or_error, PropertyError): prop_or_error = _add_if_no_conflict(prop_or_error) if isinstance(prop_or_error, PropertyError): - return prop_or_error + return prop_or_error, schemas - return _gather_property_data(properties.values(), schemas) + return _gather_property_data(properties.values()), schemas -def _gather_property_data(properties: Iterable[Property], schemas: Schemas) -> _PropertyData: +def _gather_property_data(properties: Iterable[Property]) -> ModelDetails: required_properties: list[Property] = [] optional_properties: list[Property] = [] relative_imports: set[str] = set() @@ -350,12 +339,12 @@ def _gather_property_data(properties: Iterable[Property], schemas: Schemas) -> _ (required_properties if prop.required else optional_properties).append(prop) lazy_imports.update(prop.get_lazy_imports(prefix="..")) relative_imports.update(prop.get_imports(prefix="..")) - return _PropertyData( - optional_props=optional_properties, - required_props=required_properties, + return ModelDetails( + optional_properties=optional_properties, + required_properties=required_properties, relative_imports=relative_imports, lazy_imports=lazy_imports, - schemas=schemas, + additional_properties=None, ) @@ -410,13 +399,12 @@ def _process_property_data( class_info: Class, config: Config, roots: set[ReferencePath | utils.ClassName], -) -> tuple[tuple[_PropertyData, Property | None] | PropertyError, Schemas]: - property_data = _process_properties( +) -> tuple[ModelDetails | PropertyError, Schemas]: + model_details, schemas = _process_properties( data=data, schemas=schemas, class_name=class_info.name, config=config, roots=roots ) - if isinstance(property_data, PropertyError): - return property_data, schemas - schemas = property_data.schemas + if isinstance(model_details, PropertyError): + return model_details, schemas additional_properties, schemas = _get_additional_properties( schema_additional=data.additionalProperties, @@ -430,10 +418,11 @@ def _process_property_data( elif additional_properties is None: pass else: - property_data.relative_imports.update(additional_properties.get_imports(prefix="..")) - property_data.lazy_imports.update(additional_properties.get_lazy_imports(prefix="..")) + model_details = evolve(model_details, additional_properties=additional_properties) + model_details.relative_imports.update(additional_properties.get_imports(prefix="..")) + model_details.lazy_imports.update(additional_properties.get_lazy_imports(prefix="..")) - return (property_data, additional_properties), schemas + return model_details, schemas def process_model(model_prop: ModelProperty, *, schemas: Schemas, config: Config) -> Schemas | PropertyError: @@ -455,12 +444,8 @@ def process_model(model_prop: ModelProperty, *, schemas: Schemas, config: Config if isinstance(data_or_err, PropertyError): return data_or_err - property_data, additional_properties = data_or_err - - model_prop.details.required_properties = property_data.required_props - model_prop.details.optional_properties = property_data.optional_props - model_prop.details.additional_properties = additional_properties - model_prop.set_relative_imports(property_data.relative_imports) - model_prop.set_lazy_imports(property_data.lazy_imports) + model_prop.details = data_or_err + model_prop.set_relative_imports(data_or_err.relative_imports) + model_prop.set_lazy_imports(data_or_err.lazy_imports) return schemas diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 16228869c..71b30901b 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -373,3 +373,51 @@ def test_merge_unrelated_models(model_property_factory, string_property_factory, assert [p.name for p in result.optional_properties] == ["opt_1", "opt_2"] assert result.class_info.name == "ParentSchemaPropName" assert result.description == model_2.description + + +def test_merge_models_with_incompatible_property( + model_property_factory, string_property_factory, int_property_factory, config +): + model_1 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + string_property_factory(name="prop1", required=True), + ], + optional_properties=[], + ), + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + int_property_factory(name="prop1", required=True), + ], + optional_properties=[], + ), + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2, "ParentSchema", config) + + assert isinstance(result, PropertyError) + assert result.detail == "str can't be merged with int" + + +def test_merge_models_not_yet_processed(model_property_factory, string_property_factory, int_property_factory, config): + model_1 = model_property_factory( + name="propName", + details=ModelDetails(required_properties=None, optional_properties=None), + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="propName", + details=ModelDetails(required_properties=None, optional_properties=None), + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2, "ParentSchema", config) + + assert isinstance(result, PropertyError) + assert "not processed" in result.detail diff --git a/tests/test_parser/test_properties/test_model_property.py b/tests/test_parser/test_properties/test_model_property.py index 1fa0062b2..498877251 100644 --- a/tests/test_parser/test_properties/test_model_property.py +++ b/tests/test_parser/test_properties/test_model_property.py @@ -344,7 +344,7 @@ def test_conflicting_properties_different_types( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -355,14 +355,14 @@ def test_process_properties_reference_not_exist(self, config): }, ) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_process_properties_all_of_reference_not_exist(self, config): data = oai.Schema.model_construct(allOf=[oai.Reference.model_construct(ref="#/components/schema/NotExist")]) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -370,15 +370,15 @@ def test_process_properties_model_property_roots(self, model_property_factory, c roots = {"root"} data = oai.Schema(properties={"test_model_property": oai.Schema.model_construct(type="object")}) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots=roots) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots=roots) - assert all(root in result.optional_props[0].roots for root in roots) + assert all(root in result.optional_properties[0].roots for root in roots) def test_invalid_reference(self, config): data = oai.Schema.model_construct(allOf=[oai.Reference.model_construct(ref="ThisIsNotGood")]) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -390,7 +390,7 @@ def test_non_model_reference(self, enum_property_factory, config): } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -402,7 +402,7 @@ def test_reference_not_processed(self, model_property_factory, config): } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -425,8 +425,8 @@ def test_allof_string_and_string_enum( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property def test_allof_string_enum_and_string( self, model_property_factory, enum_property_factory, string_property_factory, config @@ -448,8 +448,8 @@ def test_allof_string_enum_and_string( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.optional_properties[0] == enum_property def test_allof_int_and_int_enum(self, model_property_factory, enum_property_factory, int_property_factory, config): data = oai.Schema.model_construct( @@ -466,8 +466,8 @@ def test_allof_int_and_int_enum(self, model_property_factory, enum_property_fact } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property def test_allof_enum_incompatible_type( self, model_property_factory, enum_property_factory, int_property_factory, config @@ -486,7 +486,7 @@ def test_allof_enum_incompatible_type( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_allof_string_enums(self, model_property_factory, enum_property_factory, config): @@ -510,8 +510,8 @@ def test_allof_string_enums(self, model_property_factory, enum_property_factory, } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property1 + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property1 def test_allof_int_enums(self, model_property_factory, enum_property_factory, config): data = oai.Schema.model_construct( @@ -534,8 +534,8 @@ def test_allof_int_enums(self, model_property_factory, enum_property_factory, co } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property2 + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property2 def test_allof_enums_are_not_subsets(self, model_property_factory, enum_property_factory, config): data = oai.Schema.model_construct( @@ -558,7 +558,7 @@ def test_allof_enums_are_not_subsets(self, model_property_factory, enum_property } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_duplicate_properties(self, model_property_factory, string_property_factory, config): @@ -573,9 +573,9 @@ def test_duplicate_properties(self, model_property_factory, string_property_fact } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props == [prop], "There should only be one copy of duplicate properties" + assert result.optional_properties == [prop], "There should only be one copy of duplicate properties" @pytest.mark.parametrize("first_required", [True, False]) @pytest.mark.parametrize("second_required", [True, False]) @@ -604,18 +604,18 @@ def test_mixed_requirements( ) roots = {"root"} - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots=roots) + result, schemas = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots=roots) required = first_required or second_required expected_prop = string_property_factory( required=required, ) - assert result.schemas.dependencies == {"/First": roots, "/Second": roots} + assert schemas.dependencies == {"/First": roots, "/Second": roots} if not required: - assert result.optional_props == [expected_prop] + assert result.optional_properties == [expected_prop] else: - assert result.required_props == [expected_prop] + assert result.required_properties == [expected_prop] def test_direct_properties_non_ref(self, string_property_factory, config): data = oai.Schema.model_construct( @@ -631,10 +631,10 @@ def test_direct_properties_non_ref(self, string_property_factory, config): ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props == [string_property_factory(name="second", required=False)] - assert result.required_props == [string_property_factory(name="first", required=True)] + assert result.optional_properties == [string_property_factory(name="second", required=False)] + assert result.required_properties == [string_property_factory(name="first", required=True)] def test_conflicting_property_names(self, config): data = oai.Schema.model_construct( @@ -644,7 +644,7 @@ def test_conflicting_property_names(self, config): } ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_merge_inline_objects(self, model_property_factory, enum_property_factory, config): @@ -666,10 +666,10 @@ def test_merge_inline_objects(self, model_property_factory, enum_property_factor ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert not isinstance(result, PropertyError) - assert len(result.optional_props) == 1 - prop1 = result.optional_props[0] + assert len(result.optional_properties) == 1 + prop1 = result.optional_properties[0] assert isinstance(prop1, StringProperty) assert prop1.description == "desc" assert prop1.default == StringProperty.convert_value("a") @@ -694,31 +694,30 @@ def test_process_model_error(self, mocker, model_property_factory, config): assert model_prop.relative_imports == set() assert model_prop.additional_properties is None - def test_process_model(self, mocker, model_property_factory, config): + def test_process_model(self, mocker, model_property_factory, string_property_factory, config): from openapi_python_client.parser.properties import Schemas - from openapi_python_client.parser.properties.model_property import _PropertyData, process_model + from openapi_python_client.parser.properties.model_property import ModelDetails, process_model model_prop = model_property_factory() schemas = Schemas() - property_data = _PropertyData( - required_props=["required"], - optional_props=["optional"], + model_details = ModelDetails( + required_properties=["required"], + optional_properties=["optional"], relative_imports={"relative"}, lazy_imports={"lazy"}, - schemas=schemas, + additional_properties=string_property_factory(), ) - additional_properties = True process_property_data = mocker.patch(f"{MODULE_NAME}._process_property_data") - process_property_data.return_value = ((property_data, additional_properties), schemas) + process_property_data.return_value = (model_details, schemas) result = process_model(model_prop=model_prop, schemas=schemas, config=config) assert result == schemas - assert model_prop.required_properties == property_data.required_props - assert model_prop.optional_properties == property_data.optional_props - assert model_prop.relative_imports == property_data.relative_imports - assert model_prop.lazy_imports == property_data.lazy_imports - assert model_prop.additional_properties == additional_properties + assert model_prop.required_properties == model_details.required_properties + assert model_prop.optional_properties == model_details.optional_properties + assert model_prop.relative_imports == model_details.relative_imports + assert model_prop.lazy_imports == model_details.lazy_imports + assert model_prop.additional_properties == model_details.additional_properties def test_set_relative_imports(model_property_factory): diff --git a/tests/test_parser/test_properties/test_protocol.py b/tests/test_parser/test_properties/test_protocol.py index 1d4111750..0b9c627f3 100644 --- a/tests/test_parser/test_properties/test_protocol.py +++ b/tests/test_parser/test_properties/test_protocol.py @@ -85,3 +85,11 @@ def test_get_base_json_type_string(quoted, expected, any_property_factory, mocke mocker.patch.object(AnyProperty, "_json_type_string", "str") p = any_property_factory() assert p.get_base_json_type_string(quoted=quoted) is expected + + +def test_needs_post_processing(any_property_factory, model_property_factory): + p1 = any_property_factory() + assert p1.needs_post_processing() is False + + p2 = model_property_factory() + assert p2.needs_post_processing() is True