diff --git a/pandera/io.py b/pandera/io.py index 10fb8989a..503567316 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -7,25 +7,27 @@ from pathlib import Path from typing import Dict, Optional, Union +import numpy as np import pandas as pd import pandera.errors from . import dtypes from .checks import Check -from .engines import pandas_engine -from .schema_components import Column +from .engines import numpy_engine, pandas_engine +from .schema_components import Column, Index, MultiIndex, SeriesSchemaBase from .schema_statistics import get_dataframe_schema_statistics from .schemas import DataFrameSchema try: import black + import pyarrow import yaml from frictionless import Schema as FrictionlessSchema except ImportError as exc: # pragma: no cover raise ImportError( - "IO and formatting requires 'pyyaml', 'black' and 'frictionless'" - "to be installed.\n" + "IO and formatting requires 'pyyaml', 'black', 'frictionless' and " + "`pyarrow` to be installed.\n" "You can install pandera together with the IO dependencies with:\n" "pip install pandera[io]\n" ) from exc @@ -246,8 +248,6 @@ def deserialize_schema(serialized_schema): :returns: the schema de-serialized into :class:`~pandera.schemas.DataFrameSchema` """ - # pylint: disable=import-outside-toplevel - from pandera import Index, MultiIndex # GH#475 serialized_schema = serialized_schema if serialized_schema else {} @@ -806,3 +806,109 @@ def from_frictionless_schema( ), } return deserialize_schema(assembled_schema) + + +def to_pyarrow_field( + name: str, + pandera_field: SeriesSchemaBase, +) -> pyarrow.Field: + """ + Convert a :class:`~pandera.schema_components.SeriesSchemaBase` to a + ``pyarrow.Field`` + + :param pandera_field: pandera Index or Column + :returns: ``pyarrow.Field`` representation of ``pandera_field`` + """ + + pandera_dtype = pandera_field.dtype + pandas_dtype = pandas_engine.Engine.dtype(pandera_dtype).type + + pandas_types = { + pd.BooleanDtype(): pyarrow.bool_(), + pd.Int8Dtype(): pyarrow.int8(), + pd.Int16Dtype(): pyarrow.int16(), + pd.Int32Dtype(): pyarrow.int32(), + pd.Int64Dtype(): pyarrow.int64(), + pd.UInt8Dtype(): pyarrow.uint8(), + pd.UInt16Dtype(): pyarrow.uint16(), + pd.UInt32Dtype(): pyarrow.uint32(), + pd.UInt64Dtype(): pyarrow.uint64(), + pd.Float32Dtype(): pyarrow.float32(), # type: ignore[attr-defined] + pd.Float64Dtype(): pyarrow.float64(), # type: ignore[attr-defined] + pd.StringDtype(): pyarrow.string(), + } + + if pandas_dtype in pandas_types: + pyarrow_type = pandas_types[pandera_field.dtype.type] + elif isinstance( + pandera_dtype, (pandas_engine.Date, numpy_engine.DateTime64) + ): + pyarrow_type = pyarrow.date64() + elif isinstance(pandera_field.dtype, dtypes.Category): + # Categorical data types + pyarrow_type = pyarrow.dictionary( + pyarrow.int8(), + pandera_dtype.type.categories.inferred_type, + ordered=pandera_dtype.ordered, # type: ignore[attr-defined] + ) + elif pandas_dtype.type == np.object_: + pyarrow_type = pyarrow.string() + else: + pyarrow_type = pyarrow.from_numpy_dtype(pandas_dtype) + + return pyarrow.field(name, pyarrow_type, pandera_field.nullable) + + +def _get_index_name(level: int) -> str: + """Generate an index name for pyarrow if none is specified""" + return f"__index_level_{level}__" + + +def to_pyarrow_schema( + dataframe_schema: DataFrameSchema, + preserve_index: Optional[bool] = None, +) -> pyarrow.Schema: + """ + Convert a :class:`~pandera.schemas.DataFrameSchema` to ``pyarrow.Schema``. + + :param dataframe_schema: schema to convert to ``pyarrow.Schema`` + :param preserve_index: whether to store the index as an additional column + (or columns, for MultiIndex) in the resulting Table. The default of + None will store the index as a column, except for RangeIndex which is + stored as metadata only. Use ``preserve_index=True`` to force it to be + stored as a column. + :returns: ``pyarrow.Schema`` representation of DataFrameSchema + """ + + # List of columns that will be present in the pyarrow schema + columns: Dict[str, SeriesSchemaBase] = dataframe_schema.columns # type: ignore[assignment] + + # pyarrow schema metadata + metadata: Dict[str, bytes] = {} + + index = dataframe_schema.index + if index is None: + if preserve_index: + # Create column for RangeIndex + name = _get_index_name(0) + columns[name] = Index(dtypes.Int64, nullable=False, name=name) + else: + # Only preserve metadata of index + meta_val = b'[{"kind": "range", "name": pyarrow.null, "step": 1}]' + metadata["index_columns"] = meta_val + elif preserve_index is not False: + # Add column(s) for index(es) + if isinstance(index, Index): + name = index.name or _get_index_name(0) + # Ensure index is added at dictionary beginning + columns = {**{name: index}, **columns} + + elif isinstance(index, MultiIndex): + for i, value in enumerate(reversed(index.indexes)): + name = value.name or _get_index_name(i) + columns = {**{name: value}, **columns} + + return pyarrow.schema( + [to_pyarrow_field(k, v) for k, v in columns.items()], + metadata=metadata, + ) diff --git a/setup.py b/setup.py index 80fc0811f..be9d7f361 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ _extras_require = { "strategies": ["hypothesis >= 5.41.1"], "hypotheses": ["scipy"], - "io": ["pyyaml >= 5.1", "black", "frictionless"], + "io": ["pyyaml >= 5.1", "black", "frictionless", "pyarrow"], "pyspark": ["pyspark >= 3.2.0"], "modin": ["modin", "ray <= 1.7.0", "dask"], "modin-ray": ["modin", "ray <= 1.7.0"], diff --git a/tests/io/test_io.py b/tests/io/test_io.py index adadc482c..24e92d1ba 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -4,6 +4,7 @@ import tempfile from io import StringIO from pathlib import Path +from typing import Type from unittest import mock import pandas as pd @@ -13,7 +14,9 @@ import pandera import pandera.extensions as pa_ext import pandera.typing as pat -from pandera.engines import pandas_engine +from pandera import dtypes +from pandera.engines import numpy_engine, pandas_engine +from pandera.schema_components import Column try: from pandera import io @@ -34,6 +37,14 @@ SKIP_YAML_TESTS = PYYAML_VERSION is None or PYYAML_VERSION.release < (5, 1, 0) # type: ignore +try: + import pyarrow +except ImportError: + SKIP_PYARROW_TESTS = True +else: + SKIP_PYARROW_TESTS = False + + # skip all tests in module if "io" depends aren't installed pytestmark = pytest.mark.skipif( not HAS_IO, reason='needs "io" module dependencies' @@ -1362,3 +1373,198 @@ def test_frictionless_schema_primary_key(frictionless_schema): assert schema.unique == frictionless_schema["primaryKey"] for key in frictionless_schema["primaryKey"]: assert not schema.columns[key].unique + + +@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required") +@pytest.mark.parametrize( + "pandera_dtype, expected_pyarrow_dtype", + [ + (dtypes.Bool, pyarrow.bool_()), + (numpy_engine.Bool, pyarrow.bool_()), + (pandas_engine.BOOL, pyarrow.bool_()), + (dtypes.Int8, pyarrow.int8()), + (dtypes.Int16, pyarrow.int16()), + (dtypes.Int32, pyarrow.int32()), + (dtypes.Int64, pyarrow.int64()), + (numpy_engine.Int8, pyarrow.int8()), + (numpy_engine.Int16, pyarrow.int16()), + (numpy_engine.Int32, pyarrow.int32()), + (numpy_engine.Int64, pyarrow.int64()), + (pandas_engine.INT8, pyarrow.int8()), + (pandas_engine.INT16, pyarrow.int16()), + (pandas_engine.INT32, pyarrow.int32()), + (pandas_engine.INT64, pyarrow.int64()), + (dtypes.UInt8, pyarrow.uint8()), + (dtypes.UInt16, pyarrow.uint16()), + (dtypes.UInt32, pyarrow.uint32()), + (dtypes.UInt64, pyarrow.uint64()), + (numpy_engine.UInt8, pyarrow.uint8()), + (numpy_engine.UInt16, pyarrow.uint16()), + (numpy_engine.UInt32, pyarrow.uint32()), + (numpy_engine.UInt64, pyarrow.uint64()), + (pandas_engine.UINT8, pyarrow.uint8()), + (pandas_engine.UINT16, pyarrow.uint16()), + (pandas_engine.UINT32, pyarrow.uint32()), + (pandas_engine.UINT64, pyarrow.uint64()), + (dtypes.Float16, pyarrow.float16()), + (dtypes.Float32, pyarrow.float32()), + (dtypes.Float64, pyarrow.float64()), + (numpy_engine.Float16, pyarrow.float16()), + (numpy_engine.Float32, pyarrow.float32()), + (numpy_engine.Float64, pyarrow.float64()), + (pandas_engine.FLOAT32, pyarrow.float32()), + (pandas_engine.FLOAT64, pyarrow.float64()), + (dtypes.String, pyarrow.string()), + (numpy_engine.String, pyarrow.string()), + (pandas_engine.STRING, pyarrow.string()), + (pandas_engine.NpString, pyarrow.string()), + (numpy_engine.Object, pyarrow.string()), + (numpy_engine.Bytes, pyarrow.binary()), + (dtypes.Date, pyarrow.date64()), + (pandas_engine.Date, pyarrow.date64()), + (dtypes.Timestamp, pyarrow.timestamp("ns")), + (numpy_engine.DateTime64, pyarrow.date64()), + (pandas_engine.DateTime, pyarrow.timestamp("ns")), + (dtypes.Timedelta, pyarrow.duration("ns")), + (numpy_engine.Timedelta64, pyarrow.duration("ns")), + ( + dtypes.Category(categories=["foo", "bar", "baz"], ordered=True), + pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=True), + ), + ], +) +@pytest.mark.parametrize("nullable", [True, False]) +def test_to_pyarrow_field( + pandera_dtype: Type[dtypes.DataType], + nullable: bool, + expected_pyarrow_dtype: pyarrow.DataType, +): + """Test if pandera_dtype is correctly converted to pyarrow dtype""" + name = "foo" + + pandera_field = Column(pandera_dtype, nullable=nullable, name=name) + pyarrow_dtype = io.to_pyarrow_field(name, pandera_field) + + assert pyarrow_dtype.type == expected_pyarrow_dtype + assert pyarrow_dtype.name == name + assert pyarrow_dtype.nullable == nullable + + +@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required") +@pytest.mark.parametrize( + "dataframe_schema, preserve_index, expected", + [ + ( + _create_schema("single"), + True, + pyarrow.schema( + [ + pyarrow.field("__index_level_0__", pyarrow.int64(), False), + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema(None), + None, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema("multi"), + None, + pyarrow.schema( + [ + pyarrow.field("int_index0", pyarrow.int64(), False), + pyarrow.field("int_index1", pyarrow.int64(), False), + pyarrow.field("int_index2", pyarrow.int64(), False), + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema("multi"), + False, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema_python_types(), + None, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field("object_column", pyarrow.string(), False), + ] + ), + ), + ], +) +def test_to_pyarrow_schema( + dataframe_schema: pandera.schemas.DataFrameSchema, + preserve_index: bool, + expected: pyarrow.Schema, +): + """Test if pandera schema is correctly converted to pyarrow.Schema""" + + # Drop column with no dtype specified + dataframe_schema.columns = { + k: v + for k, v in dataframe_schema.columns.items() + if k != "notype_column" + } + + pyarrow_schema = io.to_pyarrow_schema(dataframe_schema, preserve_index) + assert expected.equals(pyarrow_schema)