From 5cc3c8c9bc9acba114925a98cf54ed3484b3f7e3 Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Mon, 12 Aug 2024 20:18:14 +0530 Subject: [PATCH] Fix Decimal comparision to load decimal value at load time instead of check time * Added custom tag "!decimal" and "!decimallist" and use them to load decimal value as decimal --- bft/cases/parser.py | 17 -------- bft/core/yaml_parser.py | 23 ++++++++++- bft/core/yaml_parser_test.py | 23 +++++++++++ bft/testers/duckdb/runner.py | 7 +--- bft/testers/snowflake/runner.py | 7 +--- bft/utils/utils.py | 19 --------- bft/utils/utils_test.py | 12 ------ cases/arithmetic_decimal/avg_decimal.yaml | 34 ++++++++-------- cases/arithmetic_decimal/max_decimal.yaml | 48 +++++++++++------------ cases/arithmetic_decimal/min_decimal.yaml | 40 +++++++++---------- cases/arithmetic_decimal/sum_decimal.yaml | 30 +++++++------- dialects/duckdb.yaml | 4 -- 12 files changed, 123 insertions(+), 141 deletions(-) create mode 100644 bft/core/yaml_parser_test.py delete mode 100644 bft/utils/utils_test.py diff --git a/bft/cases/parser.py b/bft/cases/parser.py index d6172f4..e045e8f 100644 --- a/bft/cases/parser.py +++ b/bft/cases/parser.py @@ -1,5 +1,4 @@ import math -from decimal import Decimal from typing import BinaryIO, Iterable, List from bft.core.yaml_parser import BaseYamlParser, BaseYamlVisitor @@ -39,24 +38,8 @@ def __normalize_yaml_literal(self, value, data_type): return math.nan else: raise ValueError(f"Unrecognized float string literal {value}") - if data_type.startswith("dec"): - ret_val = self._normalize_decimal_type(value) - return ret_val return value - def _normalize_decimal_type(self, val): - if val is None: - return val - if not isinstance(val, list): - return Decimal(str(val)) - converted_list = [] - for v in val: - if v is not None: - converted_list.append(Decimal(v)) - else: - converted_list.append(v) - return converted_list - def visit_literal(self, lit): value = self._get_or_die(lit, "value") data_type = self._get_or_die(lit, "type") diff --git a/bft/core/yaml_parser.py b/bft/core/yaml_parser.py index 91ff4ad..322db92 100644 --- a/bft/core/yaml_parser.py +++ b/bft/core/yaml_parser.py @@ -1,5 +1,6 @@ import math from abc import ABC, abstractmethod +from decimal import Decimal from typing import BinaryIO, Generic, Iterable, List, TypeVar import yaml @@ -90,7 +91,27 @@ class BaseYamlParser(ABC, Generic[T]): def get_visitor(self) -> BaseYamlVisitor[T]: pass + def get_loader(self): + loader = yaml.SafeLoader + """Add tag "!decimal" to the loader """ + loader.add_constructor("!decimal", self.decimal_constructor) + loader.add_constructor("!decimallist", self.list_of_decimal_constructor) + return loader + + def decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): + return self.get_decimal_value(loader, node) + + def get_decimal_value(self, loader: yaml.SafeLoader, node: yaml.ScalarNode): + value = loader.construct_scalar(node) + if isinstance(value, str) and value.lower() == 'null': + return None + return Decimal(value) + + def list_of_decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): + return [self.get_decimal_value(loader, item) for item in node.value] + def parse(self, f: BinaryIO) -> List[T]: - objs = yaml.load_all(f, SafeLoader) + loader = self.get_loader() + objs = yaml.load_all(f, loader) visitor = self.get_visitor() return [visitor.visit(obj) for obj in objs] diff --git a/bft/core/yaml_parser_test.py b/bft/core/yaml_parser_test.py new file mode 100644 index 0000000..52262d4 --- /dev/null +++ b/bft/core/yaml_parser_test.py @@ -0,0 +1,23 @@ +from decimal import Decimal +from typing import NamedTuple + +from bft.core.yaml_parser import BaseYamlParser + + +class TestDecimalResult(NamedTuple): + cases: Decimal | list[Decimal] + +class TestCaseVisitor(): + def visit(self, testcase): + return TestDecimalResult(testcase) +class DecimalTestCaseParser(BaseYamlParser[TestDecimalResult]): + def get_visitor(self) -> TestCaseVisitor: + return TestCaseVisitor() + +def test_yaml_parser_decimal_tag(): + parser = DecimalTestCaseParser() + # parser returns list of parsed values + assert parser.parse(b"!decimal 1") == [TestDecimalResult(Decimal('1'))] + assert parser.parse(b"!decimal 1.78766") == [TestDecimalResult(Decimal('1.78766'))] + assert parser.parse(b"!decimal null") == [TestDecimalResult(None)] + assert parser.parse(b"!decimallist [1.2, null, 7.547]") == [TestDecimalResult([Decimal('1.2'), None, Decimal('7.547')])] diff --git a/bft/testers/duckdb/runner.py b/bft/testers/duckdb/runner.py index 19a7011..1883ac1 100644 --- a/bft/testers/duckdb/runner.py +++ b/bft/testers/duckdb/runner.py @@ -8,7 +8,7 @@ from bft.cases.runner import SqlCaseResult, SqlCaseRunner from bft.cases.types import Case from bft.dialects.types import SqlMapping -from bft.utils.utils import type_to_dialect_type, compareDecimalResult +from bft.utils.utils import type_to_dialect_type type_map = { "i8": "TINYINT", @@ -140,11 +140,6 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult: elif case.result.type.startswith("fp") and case.result.value and result: if math.isclose(result, case.result.value, rel_tol=1e-7): return SqlCaseResult.success() - elif case.result.type.startswith("dec") and case.result.value and result: - if compareDecimalResult(case.result.value, Decimal(str(result))): - return SqlCaseResult.success() - else: - return SqlCaseResult.mismatch(str(result)) else: if result == case.result.value: return SqlCaseResult.success() diff --git a/bft/testers/snowflake/runner.py b/bft/testers/snowflake/runner.py index ff3f4be..fe1e22b 100644 --- a/bft/testers/snowflake/runner.py +++ b/bft/testers/snowflake/runner.py @@ -12,7 +12,7 @@ from bft.cases.runner import SqlCaseResult, SqlCaseRunner from bft.cases.types import Case from bft.dialects.types import SqlMapping -from bft.utils.utils import type_to_dialect_type, compareDecimalResult +from bft.utils.utils import type_to_dialect_type type_map = { "fp64": "FLOAT", @@ -169,11 +169,6 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult: elif case.result.type.startswith("fp") and case.result.value and result: if math.isclose(result, case.result.value, rel_tol=1e-7): return SqlCaseResult.success() - elif case.result.type.startswith("dec") and case.result.value and result: - if compareDecimalResult(case.result.value, Decimal(str(result))): - return SqlCaseResult.success() - else: - return SqlCaseResult.mismatch(str(result)) else: if result == case.result.value: return SqlCaseResult.success() diff --git a/bft/utils/utils.py b/bft/utils/utils.py index 956a084..b553df1 100644 --- a/bft/utils/utils.py +++ b/bft/utils/utils.py @@ -25,22 +25,3 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str: return type_val # transform parameterized type name to have dialect type return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")") - -def compareDecimalResult(expected_result: Decimal, actual_result: Decimal)->bool: - ''' - Compares non-null decimal type based on scale of expected_result - :param expected_result: expected result. Its scale is considered to be the scale to compare - :param actual_result: - :return: bool - ''' - # make scale of actual_result to same as expected_result for comparison - scale = abs(expected_result.as_tuple().exponent) - rounding_format = Decimal(f"1.{'0' * scale}") - try: - # set thread precison to 38 since database decimal support max 38 - with localcontext(prec=38) as ctx: - rounded_result = actual_result.quantize(rounding_format, rounding=ROUND_DOWN) - except Exception as e: - print(f"Exception while rounding: {e}") - return False - return rounded_result == expected_result diff --git a/bft/utils/utils_test.py b/bft/utils/utils_test.py deleted file mode 100644 index aeeeff3..0000000 --- a/bft/utils/utils_test.py +++ /dev/null @@ -1,12 +0,0 @@ -from decimal import Decimal - -from bft.utils.utils import compareDecimalResult - - -def test_compare_decimal_result(): - assert compareDecimalResult(Decimal('1'), Decimal('1')) - assert compareDecimalResult(Decimal('99999999999999999999999999999999999999'), Decimal('99999999999999999999999999999999999999')) - assert compareDecimalResult(Decimal('1.75'), Decimal('1.75678')) - assert compareDecimalResult(Decimal('1.757'), Decimal('1.75678')) == False - assert compareDecimalResult(Decimal('2.33'), Decimal('2.330000000078644688')) - assert compareDecimalResult(Decimal('4.12500053644180'), Decimal('4.1250005364418029785156')) diff --git a/cases/arithmetic_decimal/avg_decimal.yaml b/cases/arithmetic_decimal/avg_decimal.yaml index a7776dc..3b3b612 100644 --- a/cases/arithmetic_decimal/avg_decimal.yaml +++ b/cases/arithmetic_decimal/avg_decimal.yaml @@ -5,37 +5,37 @@ cases: id: basic description: Basic examples without any special cases args: - - value: [0, -1, 2, 20] + - value: !decimallist [0, -1, 2, 20] type: decimal<38, 0> result: - value: 5.25 + value: !decimal 5.25 type: decimal<38, 2> - group: basic args: - - value: [2000000, -3217908, 629000, -100000, 0, 987654] + - value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654] type: decimal<38, 0> result: - value: 49791 + value: !decimal 49791 type: decimal<38, 5> - group: basic args: - - value: [2.5, 0, 5.0, -2.5, -7.5] + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] type: decimal<38, 2> result: - value: -0.5 + value: !decimal -0.5 type: decimal<38, 2> - group: basic args: - - value: [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875] - type: decimal<38, 22> + - value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875] + type: decimal<38, 14> result: - value: 4.12500053644180 - type: decimal<38, 22> + value: !decimal 4.12500053644181 + type: decimal<38, 14> - group: id: overflow description: Examples demonstrating overflow behavior args: - - value: [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999] + - value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999] type: decimal<38, 0> options: overflow: ERROR @@ -45,22 +45,22 @@ cases: id: null_handling description: Examples with null as unput or output args: - - value: [Null, Null, Null] + - value: !decimallist [Null, Null, Null] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [] + - value: !decimallist [] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [200000, Null, 629000, -10000, 0, 987621] + - value: !decimallist [200000, Null, 629000, -10000, 0, 987621] type: decimal<38, 0> result: - value: 361324.2 + value: !decimal 361324.2 type: decimal<38, 2> diff --git a/cases/arithmetic_decimal/max_decimal.yaml b/cases/arithmetic_decimal/max_decimal.yaml index bcf560e..39d796d 100644 --- a/cases/arithmetic_decimal/max_decimal.yaml +++ b/cases/arithmetic_decimal/max_decimal.yaml @@ -5,87 +5,87 @@ cases: id: basic description: Basic examples without any special cases args: - - value: [20, -3, 1, -10, 0, 5] + - value: !decimallist [20, -3, 1, -10, 0, 5] type: decimal<38, 0> result: - value: 20 + value: !decimal 20 type: decimal<38, 0> - group: basic args: - - value: [-32768, 32767, 20000, -30000] + - value: !decimallist [-32768, 32767, 20000, -30000] type: decimal<38, 0> result: - value: 32767 + value: !decimal 32767 type: decimal<38, 0> - group: basic args: - - value: [-214748648, 214748647, 21470048, 4000000] + - value: !decimallist [-214748648, 214748647, 21470048, 4000000] type: decimal<38, 0> result: - value: 214748647 + value: !decimal 214748647 type: decimal<38, 0> - group: basic args: - - value: [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] + - value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] type: decimal<38, 0> result: - value: 2000000000 + value: !decimal 2000000000 type: decimal<38, 0> - group: basic args: - - value: [2.5, 0, 5.0, -2.5, -7.5] + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] type: decimal<38, 2> result: - value: 5.0 + value: !decimal 5.0 type: decimal<38, 2> - group: basic args: - - value: [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76] + - value: !decimallist [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76] type: decimal<38, 0> result: - value: 99999999999999999999999999999999999999 + value: !decimal 99999999999999999999999999999999999999 type: decimal<38, 0> - group: id: null_handling description: Examples with null as unput or output args: - - value: [Null, Null, Null] + - value: !decimallist [Null, Null, Null] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [] + - value: !decimallist [] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [2000000000, Null, 629000000, -100000000, Null, 987654321] + - value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321] type: decimal<38, 0> result: - value: 2000000000 + value: !decimal 2000000000 type: decimal<38, 0> - group: null_handling args: - - value: [Null, Null] + - value: !decimallist [Null, Null] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [] + - value: !decimallist [] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null] + - value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null] type: decimal<38, 0> result: - value: 99999999999999999999999999999999999999 + value: !decimal 99999999999999999999999999999999999999 type: decimal<38, 0> diff --git a/cases/arithmetic_decimal/min_decimal.yaml b/cases/arithmetic_decimal/min_decimal.yaml index 0f09151..3ba403d 100644 --- a/cases/arithmetic_decimal/min_decimal.yaml +++ b/cases/arithmetic_decimal/min_decimal.yaml @@ -5,73 +5,73 @@ cases: id: basic description: Basic examples without any special cases args: - - value: [20, -3, 1, -10, 0, 5] + - value: !decimallist [20, -3, 1, -10, 0, 5] type: decimal<38, 0> result: - value: -10 + value: !decimal -10 type: decimal<38, 0> - group: basic args: - - value: [-32768, 32767, 20000, -30000] + - value: !decimallist [-32768, 32767, 20000, -30000] type: decimal<38, 0> result: - value: -32768 + value: !decimal -32768 type: decimal<38, 0> - group: basic args: - - value: [-214748648, 214748647, 21470048, 4000000] + - value: !decimallist [-214748648, 214748647, 21470048, 4000000] type: decimal<38, 0> result: - value: -214748648 + value: !decimal -214748648 type: decimal<38, 0> - group: basic args: - - value: [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] + - value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] type: decimal<38, 0> result: - value: -3217908979 + value: !decimal -3217908979 type: decimal<38, 0> - group: basic args: - - value: [2.5, 0, 5.0, -2.5, -7.5] + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] type: decimal<38, 2> result: - value: -7.5 + value: !decimal -7.5 type: decimal<38, 2> - group: basic args: - - value: [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111] + - value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111] type: decimal<38, 0> result: - value: -99999999999999999999999999999999999998 + value: !decimal -99999999999999999999999999999999999998 type: decimal<38, 0> - group: id: null_handling description: Examples with null as unput or output args: - - value: [Null, Null, Null] + - value: !decimallist [Null, Null, Null] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [] + - value: !decimallist [] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [2000000000, Null, 629000000, -100000000, Null, 987654321] + - value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321] type: decimal<38, 0> result: - value: -100000000 + value: !decimal -100000000 type: decimal<38, 0> - group: null_handling args: - - value: [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null] + - value: !decimallist [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null] type: decimal<38, 0> result: - value: -99999999999999999999999999999999999998 + value: !decimal -99999999999999999999999999999999999998 type: decimal<38, 0> diff --git a/cases/arithmetic_decimal/sum_decimal.yaml b/cases/arithmetic_decimal/sum_decimal.yaml index ea76f84..b636e3a 100644 --- a/cases/arithmetic_decimal/sum_decimal.yaml +++ b/cases/arithmetic_decimal/sum_decimal.yaml @@ -5,37 +5,37 @@ cases: id: basic description: Basic examples without any special cases args: - - value: [0, -1, 2, 20] + - value: !decimallist [0, -1, 2, 20] type: decimal<38, 0> result: - value: 21 + value: !decimal 21 type: decimal<38, 0> - group: basic args: - - value: [2000000, -3217908, 629000, -100000, 0, 987654] + - value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654] type: decimal<38, 0> result: - value: 298746 + value: !decimal 298746 type: decimal<38, 0> - group: basic args: - - value: [2.5, 0, 5.0, -2.5, -7.5] + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] type: decimal<38, 2> result: - value: -2.5 + value: !decimal -2.5 type: decimal<38, 2> - group: basic args: - - value: [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875] + - value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875] type: decimal<38, 22> result: - value: 16.50000214576721 + value: !decimal 16.5000021457672119140625 type: decimal<38, 22> - group: id: overflow description: Examples demonstrating overflow behavior args: - - value: [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999] + - value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999] type: decimal<38, 0> options: overflow: ERROR @@ -45,22 +45,22 @@ cases: id: null_handling description: Examples with null as unput or output args: - - value: [Null, Null, Null] + - value: !decimallist [Null, Null, Null] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [] + - value: !decimallist [] type: decimal<38, 0> result: - value: Null + value: !decimal Null type: decimal<38, 0> - group: null_handling args: - - value: [200000, Null, 629000, -10000, 0, 987621] + - value: !decimallist [200000, Null, 629000, -10000, 0, 987621] type: decimal<38, 0> result: - value: 1806621 + value: !decimal 1806621 type: decimal<38, 0> diff --git a/dialects/duckdb.yaml b/dialects/duckdb.yaml index 3b18c22..d96f219 100644 --- a/dialects/duckdb.yaml +++ b/dialects/duckdb.yaml @@ -669,10 +669,6 @@ aggregate_functions: aggregate: true supported_kernels: - dec -- name: arithmetic_decimal.avg - aggregate: true - supported_kernels: - - dec - name: boolean.bool_and aggregate: true supported_kernels: