Skip to content

Commit

Permalink
Only activate dataclass typing if the dataclass contains Kauldron typ…
Browse files Browse the repository at this point in the history
…e annotations

PiperOrigin-RevId: 721010355
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Jan 29, 2025
1 parent 1e99e0d commit 733ceda
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
35 changes: 34 additions & 1 deletion kauldron/typing/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,44 @@ def _dataclass_checker_lookup(
) -> typeguard.TypeCheckerCallable | None:
"""Lookup function to register custom dataclass checkers in typeguard."""
del args, extras
if dataclasses.is_dataclass(origin_type):
# Due to conflict with mixing Kauldron with other non-Kauldron jaxtyping
# objects, we only activate dataclass support for dataclasses annotated with
# Kauldron types.
# We do this by recursively checking if any dataclass attribute is annotated
# as a Kauldron type.
if _is_kd_dataclass(origin_type):
return _custom_dataclass_checker
return None


@functools.cache
def _is_kd_dataclass(obj) -> bool:
return _is_kd_dataclass_inner(obj, visited=set())


def _is_kd_dataclass_inner(obj, visited) -> bool:
if not dataclasses.is_dataclass(obj):
return False

visited.add(obj)
hints = typing.get_type_hints(obj)
return any(_is_kd_type(t, visited=visited) for t in hints.values())


def _is_kd_type(t: Any, visited: set[Any]) -> bool:
if t in visited: # Cycle
return False
origin = typing.get_origin(t)
if origin is None:
if inspect.getattr_static(t, "_kd_repr", None):
return True
return _is_kd_dataclass_inner(t, visited)
if origin in [Union, types.UnionType]:
return any(_is_kd_type(t, visited=visited) for t in typing.get_args(t))
# Could recurse into dict, list,... too
return _is_kd_dataclass_inner(t, visited)


def add_custom_checker_lookup_fn(lookup_fn):
"""Add custom array spec checker lookup function to typeguard."""
# Add custom array spec checker lookup function to typguard
Expand Down
31 changes: 29 additions & 2 deletions kauldron/typing/type_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

import dataclasses

import jaxtyping as jt
from kauldron.typing import Float, TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member
from kauldron.typing import type_check
import numpy as np
import pytest

Expand Down Expand Up @@ -72,7 +75,6 @@ def _foo(a: cls) -> Float["B T"]:


def test_dataclass():
pytest.skip("Currently disabled")

@dataclasses.dataclass
class A:
Expand Down Expand Up @@ -105,7 +107,6 @@ class TestB:


def test_nested_dataclass():
pytest.skip("Currently disabled")

@typechecked
def _foo(b: TestB) -> TestA:
Expand All @@ -116,3 +117,29 @@ def _foo(b: TestB) -> TestA:
with pytest.raises(TypeCheckError):
# Wrong shape
_foo(TestB(a=TestA(a=np.zeros((2, 2, 2)))))


@dataclasses.dataclass
class NestedA:
x: int
y: dict[str, bool]
z: "NestedA" # Recursive dataclass has to be defined in the global scope.


def test_union_type():

assert not type_check._is_kd_dataclass(NestedA)

@dataclasses.dataclass
class A:
x: int
y: jt.Float[jt.Array, "T B"] # jaxtyping is not a Kauldron dataclass

assert not type_check._is_kd_dataclass(A)

@dataclasses.dataclass
class B:
x: int
y: Float["T B"]

assert type_check._is_kd_dataclass(B)

0 comments on commit 733ceda

Please sign in to comment.