Skip to content

Commit

Permalink
Support for Unions in schemas and validation
Browse files Browse the repository at this point in the history
fix: #1152
I would like pandera to support Union Type. That is the validation of a
Series/Column should allow multiple types.

1. Add a new PythonUnion type.
2. Add a new test to for the new UnionType.

Signed-off-by: karajan1001 <[email protected]>
  • Loading branch information
karajan1001 committed Jul 19, 2023
1 parent 19cc15d commit 273c49b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
38 changes: 37 additions & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Union,
cast,
)
from mypy.types import UnionType

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,7 +61,6 @@
else:
from typing_extensions import TypedDict # noqa


try:
from typing import Literal # type: ignore
except ImportError:
Expand Down Expand Up @@ -1330,3 +1330,39 @@ def __init__( # pylint:disable=super-init-not-called

def __str__(self) -> str:
return str(NamedTuple.__name__)


@Engine.register_dtype(equivalents=[UnionType, Union, "UnionType", "union"])
@dtypes.immutable(init=True)
class PythonUnion(PythonGenericType):
"""A datatype to support python generics."""

type = UnionType

def __init__( # pylint:disable=super-init-not-called
self, generic_type: Optional[Type] = None
) -> None:
if generic_type is not None:
object.__setattr__(self, "generic_type", generic_type)

def check(
self,
pandera_dtype: dtypes.DataType,
data_container: Optional[PandasObject] = None,
) -> Union[bool, Iterable[bool]]:
"""Check that data container has the expected type."""
pandera_dtype = Engine.dtype(pandera_dtype)

pandas_types = [object]
pandas_types.extend(
self.generic_type.__args__ # pylint: disable=no-member
)

# the underlying pandas dtype must be an object
if pandera_dtype not in map(Engine.dtype, pandas_types):
return False

if data_container is None:

Check warning on line 1365 in pandera/engines/pandas_engine.py

View check run for this annotation

Codecov / codecov/patch

pandera/engines/pandas_engine.py#L1365

Added line #L1365 was not covered by tests
return True
else:
return data_container.map(self._check_type) # type: ignore[operator]
39 changes: 33 additions & 6 deletions tests/core/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import sys
from decimal import Decimal
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import Any, Dict, List, NamedTuple, Tuple, Union

import hypothesis
import numpy as np
Expand Down Expand Up @@ -733,17 +733,44 @@ class PointTuple(NamedTuple):
"tuple_column": pa.Column(Tuple[int, str, float]),
"typeddict_column": pa.Column(PointDict),
"namedtuple_column": pa.Column(PointTuple),
"column_union_float": pa.Column(Union[str, float]),
"column_union_str": pa.Column(Union[str, float]),
"column_union_obj": pa.Column(Union[str, float]),
},
)

data = pd.DataFrame(
{
"dict_column": [{"foo": 1, "bar": 2}],
"list_column": [[1.0]],
"tuple_column": [(1, "bar", 1.0)],
"typeddict_column": [PointDict(x=2.1, y=4.8)],
"namedtuple_column": [PointTuple(x=9.2, y=1.6)],
"dict_column": [{"foo": 1, "bar": 2}, {"foobar": 3}],
"list_column": [[1.0], [2.0]],
"tuple_column": [(1, "bar", 1.0), (2, "foobar", 2.0)],
"typeddict_column": [
PointDict(x=2.1, y=4.8),
PointDict(x=2.5, y=9.0),
],
"namedtuple_column": [
PointTuple(x=9.2, y=1.6),
PointTuple(x=2.5, y=1.4),
],
"column_union_float": [1.0, 2.0],
"column_union_str": ["foo", "bar"],
"column_union_obj": [12.0, "foo"],
}
)

schema.validate(data)

float_or_str_schema = pa.DataFrameSchema(
{
"column_union": pa.Column(Union[str, float]),
},
)

int_data = pd.DataFrame(
{
"column_union": [1, 2],
}
)

with pytest.raises(pa.errors.SchemaError):
float_or_str_schema.validate(int_data)
1 change: 1 addition & 0 deletions tests/strategies/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
pandas_engine.PythonTuple,
pandas_engine.PythonTypedDict,
pandas_engine.PythonNamedTuple,
pandas_engine.PythonUnion,
]
)
SUPPORTED_DTYPES = set()
Expand Down

0 comments on commit 273c49b

Please sign in to comment.