Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Support wrapped aliases for Pydantic v2 #5103

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3,002 changes: 1,461 additions & 1,541 deletions generators/go-v2/dynamic-snippets/src/__test__/__snapshots__/ir.test.ts.snap

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def check_wrapped_aliases_v1_only(self) -> Self:
version_compat = self.version
use_wrapped_aliases = self.wrapped_aliases

if use_wrapped_aliases and version_compat != PydanticVersionCompatibility.V1:
if use_wrapped_aliases and version_compat == PydanticVersionCompatibility.Both:
raise ValueError(
"Wrapped aliases are only supported in Pydantic V1, please update your `version` field to be 'v1' to continue using wrapped aliases."
"Wrapped aliases are not supported for `both`, please update your `version` field to be 'v1' or `v2` to continue using wrapped aliases."
)

if self.enum_type != "literals":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import fern.ir.resources as ir_types

from fern_python.codegen import AST, LocalClassReference, SourceFile
from fern_python.external_dependencies.pydantic import PydanticVersionCompatibility
from fern_python.pydantic_codegen import PydanticField, PydanticModel

from ..context import PydanticGeneratorContext
Expand Down Expand Up @@ -371,4 +370,4 @@ def __exit__(
excinst: Optional[BaseException],
exctb: Optional[TracebackType],
) -> None:
self.finish()
self.finish()
Original file line number Diff line number Diff line change
Expand Up @@ -498,4 +498,4 @@ def assert_never(arg: Never) -> Never:


def get_field_name_for_single_property(property: ir_types.SingleUnionTypeProperty) -> str:
return property.name.name.snake_case.safe_name
return property.name.name.snake_case.safe_name
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,19 @@ def generate(
docstring=self._docs,
snippet=self._snippet,
) as pydantic_model:
root_name = "__root__" if self._custom_config.version == "v1" else "root"
pydantic_model.set_root_type(self._alias.alias_of)
pydantic_model.add_method(
name=self._get_getter_name(self._alias.alias_of),
parameters=[],
return_type=self._alias.alias_of,
body=AST.CodeWriter("return self.__root__"),
body=AST.CodeWriter(f"return self.{root_name}"),
)
pydantic_model.add_method(
name=self._get_builder_name(self._alias.alias_of),
parameters=[(BUILDER_PARAMETER_NAME, self._alias.alias_of)],
return_type=ir_types.TypeReference.factory.named(declared_type_name_to_named_type(self._name)),
body=AST.CodeWriter(f"return {pydantic_model.get_class_name()}(__root__={BUILDER_PARAMETER_NAME})"),
body=AST.CodeWriter(f"return {pydantic_model.get_class_name()}({root_name}={BUILDER_PARAMETER_NAME})"),
decorator=AST.ClassMethodDecorator.STATIC,
)

Expand Down Expand Up @@ -145,4 +146,4 @@ def __init__(
as_request=False,
)

# generate_snippet delegates to the parent class AliasSnippetGenerator
# generate_snippet delegates to the parent class AliasSnippetGenerator
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .pydantic_validators_generator import PydanticValidatorsGenerator
from .validators_generator import ValidatorsGenerator

__all__ = ["ValidatorsGenerator", "PydanticValidatorsGenerator", "PydanticV1CustomRootTypeValidatorsGenerator"]
__all__ = ["ValidatorsGenerator", "PydanticValidatorsGenerator", "PydanticV1CustomRootTypeValidatorsGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,4 @@ def write_example_for_docstring(self, writer: AST.NodeWriter) -> None:

def _get_root_property_name(self) -> str:
return "__root__" if self._model._version == PydanticVersionCompatibility.V1 else "root"

Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ def add_field(self, unsanitized_field: PydanticField) -> None:
self._has_aliases |= is_aliased

default_value = (
(
AST.Expression("None")
if unsanitized_field.type_hint.is_optional and self._require_optional_fields is False
else None
)
(AST.Expression("None") if unsanitized_field.type_hint.is_optional and self._require_optional_fields is False else None)
if field.default_value is None
else field.default_value
)
Expand Down Expand Up @@ -141,22 +137,16 @@ def add_field(self, unsanitized_field: PydanticField) -> None:
type_hint=aliased_type_hint,
)

self._class_declaration.add_class_var(
AST.VariableDeclaration(name=field.name, type_hint=field.type_hint, initializer=initializer)
)
self._class_declaration.add_class_var(AST.VariableDeclaration(name=field.name, type_hint=field.type_hint, initializer=initializer))

self._fields.append(field)

def get_public_fields(self) -> List[PydanticField]:
return self._fields

