Skip to content

Commit

Permalink
[PySpark] Improve validation performance by enabling cache()/`unper…
Browse files Browse the repository at this point in the history
…sist()` toggles (#1414)

* enables caching/unpersisting, tests and docs

Signed-off-by: Filipe Oliveira <[email protected]>

* improve code coverage through new test file for decorators

Signed-off-by: Filipe Oliveira <[email protected]>

* add check_obj as an arg and small improvements

Signed-off-by: Filipe Oliveira <[email protected]>

* change envvar from PYSPARK_UNPERSIST to PYSPARK_KEEP_CACHE

Signed-off-by: Filipe Oliveira <[email protected]>

* change docs

Signed-off-by: Filipe Oliveira <[email protected]>

* change envvar to make them generic

Signed-off-by: Filipe Oliveira <[email protected]>

---------

Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck authored Nov 20, 2023
1 parent af0e5c0 commit bc5e37a
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
dask-worker-space
spark-warehouse
docs/source/_contents
**.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Binary file removed docs/.DS_Store
Binary file not shown.
40 changes: 40 additions & 0 deletions docs/source/pyspark_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,46 @@ By default, validations are enabled and depth is set to ``SCHEMA_AND_DATA`` whic
can be changed to ``SCHEMA_ONLY`` or ``DATA_ONLY`` as required by the use case.


Caching control
---------------

*new in 0.17.3*

Given Spark's architecture and Pandera's internal implementation of PySpark integration
that relies on filtering conditions and *count* commands,
the PySpark DataFrame being validated by a Pandera schema may be reprocessed
multiple times, as each *count* command triggers a new underlying *Spark action*.
This processing overhead is directly related to the amount of *schema* and *data* checks
added to the Pandera schema.

To avoid such reprocessing time, Pandera allows you to cache the PySpark DataFrame
before validation starts, through the use of two environment variables:

.. code-block:: bash
export PANDERA_CACHE_DATAFRAME=True # Default is False, do not `cache()` by default
export PANDERA_KEEP_CACHED_DATAFRAME=True # Default is False, `unpersist()` by default
The first controls if current DataFrame state should be cached in your Spark Session
before the validation starts. The second controls if such cached state should still be
kept after the validation ends.

.. note::

To cache or not is a trade-off analysis: if you have enough memory to keep
the dataframe cached, it will speed up the validation timings as the validation
process will make use of this cached state.

Keeping the cached state and opting for not throwing it away when the
validation ends is important when the Pandera validation of a dataset is not
an individual process, but one step of the pipeline: if you have a pipeline that,
in a single Spark session, uses Pandera to evaluate all input dataframes before
transforming them in an result that will be written to disk, it may make sense
to not throw away the cached states in this session. In the end, the already
processed states of these dataframes will still be used after the validation ends
and storing them in memory may be beneficial.


Registering Custom Checks
-------------------------

Expand Down
7 changes: 6 additions & 1 deletion pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler
from pandera.api.pyspark.types import is_table
from pandera.backends.pyspark.base import ColumnInfo, PysparkSchemaBackend
from pandera.backends.pyspark.decorators import ValidationScope, validate_scope
from pandera.backends.pyspark.decorators import (
ValidationScope,
validate_scope,
cache_check_obj,
)
from pandera.backends.pyspark.error_formatters import scalar_failure_case
from pandera.config import CONFIG
from pandera.errors import (
Expand Down Expand Up @@ -102,6 +106,7 @@ def _data_checks(

return check_obj

@cache_check_obj()
def validate(
self,
check_obj: DataFrame,
Expand Down
77 changes: 74 additions & 3 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""This module holds the decorators only valid for pyspark"""

import functools
import logging
import warnings
from contextlib import contextmanager
from enum import Enum
from typing import List, Type

import pyspark.sql

from pyspark.sql import DataFrame
from pandera.api.pyspark.types import PysparkDefaultTypes
from pandera.config import CONFIG, ValidationDepth
from pandera.errors import SchemaError

logger = logging.getLogger(__name__)


class ValidationScope(Enum):
"""Indicates whether a check/validator operates at a schema of data level."""
Expand Down Expand Up @@ -90,7 +93,7 @@ def _get_check_obj():
"""
if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
if isinstance(value, DataFrame):
return value

if scope == ValidationScope.SCHEMA:
Expand Down Expand Up @@ -126,3 +129,71 @@ def _get_check_obj():
return wrapper

return _wrapper


def cache_check_obj():
"""This decorator evaluates if `check_obj` should be cached before validation.
As each new data check added to the Pandera schema by the user triggers a new
Spark action, Spark reprocesses the `check_obj` DataFrame multiple times.
To prevent this waste of processing resources and to reduce validation times in
complex scenarios, the decorator created by this factory caches the `check_obj`
DataFrame before validation and unpersists it afterwards.
This decorator is meant to be used primarily in the `validate()` function
entrypoint.
The behavior of the resulting decorator depends on the `PANDERA_PYSPARK_CACHING` and
`PANDERA_KEEP_CACHED_DATAFRAME` (optional) environment variables.
Usage:
@cache_check_obj()
def validate(check_obj: DataFrame):
# ...
"""

def _wrapper(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Skip if not enabled
if CONFIG.pyspark_cache is not True:
return func(self, *args, **kwargs)

check_obj: DataFrame = None

# Check if decorated function has a dataframe object as an positional arg
for arg in args:
if isinstance(arg, DataFrame):
check_obj = arg
break

# If it doesn't exist, fallback to kwargs and search for a `check_obj` key
if check_obj is None:
check_obj = kwargs.get("check_obj", None)

if not isinstance(check_obj, DataFrame):
raise ValueError(
"Expected to find a DataFrame object in a arg or a `check_obj` "
"kwarg in the decorated function "
f"`{func.__name__}`. Got {args=}/{kwargs=}"
)

@contextmanager
def cached_check_obj():
"""Cache the dataframe and unpersist it after function execution."""
logger.debug("Caching dataframe...")
check_obj.cache()

yield # Execute the decorated function

if not CONFIG.pyspark_keep_cache:
# If not cached, `.unpersist()` does nothing
logger.debug("Unpersisting dataframe...")
check_obj.unpersist()

with cached_check_obj():
return func(self, *args, **kwargs)

return wrapper

return _wrapper
12 changes: 12 additions & 0 deletions pandera/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ class PanderaConfig(BaseModel):
This should pick up environment variables automatically, e.g.:
export PANDERA_VALIDATION_ENABLED=False
export PANDERA_VALIDATION_DEPTH=DATA_ONLY
export PANDERA_CACHE_DATAFRAME=True
export PANDERA_KEEP_CACHED_DATAFRAME=True
"""

validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA
pyspark_cache: bool = False
pyspark_keep_cache: bool = False


# this config variable should be accessible globally
Expand All @@ -35,4 +39,12 @@ class PanderaConfig(BaseModel):
validation_depth=os.environ.get(
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
),
pyspark_cache=os.environ.get(
"PANDERA_CACHE_DATAFRAME",
False,
),
pyspark_keep_cache=os.environ.get(
"PANDERA_KEEP_CACHED_DATAFRAME",
False,
),
)
43 changes: 37 additions & 6 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint:disable=import-outside-toplevel,abstract-method

import pyspark.sql.types as T
import pytest

from pandera.config import CONFIG, ValidationDepth
from pandera.pyspark import (
Expand All @@ -24,7 +25,7 @@ def test_disable_validation(self, spark, sample_spark_schema):

CONFIG.validation_enabled = False

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -41,10 +42,12 @@ class TestSchema(DataFrameModel):
expected = {
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_keep_cache": False,
}

assert CONFIG.dict() == expected
assert pandra_schema.validate(input_df)
assert pandera_schema.validate(input_df)
assert TestSchema.validate(input_df)

# pylint:disable=too-many-locals
Expand All @@ -63,6 +66,8 @@ def test_schema_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_ONLY,
"pyspark_cache": False,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -132,7 +137,7 @@ def test_data_only(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.DATA_ONLY

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -141,11 +146,13 @@ def test_data_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.DATA_ONLY,
"pyspark_cache": False,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"DATA": {
"DATAFRAME_CHECK": [
Expand Down Expand Up @@ -217,7 +224,7 @@ def test_schema_and_data(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.SCHEMA_AND_DATA

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -226,11 +233,13 @@ def test_schema_and_data(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"DATA": {
"DATAFRAME_CHECK": [
Expand Down Expand Up @@ -326,3 +335,25 @@ class TestSchema(DataFrameModel):
dict(output_dataframemodel_df.pandera.errors["SCHEMA"])
== expected_dataframemodel["SCHEMA"]
)

@pytest.mark.parametrize("cache_enabled", [True, False])
@pytest.mark.parametrize("keep_cache_enabled", [True, False])
# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
self,
cache_enabled,
keep_cache_enabled,
):
"""This function validates setters and getters for cache/keep_cache options."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_keep_cache = keep_cache_enabled

# Evaluate expected Config
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": cache_enabled,
"pyspark_keep_cache": keep_cache_enabled,
}
assert CONFIG.dict() == expected
Loading

0 comments on commit bc5e37a

Please sign in to comment.