diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 35fae1194..2f3c6bc23 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -25,6 +25,7 @@ Union, cast, ) +from mypy.types import UnionType import numpy as np import pandas as pd @@ -60,7 +61,6 @@ else: from typing_extensions import TypedDict # noqa - try: from typing import Literal # type: ignore except ImportError: @@ -1330,3 +1330,39 @@ def __init__( # pylint:disable=super-init-not-called def __str__(self) -> str: return str(NamedTuple.__name__) + + +@Engine.register_dtype(equivalents=[UnionType, Union, "UnionType", "union"]) +@dtypes.immutable(init=True) +class PythonUnion(PythonGenericType): + """A datatype to support python generics.""" + + type = UnionType + + def __init__( # pylint:disable=super-init-not-called + self, generic_type: Optional[Type] = None + ) -> None: + if generic_type is not None: + object.__setattr__(self, "generic_type", generic_type) + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Optional[PandasObject] = None, + ) -> Union[bool, Iterable[bool]]: + """Check that data container has the expected type.""" + pandera_dtype = Engine.dtype(pandera_dtype) + + pandas_types = [object] + pandas_types.extend( + self.generic_type.__args__ # pylint: disable=no-member + ) + + # the underlying pandas dtype must be an object + if pandera_dtype not in map(Engine.dtype, pandas_types): + return False + + if data_container is None: + return True + else: + return data_container.map(self._check_type) # type: ignore[operator] diff --git a/tests/core/test_dtypes.py b/tests/core/test_dtypes.py index dbcfac26e..1aeb5fa3c 100644 --- a/tests/core/test_dtypes.py +++ b/tests/core/test_dtypes.py @@ -8,7 +8,7 @@ import re import sys from decimal import Decimal -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import Any, Dict, List, NamedTuple, Tuple, Union import hypothesis import numpy as np @@ -733,17 +733,44 @@ class PointTuple(NamedTuple): "tuple_column": pa.Column(Tuple[int, str, float]), "typeddict_column": pa.Column(PointDict), "namedtuple_column": pa.Column(PointTuple), + "column_union_float": pa.Column(Union[str, float]), + "column_union_str": pa.Column(Union[str, float]), + "column_union_obj": pa.Column(Union[str, float]), }, ) data = pd.DataFrame( { - "dict_column": [{"foo": 1, "bar": 2}], - "list_column": [[1.0]], - "tuple_column": [(1, "bar", 1.0)], - "typeddict_column": [PointDict(x=2.1, y=4.8)], - "namedtuple_column": [PointTuple(x=9.2, y=1.6)], + "dict_column": [{"foo": 1, "bar": 2}, {"foobar": 3}], + "list_column": [[1.0], [2.0]], + "tuple_column": [(1, "bar", 1.0), (2, "foobar", 2.0)], + "typeddict_column": [ + PointDict(x=2.1, y=4.8), + PointDict(x=2.5, y=9.0), + ], + "namedtuple_column": [ + PointTuple(x=9.2, y=1.6), + PointTuple(x=2.5, y=1.4), + ], + "column_union_float": [1.0, 2.0], + "column_union_str": ["foo", "bar"], + "column_union_obj": [12.0, "foo"], } ) schema.validate(data) + + float_or_str_schema = pa.DataFrameSchema( + { + "column_union": pa.Column(Union[str, float]), + }, + ) + + int_data = pd.DataFrame( + { + "column_union": [1, 2], + } + ) + + with pytest.raises(pa.errors.SchemaError): + float_or_str_schema.validate(int_data) diff --git a/tests/strategies/test_strategies.py b/tests/strategies/test_strategies.py index 279185e7d..ab2ff2864 100644 --- a/tests/strategies/test_strategies.py +++ b/tests/strategies/test_strategies.py @@ -42,6 +42,7 @@ pandas_engine.PythonTuple, pandas_engine.PythonTypedDict, pandas_engine.PythonNamedTuple, + pandas_engine.PythonUnion, ] ) SUPPORTED_DTYPES = set()