Skip to content

Commit

Permalink
Handle timestamp and nans in removing multi index failure cases # 1469
Browse files Browse the repository at this point in the history
Signed-off-by: Rory <[email protected]>
  • Loading branch information
rorymcstay committed Feb 27, 2024
1 parent f86675e commit a770bda
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
17 changes: 16 additions & 1 deletion pandera/backends/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Union,
)

import numpy as np
import pandas as pd

from pandera.api.base.checks import CheckResult
Expand Down Expand Up @@ -42,6 +43,13 @@ class ColumnInfo(NamedTuple):
regex_match_patterns: List


_MULTIINDEX_HANDLED_TYPES = {
"Timestamp": pd.Timestamp,
"NaT": pd.NaT,
"nan": np.nan,
}


FieldCheckObj = Union[pd.Series, pd.DataFrame]

T = TypeVar(
Expand Down Expand Up @@ -167,7 +175,14 @@ def drop_invalid_rows(self, check_obj, error_handler: SchemaErrorHandler):
if isinstance(check_obj.index, pd.MultiIndex):
# MultiIndex values are saved on the error as strings so need to be cast back
# to their original types
index_tuples = err.failure_cases["index"].apply(eval)
index_tuples = (
err.failure_cases["index"]
.astype(str)
.apply(lambda i: eval(i, _MULTIINDEX_HANDLED_TYPES))
)
# type check on a column of index.
if len(index_tuples) == 1 and index_tuples[0] is None:
continue
index_values = pd.MultiIndex.from_tuples(index_tuples)

mask = ~check_obj.index.isin(index_values)
Expand Down
139 changes: 139 additions & 0 deletions tests/core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,3 +2565,142 @@ def test_get_schema_metadata():
}
}
assert expected == metadata


@pytest.mark.parametrize(
"schema, obj, expected_obj, check_dtype",
[
(
DataFrameSchema(
columns={
"temperature": Column(float, nullable=False),
},
index=MultiIndex(
[
Index(pd.Timestamp, name="timestamp"),
Index(str, name="city"),
]
),
drop_invalid_rows=True,
),
pd.DataFrame(
{
"temperature": [
3.0,
4.0,
5.0,
5.0,
np.nan,
2.0,
],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2023-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
),
names=["timestamp", "city"],
),
),
pd.DataFrame(
{
"temperature": [3.0, 4.0, 5.0, 5.0, 2.0],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
),
names=["timestamp", "city"],
),
),
True,
),
(
DataFrameSchema(
columns={
"temperature": Column(float, nullable=False),
},
index=MultiIndex(
[
Index(pd.Timestamp, name="timestamp"),
Index(str, name="city"),
]
),
drop_invalid_rows=True,
),
pd.DataFrame(
{
"temperature": [
3.0,
4.0,
5.0,
-1.0,
np.nan,
-2.0,
4.0,
5.0,
2.0,
],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2023-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
(
pd.Timestamp("2024-01-01", tz="Europe/London"),
"London",
),
(pd.Timestamp(pd.NaT), "Frankfurt"),
(pd.Timestamp("2024-01-01"), 6),
),
names=["timestamp", "city"],
),
),
pd.DataFrame(
{
"temperature": [3.0, 4.0, 5.0, -1.0, -2.0, 4],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
(
pd.Timestamp("2024-01-01", tz="Europe/London"),
"London",
),
),
names=["timestamp", "city"],
),
),
False,
),
],
)
def test_drop_invalid_for_multi_index_with_datetime(
schema, obj, expected_obj, check_dtype
):
"""Test drop_invalid_rows works as expected on multi-index dataframes"""
actual_obj = schema.validate(obj, lazy=True)

# the datatype of the index is not casted, In this cases its an object
pd.testing.assert_frame_equal(
actual_obj,
expected_obj,
check_dtype=check_dtype,
check_index_type=check_dtype,
)

0 comments on commit a770bda

Please sign in to comment.