Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✅ TESTS ✅ chore: Setting the default spark_version value from pyspark.__version__ #237

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
10 changes: 8 additions & 2 deletions pydeequ/configs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import logging
from functools import lru_cache
import os
import re

import pyspark

SPARK_TO_DEEQU_COORD_MAPPING = {
"3.5": "com.amazon.deequ:deequ:2.0.7-spark-3.5",
Expand All @@ -22,7 +23,12 @@ def _extract_major_minor_versions(full_version: str):
@lru_cache(maxsize=None)
def _get_spark_version() -> str:
try:
spark_version = os.environ["SPARK_VERSION"]
spark_version = os.getenv("SPARK_VERSION")
if not spark_version:
spark_version = str(pyspark.__version__)
logging.info(
f"SPARK_VERSION environment variable is not set, using Spark version from PySpark {spark_version} for Deequ Maven jars"
)
except KeyError:
raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}")

Expand Down
35 changes: 34 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import os
from unittest import mock

import pyspark
import pytest
from pydeequ.configs import _extract_major_minor_versions

from pydeequ.configs import _extract_major_minor_versions, _get_spark_version


@pytest.fixture
def mock_env(monkeypatch):
with mock.patch.dict(os.environ, clear=True):
monkeypatch.delenv("SPARK_VERSION", raising=False)
yield


@pytest.mark.parametrize(
Expand All @@ -13,3 +25,24 @@
)
def test_extract_major_minor_versions(full_version, major_minor_version):
assert _extract_major_minor_versions(full_version) == major_minor_version


@pytest.mark.parametrize(
"spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")]
)
def test__get_spark_versione(spark_version, expected, mock_env):
try:
_get_spark_version.cache_clear()
with mock.patch.object(pyspark, "__version__", spark_version):
assert _get_spark_version() == expected
finally:
_get_spark_version.cache_clear()



@pytest.mark.parametrize(
"spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.2"), ("3.10.3", "3.2"), ("3.10", "3.2")]
)
def test__get_spark_version_with_cache(spark_version, expected, mock_env):
with mock.patch.object(pyspark, "__version__", spark_version):
assert _get_spark_version() == expected