From 733ceda408441d3d1f08e027a767f0099e2e52d5 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Wed, 29 Jan 2025 09:41:36 -0800 Subject: [PATCH] Only activate dataclass typing if the dataclass contains Kauldron type annotations PiperOrigin-RevId: 721010355 --- kauldron/typing/type_check.py | 35 +++++++++++++++++++++++++++++- kauldron/typing/type_check_test.py | 31 ++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/kauldron/typing/type_check.py b/kauldron/typing/type_check.py index 89650583..37de42b6 100644 --- a/kauldron/typing/type_check.py +++ b/kauldron/typing/type_check.py @@ -410,11 +410,44 @@ def _dataclass_checker_lookup( ) -> typeguard.TypeCheckerCallable | None: """Lookup function to register custom dataclass checkers in typeguard.""" del args, extras - if dataclasses.is_dataclass(origin_type): + # Due to conflict with mixing Kauldron with other non-Kauldron jaxtyping + # objects, we only activate dataclass support for dataclasses annotated with + # Kauldron types. + # We do this by recursively checking if any dataclass attribute is annotated + # as a Kauldron type. + if _is_kd_dataclass(origin_type): return _custom_dataclass_checker return None +@functools.cache +def _is_kd_dataclass(obj) -> bool: + return _is_kd_dataclass_inner(obj, visited=set()) + + +def _is_kd_dataclass_inner(obj, visited) -> bool: + if not dataclasses.is_dataclass(obj): + return False + + visited.add(obj) + hints = typing.get_type_hints(obj) + return any(_is_kd_type(t, visited=visited) for t in hints.values()) + + +def _is_kd_type(t: Any, visited: set[Any]) -> bool: + if t in visited: # Cycle + return False + origin = typing.get_origin(t) + if origin is None: + if inspect.getattr_static(t, "_kd_repr", None): + return True + return _is_kd_dataclass_inner(t, visited) + if origin in [Union, types.UnionType]: + return any(_is_kd_type(t, visited=visited) for t in typing.get_args(t)) + # Could recurse into dict, list,... too + return _is_kd_dataclass_inner(t, visited) + + def add_custom_checker_lookup_fn(lookup_fn): """Add custom array spec checker lookup function to typeguard.""" # Add custom array spec checker lookup function to typguard diff --git a/kauldron/typing/type_check_test.py b/kauldron/typing/type_check_test.py index 26075a95..aaa2c7f8 100644 --- a/kauldron/typing/type_check_test.py +++ b/kauldron/typing/type_check_test.py @@ -13,7 +13,10 @@ # limitations under the License. import dataclasses + +import jaxtyping as jt from kauldron.typing import Float, TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member +from kauldron.typing import type_check import numpy as np import pytest @@ -72,7 +75,6 @@ def _foo(a: cls) -> Float["B T"]: def test_dataclass(): - pytest.skip("Currently disabled") @dataclasses.dataclass class A: @@ -105,7 +107,6 @@ class TestB: def test_nested_dataclass(): - pytest.skip("Currently disabled") @typechecked def _foo(b: TestB) -> TestA: @@ -116,3 +117,29 @@ def _foo(b: TestB) -> TestA: with pytest.raises(TypeCheckError): # Wrong shape _foo(TestB(a=TestA(a=np.zeros((2, 2, 2))))) + + +@dataclasses.dataclass +class NestedA: + x: int + y: dict[str, bool] + z: "NestedA" # Recursive dataclass has to be defined in the global scope. + + +def test_union_type(): + + assert not type_check._is_kd_dataclass(NestedA) + + @dataclasses.dataclass + class A: + x: int + y: jt.Float[jt.Array, "T B"] # jaxtyping is not a Kauldron dataclass + + assert not type_check._is_kd_dataclass(A) + + @dataclasses.dataclass + class B: + x: int + y: Float["T B"] + + assert type_check._is_kd_dataclass(B)