From d618f9ecf8bebc86a95877b164c6b768e5a43c79 Mon Sep 17 00:00:00 2001 From: Arsen Gumin Date: Thu, 23 Nov 2023 22:26:45 +0300 Subject: [PATCH 1/5] chore: Setting the default spark_version value from pyspark.__version__ --- pydeequ/configs.py | 6 ++++-- tests/test_config.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pydeequ/configs.py b/pydeequ/configs.py index c3c885d..fb184e2 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -2,7 +2,7 @@ from functools import lru_cache import os import re - +import pyspark SPARK_TO_DEEQU_COORD_MAPPING = { "3.3": "com.amazon.deequ:deequ:2.0.3-spark-3.3", @@ -23,7 +23,9 @@ def _extract_major_minor_versions(full_version: str): @lru_cache(maxsize=None) def _get_spark_version() -> str: try: - spark_version = os.environ["SPARK_VERSION"] + spark_version = os.getenv("SPARK_VERSION") + if not spark_version: + spark_version = str(pyspark.__version__) except KeyError: raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}") diff --git a/tests/test_config.py b/tests/test_config.py index c2956b3..7b05e9c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,9 @@ +from unittest import mock + +import pyspark import pytest -from pydeequ.configs import _extract_major_minor_versions + +from pydeequ.configs import _extract_major_minor_versions, _get_spark_version @pytest.mark.parametrize( @@ -13,3 +17,32 @@ ) def test_extract_major_minor_versions(full_version, major_minor_version): assert _extract_major_minor_versions(full_version) == major_minor_version + + +@pytest.mark.parametrize( + "spark_version, expected", + [ + ("3.2.1", "3.2"), + ("3.1", "3.1"), + ("3.10.3", "3.10"), + ("3.10", "3.10") + ] +) +def test__get_spark_version_without_cache(spark_version, expected): + with mock.patch.object(pyspark, "__version__", spark_version): + _get_spark_version.cache_clear() + assert _get_spark_version() == expected + + +@pytest.mark.parametrize( + "spark_version, expected", + [ + ("3.2.1", "3.2"), + ("3.1", "3.2"), + ("3.10.3", "3.2"), + ("3.10", "3.2") + ] +) +def test__get_spark_version_with_cache(spark_version, expected): + with mock.patch.object(pyspark, "__version__", spark_version): + assert _get_spark_version() == expected From 82b7cee99e53774a87e8d54dfdc7e861c9a5c369 Mon Sep 17 00:00:00 2001 From: Arsen Gumin Date: Wed, 14 Aug 2024 13:30:36 +0300 Subject: [PATCH 2/5] chore: Setting the default spark_version value from pyspark.__version__ --- pydeequ/configs.py | 4 ++++ tests/test_config.py | 30 +++++++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pydeequ/configs.py b/pydeequ/configs.py index fb184e2..091486f 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from functools import lru_cache import os import re @@ -26,6 +27,9 @@ def _get_spark_version() -> str: spark_version = os.getenv("SPARK_VERSION") if not spark_version: spark_version = str(pyspark.__version__) + logging.info( + f"SPARK_VERSION environment variable is not set, using Spark version from PySpark {spark_version} for Deequ Maven jars" + ) except KeyError: raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}") diff --git a/tests/test_config.py b/tests/test_config.py index 7b05e9c..2ee431a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import os from unittest import mock import pyspark @@ -6,6 +7,13 @@ from pydeequ.configs import _extract_major_minor_versions, _get_spark_version +@pytest.fixture +def mock_env(monkeypatch): + with mock.patch.dict(os.environ, clear=True): + monkeypatch.delenv("SPARK_VERSION", raising=False) + yield + + @pytest.mark.parametrize( "full_version, major_minor_version", [ @@ -20,29 +28,17 @@ def test_extract_major_minor_versions(full_version, major_minor_version): @pytest.mark.parametrize( - "spark_version, expected", - [ - ("3.2.1", "3.2"), - ("3.1", "3.1"), - ("3.10.3", "3.10"), - ("3.10", "3.10") - ] + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")] ) -def test__get_spark_version_without_cache(spark_version, expected): +def test__get_spark_version_without_cache(spark_version, expected, mock_env): with mock.patch.object(pyspark, "__version__", spark_version): - _get_spark_version.cache_clear() assert _get_spark_version() == expected + _get_spark_version.cache_clear() @pytest.mark.parametrize( - "spark_version, expected", - [ - ("3.2.1", "3.2"), - ("3.1", "3.2"), - ("3.10.3", "3.2"), - ("3.10", "3.2") - ] + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.2"), ("3.10.3", "3.2"), ("3.10", "3.2")] ) -def test__get_spark_version_with_cache(spark_version, expected): +def test__get_spark_version_with_cache(spark_version, expected, mock_env): with mock.patch.object(pyspark, "__version__", spark_version): assert _get_spark_version() == expected From 64ce29264869a284a1c5302bf3185f90723ab9f3 Mon Sep 17 00:00:00 2001 From: Arsen Gumin Date: Thu, 23 Nov 2023 22:26:45 +0300 Subject: [PATCH 3/5] chore: Setting the default spark_version value from pyspark.__version__ --- pydeequ/configs.py | 6 ++++-- tests/test_config.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pydeequ/configs.py b/pydeequ/configs.py index d4d4b31..e976501 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -2,7 +2,7 @@ from functools import lru_cache import os import re - +import pyspark SPARK_TO_DEEQU_COORD_MAPPING = { "3.5": "com.amazon.deequ:deequ:2.0.7-spark-3.5", @@ -22,7 +22,9 @@ def _extract_major_minor_versions(full_version: str): @lru_cache(maxsize=None) def _get_spark_version() -> str: try: - spark_version = os.environ["SPARK_VERSION"] + spark_version = os.getenv("SPARK_VERSION") + if not spark_version: + spark_version = str(pyspark.__version__) except KeyError: raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}") diff --git a/tests/test_config.py b/tests/test_config.py index c2956b3..7b05e9c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,9 @@ +from unittest import mock + +import pyspark import pytest -from pydeequ.configs import _extract_major_minor_versions + +from pydeequ.configs import _extract_major_minor_versions, _get_spark_version @pytest.mark.parametrize( @@ -13,3 +17,32 @@ ) def test_extract_major_minor_versions(full_version, major_minor_version): assert _extract_major_minor_versions(full_version) == major_minor_version + + +@pytest.mark.parametrize( + "spark_version, expected", + [ + ("3.2.1", "3.2"), + ("3.1", "3.1"), + ("3.10.3", "3.10"), + ("3.10", "3.10") + ] +) +def test__get_spark_version_without_cache(spark_version, expected): + with mock.patch.object(pyspark, "__version__", spark_version): + _get_spark_version.cache_clear() + assert _get_spark_version() == expected + + +@pytest.mark.parametrize( + "spark_version, expected", + [ + ("3.2.1", "3.2"), + ("3.1", "3.2"), + ("3.10.3", "3.2"), + ("3.10", "3.2") + ] +) +def test__get_spark_version_with_cache(spark_version, expected): + with mock.patch.object(pyspark, "__version__", spark_version): + assert _get_spark_version() == expected From 3074e390febfdb47628478883679f74b8ec3524e Mon Sep 17 00:00:00 2001 From: Arsen Gumin Date: Wed, 14 Aug 2024 13:30:36 +0300 Subject: [PATCH 4/5] chore: Setting the default spark_version value from pyspark.__version__ --- pydeequ/configs.py | 4 ++++ tests/test_config.py | 30 +++++++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pydeequ/configs.py b/pydeequ/configs.py index e976501..3f3d9ee 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from functools import lru_cache import os import re @@ -25,6 +26,9 @@ def _get_spark_version() -> str: spark_version = os.getenv("SPARK_VERSION") if not spark_version: spark_version = str(pyspark.__version__) + logging.info( + f"SPARK_VERSION environment variable is not set, using Spark version from PySpark {spark_version} for Deequ Maven jars" + ) except KeyError: raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}") diff --git a/tests/test_config.py b/tests/test_config.py index 7b05e9c..2ee431a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import os from unittest import mock import pyspark @@ -6,6 +7,13 @@ from pydeequ.configs import _extract_major_minor_versions, _get_spark_version +@pytest.fixture +def mock_env(monkeypatch): + with mock.patch.dict(os.environ, clear=True): + monkeypatch.delenv("SPARK_VERSION", raising=False) + yield + + @pytest.mark.parametrize( "full_version, major_minor_version", [ @@ -20,29 +28,17 @@ def test_extract_major_minor_versions(full_version, major_minor_version): @pytest.mark.parametrize( - "spark_version, expected", - [ - ("3.2.1", "3.2"), - ("3.1", "3.1"), - ("3.10.3", "3.10"), - ("3.10", "3.10") - ] + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")] ) -def test__get_spark_version_without_cache(spark_version, expected): +def test__get_spark_version_without_cache(spark_version, expected, mock_env): with mock.patch.object(pyspark, "__version__", spark_version): - _get_spark_version.cache_clear() assert _get_spark_version() == expected + _get_spark_version.cache_clear() @pytest.mark.parametrize( - "spark_version, expected", - [ - ("3.2.1", "3.2"), - ("3.1", "3.2"), - ("3.10.3", "3.2"), - ("3.10", "3.2") - ] + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.2"), ("3.10.3", "3.2"), ("3.10", "3.2")] ) -def test__get_spark_version_with_cache(spark_version, expected): +def test__get_spark_version_with_cache(spark_version, expected, mock_env): with mock.patch.object(pyspark, "__version__", spark_version): assert _get_spark_version() == expected From 24fba1625e8b0607df2dc5c75673e085ab1a3db6 Mon Sep 17 00:00:00 2001 From: Steven Ayers Date: Sun, 29 Sep 2024 11:32:42 +0100 Subject: [PATCH 5/5] Fix tests --- tests/test_config.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 2ee431a..d61ccfd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -30,10 +30,14 @@ def test_extract_major_minor_versions(full_version, major_minor_version): @pytest.mark.parametrize( "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")] ) -def test__get_spark_version_without_cache(spark_version, expected, mock_env): - with mock.patch.object(pyspark, "__version__", spark_version): - assert _get_spark_version() == expected +def test__get_spark_versione(spark_version, expected, mock_env): + try: _get_spark_version.cache_clear() + with mock.patch.object(pyspark, "__version__", spark_version): + assert _get_spark_version() == expected + finally: + _get_spark_version.cache_clear() + @pytest.mark.parametrize(