def add_private_instance_field(
self, name: str, type_hint: AST.TypeHint, default_factory: Optional[AST.Expression] = None
) -> None:
def add_private_instance_field(self, name: str, type_hint: AST.TypeHint, default_factory: Optional[AST.Expression] = None) -> None:
if not name.startswith("_"):
raise RuntimeError(
f"Private pydantic field {name} in {self._class_declaration.name} does not start with an underscore"
)
raise RuntimeError(f"Private pydantic field {name} in {self._class_declaration.name} does not start with an underscore")
self._class_declaration.add_class_var(
AST.VariableDeclaration(
name=name,
Expand Down Expand Up @@ -240,25 +230,22 @@ def add_root_validator(
declaration=AST.FunctionDeclaration(
name=validator_name,
signature=AST.FunctionSignature(
parameters=[
AST.FunctionParameter(name=PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME, type_hint=value_type)
],
parameters=[AST.FunctionParameter(name=PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME, type_hint=value_type)],
return_type=value_type,
),
body=body,
decorators=[self._universal_root_validator(pre)],
),
)

def set_root_type_unsafe(
self, root_type: AST.TypeHint, annotation: Optional[AST.Expression] = None
) -> None:
def set_root_type_unsafe(self, root_type: AST.TypeHint, annotation: Optional[AST.Expression] = None) -> None:
if self._version == PydanticVersionCompatibility.Both:
raise RuntimeError("Overriding root types is only available in Pydantic v1 or v2")

if self._root_type is not None:
raise RuntimeError("__root__ was already added")



self.root_type = root_type

root_type_with_annotation = (
Expand All @@ -270,16 +257,11 @@ def set_root_type_unsafe(
else root_type
)

if self._version == PydanticVersionCompatibility.V1:
self._class_declaration.add_statement(
AST.VariableDeclaration(name="__root__", type_hint=root_type_with_annotation)
)

if self._version == PydanticVersionCompatibility.V2:
self._class_declaration.add_statement(
AST.VariableDeclaration(name="root", type_hint=root_type_with_annotation)
)

if self._version == PydanticVersionCompatibility.V1:
self._class_declaration.add_statement(AST.VariableDeclaration(name="__root__", type_hint=root_type_with_annotation))

if self._version == PydanticVersionCompatibility.V2:
self._class_declaration.add_statement(AST.VariableDeclaration(name="root", type_hint=root_type_with_annotation))

def get_root_type_unsafe(self) -> Optional[AST.TypeHint]:
return self.root_type
Expand All @@ -296,8 +278,7 @@ def add_partial_class(self) -> None:
extends=[
dataclasses.replace(
base_model,
qualified_name_excluding_import=base_model.qualified_name_excluding_import
+ (PydanticModel._PARTIAL_CLASS_NAME,),
qualified_name_excluding_import=base_model.qualified_name_excluding_import + (PydanticModel._PARTIAL_CLASS_NAME,),
)
for base_model in self._base_models
]
Expand Down Expand Up @@ -380,10 +361,7 @@ def write_extras(writer: AST.NodeWriter) -> None:
writer.write_node(v2_model_config)

if (
(
self._version == PydanticVersionCompatibility.Both
and (v1_config_class is not None or v2_model_config is not None)
)
(self._version == PydanticVersionCompatibility.Both and (v1_config_class is not None or v2_model_config is not None))
or (self._version == PydanticVersionCompatibility.V1 and v1_config_class is not None)
or (self._version == PydanticVersionCompatibility.V2 and v2_model_config is not None)
):
Expand Down
6 changes: 3 additions & 3 deletions generators/python/tests/sdk/test_custom_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def test_parse_wrapped_aliases() -> None:
"wrapped_aliases": True,
},
}
with pytest.raises(pydantic.ValidationError, match="Wrapped aliases are only supported in Pydantic V1, please update your `version` field to be 'v1' to continue using wrapped aliases."):
SDKCustomConfig.parse_obj(v2)
sdk_custom_config = SDKCustomConfig.parse_obj(v2)
assert sdk_custom_config.pydantic_config.version == "v2" and sdk_custom_config.pydantic_config.wrapped_aliases is True

both = {
"pydantic_config": {
"version": "both",
"wrapped_aliases": True,
},
}
with pytest.raises(pydantic.ValidationError, match="Wrapped aliases are only supported in Pydantic V1, please update your `version` field to be 'v1' to continue using wrapped aliases."):
with pytest.raises(pydantic.ValidationError, match="Wrapped aliases are not supported for `both`, please update your `version` field to be 'v1' or `v2` to continue using wrapped aliases."):
SDKCustomConfig.parse_obj(both)
Loading