Skip to content

Commit

Permalink
mostly there
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-giglio committed Nov 5, 2024
1 parent dfd31f6 commit 8b86656
Show file tree
Hide file tree
Showing 136 changed files with 21,856 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def check_wrapped_aliases_v1_only(self) -> Self:
version_compat = self.version
use_wrapped_aliases = self.wrapped_aliases

if use_wrapped_aliases and version_compat != PydanticVersionCompatibility.V1:
if use_wrapped_aliases and version_compat == PydanticVersionCompatibility.Both:
raise ValueError(
"Wrapped aliases are only supported in Pydantic V1, please update your `version` field to be 'v1' to continue using wrapped aliases."
"Wrapped aliases are not supported for `both`, please update your `version` field to be 'v1' or `v2` to continue using wrapped aliases."
)

if self.enum_type != "literals":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,22 @@ def add_method_unsafe(
) -> AST.FunctionDeclaration:
return self._pydantic_model.add_method(declaration=declaration, decorator=decorator)

def set_root_type_v1_only(
def set_root_type(
self,
root_type: ir_types.TypeReference,
annotation: Optional[AST.Expression] = None,
is_forward_ref: bool = False,
) -> None:
self.set_root_type_unsafe_v1_only(
self.set_root_type_unsafe(
root_type=self.get_type_hint_for_type_reference(root_type),
annotation=annotation,
is_forward_ref=is_forward_ref,
)

def set_root_type_unsafe_v1_only(
def set_root_type_unsafe(
self, root_type: AST.TypeHint, annotation: Optional[AST.Expression] = None, is_forward_ref: bool = False
) -> None:
self._pydantic_model.set_root_type_unsafe_v1_only(root_type=root_type, annotation=annotation)
self._pydantic_model.set_root_type_unsafe(root_type=root_type, annotation=annotation)

def add_ghost_reference(self, type_id: ir_types.TypeId) -> None:
self._pydantic_model.add_ghost_reference(
Expand All @@ -286,10 +286,7 @@ def add_ghost_reference(self, type_id: ir_types.TypeId) -> None:

def finish(self) -> None:
if self._custom_config.include_validators:
if (
self._pydantic_model._v1_root_type is None
and self._custom_config.version == PydanticVersionCompatibility.V1
):
if self._pydantic_model.root_type is None:
self._pydantic_model.add_partial_class()
self._get_validators_generator().add_validators()
if self._model_contains_forward_refs or self._force_update_forward_refs:
Expand All @@ -309,11 +306,11 @@ def finish(self) -> None:
self._pydantic_model.finish()

def _get_validators_generator(self) -> ValidatorsGenerator:
v1_root_type = self._pydantic_model.get_root_type_unsafe_v1_only()
if v1_root_type is not None and self._custom_config.version == PydanticVersionCompatibility.V1:
root_type = self._pydantic_model.get_root_type_unsafe()
if root_type is not None:
return PydanticV1CustomRootTypeValidatorsGenerator(
model=self._pydantic_model,
root_type=v1_root_type,
root_type=root_type,
)
else:
unique_name = []
Expand Down Expand Up @@ -374,4 +371,4 @@ def __exit__(
excinst: Optional[BaseException],
exctb: Optional[TracebackType],
) -> None:
self.finish()
self.finish()
Original file line number Diff line number Diff line change
Expand Up @@ -362,27 +362,26 @@ def get_dict_method(writer: AST.NodeWriter) -> None:
)
)

if self._custom_config.version == PydanticVersionCompatibility.V1:
external_pydantic_model.set_root_type_unsafe_v1_only(
is_forward_ref=True,
root_type=root_type,
annotation=AST.Expression(
AST.FunctionInvocation(
function_definition=Pydantic.Field(),
kwargs=[
(
"discriminator",
AST.Expression(
f'"{self._get_discriminant_attr_name()}"',
),
)
],
)
external_pydantic_model.set_root_type_unsafe(
is_forward_ref=True,
root_type=root_type,
annotation=AST.Expression(
AST.FunctionInvocation(
function_definition=Pydantic.Field(),
kwargs=[
(
"discriminator",
AST.Expression(
f'"{self._get_discriminant_attr_name()}"',
),
)
],
)
# can't use discriminator without single variant pydantic models
# https://github.com/pydantic/pydantic/pull/3639
if len(internal_single_union_types) != 1 else None,
)
# can't use discriminator without single variant pydantic models
# https://github.com/pydantic/pydantic/pull/3639
if len(internal_single_union_types) != 1 else None,
)

def _create_body_writer(
self,
Expand Down Expand Up @@ -499,4 +498,4 @@ def assert_never(arg: Never) -> Never:


def get_field_name_for_single_property(property: ir_types.SingleUnionTypeProperty) -> str:
return property.name.name.snake_case.safe_name
return property.name.name.snake_case.safe_name
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def generate(
should_export=True,
)
else:
# NOTE: We validate the config to ensure wrapped aliases are only available for Pydantic V1 users.
# As such, we force the root field to be __root__ as opposed to conditional based on the Pydantic version.
BUILDER_PARAMETER_NAME = "value"
with FernAwarePydanticModel(
class_name=self._context.get_class_name_for_type_id(self._name.type_id, as_request=False),
Expand All @@ -62,7 +60,7 @@ def generate(
docstring=self._docs,
snippet=self._snippet,
) as pydantic_model:
pydantic_model.set_root_type_v1_only(self._alias.alias_of)
pydantic_model.set_root_type(self._alias.alias_of)
pydantic_model.add_method(
name=self._get_getter_name(self._alias.alias_of),
parameters=[],
Expand Down Expand Up @@ -147,4 +145,4 @@ def __init__(
as_request=False,
)

# generate_snippet delegates to the parent class AliasSnippetGenerator
# generate_snippet delegates to the parent class AliasSnippetGenerator
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .pydantic_v1_custom_root_type_validators_generator import (
from .pydantic_custom_root_type_validators_generator import (
PydanticV1CustomRootTypeValidatorsGenerator,
)
from .pydantic_validators_generator import PydanticValidatorsGenerator
from .validators_generator import ValidatorsGenerator

__all__ = ["ValidatorsGenerator", "PydanticValidatorsGenerator", "PydanticV1CustomRootTypeValidatorsGenerator"]
__all__ = ["ValidatorsGenerator", "PydanticValidatorsGenerator", "PydanticV1CustomRootTypeValidatorsGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fern_python.pydantic_codegen import PydanticModel

from .validator_generators import (
PydanticV1CustomRootTypeValidatorGenerator,
PydanticCustomRootTypeValidatorGenerator,
ValidatorGenerator,
)
from .validators_generator import ValidatorsGenerator
Expand All @@ -14,7 +14,7 @@ class PydanticV1CustomRootTypeValidatorsGenerator(ValidatorsGenerator):
def __init__(self, root_type: AST.TypeHint, model: PydanticModel):
super().__init__(model=model)
self._root_type = root_type
self._root_type_generator = PydanticV1CustomRootTypeValidatorGenerator(
self._root_type_generator = PydanticCustomRootTypeValidatorGenerator(
root_type=root_type,
model=model,
reference_to_validators_class=self._get_reference_to_validators_class(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .field_validator_generator import FieldValidatorGenerator
from .pydantic_v1_custom_root_type_validator_generator import (
PydanticV1CustomRootTypeValidatorGenerator,
from .pydantic_custom_root_type_validator_generator import (
PydanticCustomRootTypeValidatorGenerator,
)
from .root_validator_generator import RootValidatorGenerator
from .validator_generator import ValidatorGenerator
Expand All @@ -9,5 +9,5 @@
"ValidatorGenerator",
"FieldValidatorGenerator",
"RootValidatorGenerator",
"PydanticV1CustomRootTypeValidatorGenerator",
"PydanticCustomRootTypeValidatorGenerator",
]
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Tuple

from fern_python.codegen import AST
from fern_python.external_dependencies.pydantic import PydanticVersionCompatibility
from fern_python.pydantic_codegen import PydanticModel

from .validator_generator import ValidatorGenerator


class PydanticV1CustomRootTypeValidatorGenerator(ValidatorGenerator):
class PydanticCustomRootTypeValidatorGenerator(ValidatorGenerator):
_DECORATOR_FUNCTION_NAME = "validate"
_VALIDATOR_CLASS_VALIDATORS_CLASS_VAR = "_validators"

Expand All @@ -23,6 +24,7 @@ def add_validator_to_model(self) -> None:
def _write_validator_body(self, writer: AST.NodeWriter) -> None:
ROOT_VARIABLE_NAME = "value"
INDIVIDUAL_VALIDATOR_NAME = "validator"
ROOT_PROPERTY_NAME = self._get_root_property_name()

writer.write(f"{ROOT_VARIABLE_NAME} = ")
writer.write_node(
Expand All @@ -33,7 +35,7 @@ def _write_validator_body(self, writer: AST.NodeWriter) -> None:
),
args=[
AST.Expression(self._root_type),
AST.Expression(f'{PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME}.get("__root__")'),
AST.Expression(f'{PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME}.get("{ROOT_PROPERTY_NAME}")'),
],
),
)
Expand All @@ -44,7 +46,7 @@ def _write_validator_body(self, writer: AST.NodeWriter) -> None:
".".join(
(
*self._reference_to_validators_class,
PydanticV1CustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR,
PydanticCustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR,
)
)
+ ":"
Expand All @@ -62,7 +64,7 @@ def _write_validator_body(self, writer: AST.NodeWriter) -> None:
)
writer.write_line()
writer.write_line(
"return " + f'{{ **{PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME}, "__root__": {ROOT_VARIABLE_NAME} }}'
f"return {{ **{PydanticModel.VALIDATOR_VALUES_PARAMETER_NAME}, '{ROOT_PROPERTY_NAME}': {ROOT_VARIABLE_NAME} }}"
)

def add_to_validators_class(self, validators_class: AST.ClassDeclaration) -> None:
Expand All @@ -71,20 +73,20 @@ def add_to_validators_class(self, validators_class: AST.ClassDeclaration) -> Non
validator_type = AST.TypeHint.callable([self._root_type], self._root_type)
validators_class.add_class_var(
AST.VariableDeclaration(
name=PydanticV1CustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR,
name=PydanticCustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR,
type_hint=AST.TypeHint.class_var(AST.TypeHint.list(validator_type)),
initializer=AST.Expression("[]"),
)
)
validators_class.add_method(
declaration=AST.FunctionDeclaration(
name=PydanticV1CustomRootTypeValidatorGenerator._DECORATOR_FUNCTION_NAME,
name=PydanticCustomRootTypeValidatorGenerator._DECORATOR_FUNCTION_NAME,
signature=AST.FunctionSignature(
parameters=[AST.FunctionParameter(name=VALIDATOR_PARAMETER, type_hint=validator_type)],
return_type=AST.TypeHint.none(),
),
body=AST.CodeWriter(
f"cls.{PydanticV1CustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR}"
f"cls.{PydanticCustomRootTypeValidatorGenerator._VALIDATOR_CLASS_VALIDATORS_CLASS_VAR}"
+ f".append({VALIDATOR_PARAMETER})"
),
),
Expand All @@ -94,7 +96,7 @@ def add_to_validators_class(self, validators_class: AST.ClassDeclaration) -> Non
def write_example_for_docstring(self, writer: AST.NodeWriter) -> None:

reference_to_decorator = ".".join(
(*self._reference_to_validators_class, PydanticV1CustomRootTypeValidatorGenerator._DECORATOR_FUNCTION_NAME)
(*self._reference_to_validators_class, PydanticCustomRootTypeValidatorGenerator._DECORATOR_FUNCTION_NAME)
)

with writer.indent():
Expand All @@ -109,3 +111,6 @@ def write_example_for_docstring(self, writer: AST.NodeWriter) -> None:

with writer.indent():
writer.write_line("...")

def _get_root_property_name(self) -> str:
return "__root__" if self._model._version == PydanticVersionCompatibility.V1 else "root"
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
)
self._has_aliases = False
self._version = version
self._v1_root_type: Optional[AST.TypeHint] = None
self._root_type: Optional[AST.TypeHint] = None
self._fields: List[PydanticField] = []
self._extra_fields = extra_fields
self._frozen = frozen
Expand Down Expand Up @@ -250,15 +250,16 @@ def add_root_validator(
),
)

def set_root_type_unsafe_v1_only(
def set_root_type_unsafe(
self, root_type: AST.TypeHint, annotation: Optional[AST.Expression] = None
) -> None:
if self._version != PydanticVersionCompatibility.V1:
raise RuntimeError("Overriding root types is only available in Pydantic v1")

if self._v1_root_type is not None:
if self._version == PydanticVersionCompatibility.Both:
raise RuntimeError("Overriding root types is only available in Pydantic v1 or v2")
if self._root_type is not None:
raise RuntimeError("__root__ was already added")
self._v1_root_type = root_type

self.root_type = root_type

root_type_with_annotation = (
AST.TypeHint.annotated(
Expand All @@ -269,12 +270,19 @@ def set_root_type_unsafe_v1_only(
else root_type
)

self._class_declaration.add_statement(
AST.VariableDeclaration(name="__root__", type_hint=root_type_with_annotation)
)
if self._version == PydanticVersionCompatibility.V1:
self._class_declaration.add_statement(
AST.VariableDeclaration(name="__root__", type_hint=root_type_with_annotation)
)

if self._version == PydanticVersionCompatibility.V2:
self._class_declaration.add_statement(
AST.VariableDeclaration(name="root", type_hint=root_type_with_annotation)
)


def get_root_type_unsafe_v1_only(self) -> Optional[AST.TypeHint]:
return self._v1_root_type
def get_root_type_unsafe(self) -> Optional[AST.TypeHint]:
return self.root_type

def add_inner_class(self, inner_class: AST.ClassDeclaration) -> None:
self._class_declaration.add_class(declaration=inner_class)
Expand Down Expand Up @@ -508,4 +516,4 @@ def write(writer: AST.NodeWriter) -> None:
def get_named_import_or_throw(reference: AST.Reference) -> str:
if reference.import_ is None or reference.import_.named_import is None:
raise RuntimeError("No named import defined on reference")
return reference.import_.named_import
return reference.import_.named_import
Loading

0 comments on commit 8b86656

Please sign in to comment.