Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
KerstenBreuer committed Nov 15, 2023
1 parent 2ea64e3 commit ea26a79
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/schemapack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def validator(self) -> JsonSchemaValidator:
Raises:
JsonSchemaError: If the schema is invalid.
"""
return get_json_schema_validator(self.json_schema)
return get_json_schema_validator(dict(self.json_schema))

@model_validator(mode="after")
def trigger_validator_construction(self) -> "ContentSchema":
Expand Down
32 changes: 21 additions & 11 deletions src/schemapack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import json
import os
import typing
from collections.abc import Mapping
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -30,7 +31,7 @@
from immutabledict import immutabledict
from jsonschema.protocols import Validator as JsonSchemaValidator
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from pydantic_core import core_schema


class DecodeError(ValueError):
Expand Down Expand Up @@ -122,13 +123,22 @@ class FrozenDict(immutabledict[_K, _V_co]):

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Get a pydantic core schema for this class."""
return core_schema.no_info_after_validator_function(
cls,
handler(dict),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: dict(instance)
),
)
cls, source: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Get the pydantic core schema for this type."""
# Validate the type against a dict:
# (this will have the side effect of converting the instance to dict even if it
# is already a dict a FrozenDict)
args = typing.get_args(source)
if args:
if len(args) != 2:
raise TypeError(
"Expected exactly two (or no) type arguments for FrozenDict, got"
+ f" {len(args)}"
)
dict_schema = handler.generate_schema(dict[args[0], args[1]]) # type: ignore
else:
dict_schema = handler.generate_schema(dict)

# Uses cls as validator function to convert the dict to a FrozenDict:
return core_schema.no_info_after_validator_function(cls, dict_schema)
37 changes: 3 additions & 34 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,7 @@
# limitations under the License.
#

from schemapack.utils import FrozenDict

from re import I
from typing import Any

from immutabledict import immutabledict
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

# class immutabledict(immutabledict_):
# """Wrapper around immutabledict to make it pydantic compatible."""


class FrozenDict(immutabledict):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
handler(dict),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: dict(instance)
),
)


class InnerTest(BaseModel):
dict_: FrozenDict


class Test(BaseModel):
inner: FrozenDict[str, InnerTest]


test = Test(inner={"test": {"dict_": {1: 2}}})
test = FrozenDict({1: {2: 3}})
test[1][2] = 4
2 changes: 1 addition & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test dummy."""
"""Tests the load module."""

from pathlib import Path

Expand Down
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,23 @@ class Test(BaseModel):
hash(test1)

assert test1 == test2


def test_frozen_dict_nesting():
"""Test hashing and comparison of pydantic models using FrozenDicts."""

class Inner(BaseModel):
dict_: FrozenDict

class Test(BaseModel):
inner: FrozenDict[str, Inner]

test_dict = {1: 2}

test1 = Test.model_validate({"inner": {"test": {"dict_": test_dict}}})
assert isinstance(test1.inner["test"], Inner)

test2 = Test(
inner=FrozenDict({"test": FrozenDict({"dict_": FrozenDict(test_dict)})})
)
assert test1 == test2

0 comments on commit ea26a79

Please sign in to comment.