diff --git a/pydeequ/check_functions.py b/pydeequ/check_functions.py new file mode 100644 index 0000000..7d2d62d --- /dev/null +++ b/pydeequ/check_functions.py @@ -0,0 +1,2 @@ +def is_one(x): + return x == 1 / 1 diff --git a/pydeequ/checks.py b/pydeequ/checks.py index ebaaa49..c082309 100644 --- a/pydeequ/checks.py +++ b/pydeequ/checks.py @@ -3,8 +3,10 @@ from pyspark.sql import SparkSession +from pydeequ.check_functions import is_one from pydeequ.scala_utils import ScalaFunction1, to_scala_seq + # TODO implement custom assertions # TODO implement all methods without outside class dependencies # TODO Integration with Constraints @@ -564,8 +566,10 @@ def hasPattern(self, column, pattern, assertion=None, name=None, hint=None): :param str hint: A hint that states why a constraint could have failed. :return: hasPattern self: A Check object that runs the condition on the column. """ - assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) if assertion \ - else getattr(self._Check, "hasPattern$default$2")() + if not assertion: + assertion = is_one + + assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) name = self._jvm.scala.Option.apply(name) hint = self._jvm.scala.Option.apply(hint) pattern_regex = self._jvm.scala.util.matching.Regex(pattern, None) @@ -779,19 +783,25 @@ def isGreaterThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None): self._Check = self._Check.isGreaterThanOrEqualTo(columnA, columnB, assertion_func, hint) return self - def isContainedIn(self, column, allowed_values): + def isContainedIn(self, column, allowed_values, assertion=None, hint=None): """ Asserts that every non-null value in a column is contained in a set of predefined values :param str column: Column in DataFrame to run the assertion on. :param list[str] allowed_values: A function that accepts allowed values for the column. + :param lambda assertion: A function that accepts an int or float parameter. :param str hint: A hint that states why a constraint could have failed. :return: isContainedIn self: A Check object that runs the assertion on the columns. """ arr = self._spark_session.sparkContext._gateway.new_array(self._jvm.java.lang.String, len(allowed_values)) for i in range(len(allowed_values)): arr[i] = allowed_values[i] - self._Check = self._Check.isContainedIn(column, arr) + + if not assertion: + assertion = is_one + assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) + hint = self._jvm.scala.Option.apply(hint) + self._Check = self._Check.isContainedIn(column, arr, assertion_func, hint) return self def evaluate(self, context): diff --git a/tests/test_checks.py b/tests/test_checks.py index 40634eb..d4782a6 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -431,10 +431,12 @@ def isGreaterThan(self, columnA, columnB, assertion=None, hint=None): df = VerificationResult.checkResultsAsDataFrame(self.spark, result) return df.select("constraint_status").collect() - def isContainedIn(self, column, allowed_values): + def isContainedIn(self, column, allowed_values, assertion=None, hint=None): check = Check(self.spark, CheckLevel.Warning, "test isContainedIn") result = ( - VerificationSuite(self.spark).onData(self.df).addCheck(check.isContainedIn(column, allowed_values)).run() + VerificationSuite(self.spark).onData(self.df).addCheck( + check.isContainedIn(column, allowed_values, assertion=assertion, hint=hint) + ).run() ) df = VerificationResult.checkResultsAsDataFrame(self.spark, result) @@ -1134,6 +1136,11 @@ def test_fail_satisfies(self): def test_hasPattern(self): self.assertEqual(self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3), [Row(constraint_status="Success")]) + # Default assertion is 1, thus failure + self.assertEqual(self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}"), [Row(constraint_status="Failure")]) + self.assertEqual( + self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3, hint="it be should be above 0.66"), + [Row(constraint_status="Success")]) @pytest.mark.xfail(reason="@unittest.expectedFailure") def test_fail_hasPattern(self): @@ -1206,6 +1213,12 @@ def test_fail_isGreaterThan(self): self.assertEqual(self.isGreaterThan("h", "f", lambda x: x == 1), [Row(constraint_status="Success")]) def test_isContainedIn(self): + # test all variants for assertion and hint + self.assertEqual( + self.isContainedIn("a", ["foo", "bar", "baz"], lambda x: x == 1), [Row(constraint_status="Success")]) + self.assertEqual( + self.isContainedIn("a", ["foo", "bar", "baz"], lambda x: x == 1, hint="it should be 1"), + [Row(constraint_status="Success")]) self.assertEqual(self.isContainedIn("a", ["foo", "bar", "baz"]), [Row(constraint_status="Success")]) # A none value makes the test still pass self.assertEqual(self.isContainedIn("c", ["5", "6"]), [Row(constraint_status="Success")])