Skip to content

Commit

Permalink
add discriminator property support
Browse files Browse the repository at this point in the history
  • Loading branch information
eli-bl committed Oct 28, 2024
1 parent a0b1bb7 commit 0496ff8
Show file tree
Hide file tree
Showing 14 changed files with 655 additions and 64 deletions.
3 changes: 3 additions & 0 deletions end_to_end_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
""" Generate a complete client and verify that it is correct """
import pytest

pytest.register_assert_rewrite('end_to_end_tests.end_to_end_live_tests')
6 changes: 4 additions & 2 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -2841,15 +2841,17 @@
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
},
"parameters": {
Expand Down
6 changes: 4 additions & 2 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2835,15 +2835,17 @@ info:
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
}
"parameters": {
Expand Down
33 changes: 33 additions & 0 deletions end_to_end_tests/end_to_end_live_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib
from typing import Any

import pytest


def live_tests_3_x():
_test_model_with_discriminated_union()


def _import_model(module_name, class_name: str) -> Any:
module = importlib.import_module(f"my_test_api_client.models.{module_name}")
module = importlib.reload(module) # avoid test contamination from previous import
return getattr(module, class_name)


def _test_model_with_discriminated_union():
ModelType1Class = _import_model("a_discriminated_union_type_1", "ADiscriminatedUnionType1")
ModelType2Class = _import_model("a_discriminated_union_type_2", "ADiscriminatedUnionType2")
ModelClass = _import_model("model_with_discriminated_union", "ModelWithDiscriminatedUnion")

assert (
ModelClass.from_dict({"discriminated_union": {"modelType": "type1"}}) ==
ModelClass(discriminated_union=ModelType1Class.from_dict({"modelType": "type1"}))
)
assert (
ModelClass.from_dict({"discriminated_union": {"modelType": "type2"}}) ==
ModelClass(discriminated_union=ModelType2Class.from_dict({"modelType": "type2"}))
)
with pytest.raises(TypeError):
ModelClass.from_dict({"discriminated_union": {"modelType": "type3"}})
with pytest.raises(TypeError):
ModelClass.from_dict({"discriminated_union": {}})
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType1")


@_attrs_define
class ADiscriminatedUnionType1:
"""
Attributes:
model_type (Union[Unset, str]):
model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
model_type = self.model_type

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_1 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType2")


@_attrs_define
class ADiscriminatedUnionType2:
"""
Attributes:
model_type (Union[Unset, str]):
model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
model_type = self.model_type

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_2 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,34 @@ def _parse_discriminated_union(
return data
if isinstance(data, Unset):
return data
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_0
except: # noqa: E722
pass
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1
except: # noqa: E722
pass
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
if not isinstance(data, dict):
raise TypeError()
if "modelType" in data:
_discriminator_value = data["modelType"]

def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_1

def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_2

_discriminator_mapping = {
"type1": _parse_componentsschemas_a_discriminated_union_type_1,
"type2": _parse_componentsschemas_a_discriminated_union_type_2,
}
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
return cast(
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
)
raise TypeError("unrecognized value for property modelType")

discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))

Expand Down
18 changes: 15 additions & 3 deletions end_to_end_tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import shutil
from filecmp import cmpfiles, dircmp
from pathlib import Path
from typing import Dict, List, Optional, Set
import sys
from typing import Callable, Dict, List, Optional, Set

import pytest
from click.testing import Result
from typer.testing import CliRunner

from openapi_python_client.cli import app
from .end_to_end_live_tests import live_tests_3_x



def _compare_directories(
Expand Down Expand Up @@ -83,6 +87,7 @@ def run_e2e_test(
golden_record_path: str = "golden-record",
output_path: str = "my-test-api-client",
expected_missing: Optional[Set[str]] = None,
live_tests: Optional[Callable[[str], None]] = None,
) -> Result:
output_path = Path.cwd() / output_path
shutil.rmtree(output_path, ignore_errors=True)
Expand All @@ -97,6 +102,13 @@ def run_e2e_test(
_compare_directories(
gr_path, output_path, expected_differences=expected_differences, expected_missing=expected_missing
)
if live_tests:
old_path = sys.path.copy()
sys.path.insert(0, str(output_path))
try:
live_tests()
finally:
sys.path = old_path

import mypy.api

Expand Down Expand Up @@ -131,11 +143,11 @@ def _run_command(command: str, extra_args: Optional[List[str]] = None, openapi_d


def test_baseline_end_to_end_3_0():
run_e2e_test("baseline_openapi_3.0.json", [], {})
run_e2e_test("baseline_openapi_3.0.json", [], {}, live_tests=live_tests_3_x)


def test_baseline_end_to_end_3_1():
run_e2e_test("baseline_openapi_3.1.yaml", [], {})
run_e2e_test("baseline_openapi_3.1.yaml", [], {}, live_tests=live_tests_3_x)


def test_3_1_specific_features():
Expand Down
7 changes: 7 additions & 0 deletions openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]:
return cast(ReferencePath, parsed.fragment)


def get_reference_simple_name(ref_path: str) -> str:
"""
Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`.
"""
return ref_path.split("/", 3)[-1]


@define
class Class:
"""Represents Python class which will be generated from an OpenAPI schema"""
Expand Down
Loading

0 comments on commit 0496ff8

Please sign in to comment.