Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add discriminator property support #214

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -2823,7 +2823,8 @@
"propertyName": "modelType",
"mapping": {
"type1": "#/components/schemas/ADiscriminatedUnionType1",
"type2": "#/components/schemas/ADiscriminatedUnionType2"
"type2": "#/components/schemas/ADiscriminatedUnionType2",
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
}
},
"oneOf": [
Expand All @@ -2841,15 +2842,17 @@
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary in order for the discriminator in the spec to really be valid: a discriminator property must be a required property in all of the variant schemas. My current implementation wouldn't actually catch a mistake like this, but I figured it was best to have the test spec be valid.

},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
},
"parameters": {
Expand Down
9 changes: 6 additions & 3 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,8 @@ info:
"propertyName": "modelType",
"mapping": {
"type1": "#/components/schemas/ADiscriminatedUnionType1",
"type2": "#/components/schemas/ADiscriminatedUnionType2"
"type2": "#/components/schemas/ADiscriminatedUnionType2",
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
}
},
"oneOf": [
Expand All @@ -2835,15 +2836,17 @@ info:
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
}
"parameters": {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType1")


@_attrs_define
class ADiscriminatedUnionType1:
"""
Attributes:
model_type (Union[Unset, str]):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this generated file and the next one are because I changed modelType to be required in the test spec.

model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
model_type = self.model_type

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_1 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType2")


@_attrs_define
class ADiscriminatedUnionType2:
"""
Attributes:
model_type (Union[Unset, str]):
model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
model_type = self.model_type

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_2 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,42 @@ def _parse_discriminated_union(
return data
if isinstance(data, Unset):
return data
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_0
except: # noqa: E722
pass
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1
except: # noqa: E722
pass
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
if not isinstance(data, dict):
raise TypeError()
if "modelType" in data:
_discriminator_value = data["modelType"]

def _parse_1(data: object) -> ADiscriminatedUnionType1:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_0

def _parse_2(data: object) -> ADiscriminatedUnionType2:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 68-72 here are boilerplate that comes from the generator's template macro for "call the deserializer for some type". The weird function name on line 67 is due to how the generator makes unique Python names for stuff in its internal data structure.


def _parse_3(data: object) -> ADiscriminatedUnionType2:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1

_discriminator_mapping = {
"type1": _parse_1,
"type2": _parse_2,
"type2-another-value": _parse_3,
}
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
return cast(
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
)
raise TypeError("unrecognized value for property modelType")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 81-89 are from my new custom template logic for discriminators.


discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))

Expand Down
1 change: 1 addition & 0 deletions openapi_python_client/parser/properties/model_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol):
relative_imports: set[str] | None
lazy_imports: set[str] | None
additional_properties: Property | None
ref_path: ReferencePath | None = None
_json_type_string: ClassVar[str] = "Dict[str, Any]"

template: ClassVar[str] = "model_property.py.jinja"
Expand Down
5 changes: 5 additions & 0 deletions openapi_python_client/parser/properties/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from openapi_python_client.parser.properties.schemas import ReferencePath

__all__ = ["PropertyProtocol", "Value"]

from abc import abstractmethod
Expand Down Expand Up @@ -185,3 +187,6 @@ def is_base_type(self) -> bool:
ListProperty.__name__,
UnionProperty.__name__,
}

def get_ref_path(self) -> ReferencePath | None:
return self.ref_path if hasattr(self, "ref_path") else None
16 changes: 16 additions & 0 deletions openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]:
return cast(ReferencePath, parsed.fragment)


def get_reference_simple_name(ref_path: str) -> str:
"""
Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`.
"""
return ref_path.split("/", 3)[-1]


@define
class Class:
"""Represents Python class which will be generated from an OpenAPI schema"""
Expand Down Expand Up @@ -135,6 +142,15 @@ def update_schemas_with_data(
)
return prop

# Save the original path (/components/schemas/X) in the property. This is important because:
# 1. There are some contexts (such as a union with a discriminator) where we have a Property
# instance and we want to know what its path is, instead of the other way round.
# 2. Even though we did set prop.name to be the same as ref_path when we created it above,
# whenever there's a $ref to this property, we end up making a copy of it and changing
# the name. So we can't rely on prop.name always being the path.
if hasattr(prop, "ref_path"):
prop.ref_path = ref_path

schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference})
return schemas

Expand Down
Loading