Skip to content

Commit

Permalink
Fix Decimal comparision to load decimal value at load time instead of…
Browse files Browse the repository at this point in the history
… check time

* Added custom tag "!decimal" and "!decimallist" and use them to load decimal value as decimal
  • Loading branch information
anshuldata committed Aug 12, 2024
1 parent 6cf036e commit 5cc3c8c
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 141 deletions.
17 changes: 0 additions & 17 deletions bft/cases/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from decimal import Decimal
from typing import BinaryIO, Iterable, List

from bft.core.yaml_parser import BaseYamlParser, BaseYamlVisitor
Expand Down Expand Up @@ -39,24 +38,8 @@ def __normalize_yaml_literal(self, value, data_type):
return math.nan
else:
raise ValueError(f"Unrecognized float string literal {value}")
if data_type.startswith("dec"):
ret_val = self._normalize_decimal_type(value)
return ret_val
return value

def _normalize_decimal_type(self, val):
if val is None:
return val
if not isinstance(val, list):
return Decimal(str(val))
converted_list = []
for v in val:
if v is not None:
converted_list.append(Decimal(v))
else:
converted_list.append(v)
return converted_list

def visit_literal(self, lit):
value = self._get_or_die(lit, "value")
data_type = self._get_or_die(lit, "type")
Expand Down
23 changes: 22 additions & 1 deletion bft/core/yaml_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import BinaryIO, Generic, Iterable, List, TypeVar

import yaml
Expand Down Expand Up @@ -90,7 +91,27 @@ class BaseYamlParser(ABC, Generic[T]):
def get_visitor(self) -> BaseYamlVisitor[T]:
pass

def get_loader(self):
loader = yaml.SafeLoader
"""Add tag "!decimal" to the loader """
loader.add_constructor("!decimal", self.decimal_constructor)
loader.add_constructor("!decimallist", self.list_of_decimal_constructor)
return loader

def decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return self.get_decimal_value(loader, node)

def get_decimal_value(self, loader: yaml.SafeLoader, node: yaml.ScalarNode):
value = loader.construct_scalar(node)
if isinstance(value, str) and value.lower() == 'null':
return None
return Decimal(value)

def list_of_decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return [self.get_decimal_value(loader, item) for item in node.value]

def parse(self, f: BinaryIO) -> List[T]:
objs = yaml.load_all(f, SafeLoader)
loader = self.get_loader()
objs = yaml.load_all(f, loader)
visitor = self.get_visitor()
return [visitor.visit(obj) for obj in objs]
23 changes: 23 additions & 0 deletions bft/core/yaml_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from decimal import Decimal
from typing import NamedTuple

from bft.core.yaml_parser import BaseYamlParser


class TestDecimalResult(NamedTuple):
cases: Decimal | list[Decimal]

class TestCaseVisitor():
def visit(self, testcase):
return TestDecimalResult(testcase)
class DecimalTestCaseParser(BaseYamlParser[TestDecimalResult]):
def get_visitor(self) -> TestCaseVisitor:
return TestCaseVisitor()

def test_yaml_parser_decimal_tag():
parser = DecimalTestCaseParser()
# parser returns list of parsed values
assert parser.parse(b"!decimal 1") == [TestDecimalResult(Decimal('1'))]
assert parser.parse(b"!decimal 1.78766") == [TestDecimalResult(Decimal('1.78766'))]
assert parser.parse(b"!decimal null") == [TestDecimalResult(None)]
assert parser.parse(b"!decimallist [1.2, null, 7.547]") == [TestDecimalResult([Decimal('1.2'), None, Decimal('7.547')])]
7 changes: 1 addition & 6 deletions bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type, compareDecimalResult
from bft.utils.utils import type_to_dialect_type

type_map = {
"i8": "TINYINT",
Expand Down Expand Up @@ -140,11 +140,6 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
elif case.result.type.startswith("fp") and case.result.value and result:
if math.isclose(result, case.result.value, rel_tol=1e-7):
return SqlCaseResult.success()
elif case.result.type.startswith("dec") and case.result.value and result:
if compareDecimalResult(case.result.value, Decimal(str(result))):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
else:
if result == case.result.value:
return SqlCaseResult.success()
Expand Down
7 changes: 1 addition & 6 deletions bft/testers/snowflake/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type, compareDecimalResult
from bft.utils.utils import type_to_dialect_type

type_map = {
"fp64": "FLOAT",
Expand Down Expand Up @@ -169,11 +169,6 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
elif case.result.type.startswith("fp") and case.result.value and result:
if math.isclose(result, case.result.value, rel_tol=1e-7):
return SqlCaseResult.success()
elif case.result.type.startswith("dec") and case.result.value and result:
if compareDecimalResult(case.result.value, Decimal(str(result))):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
else:
if result == case.result.value:
return SqlCaseResult.success()
Expand Down
19 changes: 0 additions & 19 deletions bft/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,3 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
return type_val
# transform parameterized type name to have dialect type
return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")")

def compareDecimalResult(expected_result: Decimal, actual_result: Decimal)->bool:
'''
Compares non-null decimal type based on scale of expected_result
:param expected_result: expected result. Its scale is considered to be the scale to compare
:param actual_result:
:return: bool
'''
# make scale of actual_result to same as expected_result for comparison
scale = abs(expected_result.as_tuple().exponent)
rounding_format = Decimal(f"1.{'0' * scale}")
try:
# set thread precison to 38 since database decimal support max 38
with localcontext(prec=38) as ctx:
rounded_result = actual_result.quantize(rounding_format, rounding=ROUND_DOWN)
except Exception as e:
print(f"Exception while rounding: {e}")
return False
return rounded_result == expected_result
12 changes: 0 additions & 12 deletions bft/utils/utils_test.py

This file was deleted.

34 changes: 17 additions & 17 deletions cases/arithmetic_decimal/avg_decimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,37 @@ cases:
id: basic
description: Basic examples without any special cases
args:
- value: [0, -1, 2, 20]
- value: !decimallist [0, -1, 2, 20]
type: decimal<38, 0>
result:
value: 5.25
value: !decimal 5.25
type: decimal<38, 2>
- group: basic
args:
- value: [2000000, -3217908, 629000, -100000, 0, 987654]
- value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654]
type: decimal<38, 0>
result:
value: 49791
value: !decimal 49791
type: decimal<38, 5>
- group: basic
args:
- value: [2.5, 0, 5.0, -2.5, -7.5]
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<38, 2>
result:
value: -0.5
value: !decimal -0.5
type: decimal<38, 2>
- group: basic
args:
- value: [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875]
type: decimal<38, 22>
- value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875]
type: decimal<38, 14>
result:
value: 4.12500053644180
type: decimal<38, 22>
value: !decimal 4.12500053644181
type: decimal<38, 14>
- group:
id: overflow
description: Examples demonstrating overflow behavior
args:
- value: [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999]
- value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999]
type: decimal<38, 0>
options:
overflow: ERROR
Expand All @@ -45,22 +45,22 @@ cases:
id: null_handling
description: Examples with null as unput or output
args:
- value: [Null, Null, Null]
- value: !decimallist [Null, Null, Null]
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
- value: !decimallist []
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [200000, Null, 629000, -10000, 0, 987621]
- value: !decimallist [200000, Null, 629000, -10000, 0, 987621]
type: decimal<38, 0>
result:
value: 361324.2
value: !decimal 361324.2
type: decimal<38, 2>
48 changes: 24 additions & 24 deletions cases/arithmetic_decimal/max_decimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,87 @@ cases:
id: basic
description: Basic examples without any special cases
args:
- value: [20, -3, 1, -10, 0, 5]
- value: !decimallist [20, -3, 1, -10, 0, 5]
type: decimal<38, 0>
result:
value: 20
value: !decimal 20
type: decimal<38, 0>
- group: basic
args:
- value: [-32768, 32767, 20000, -30000]
- value: !decimallist [-32768, 32767, 20000, -30000]
type: decimal<38, 0>
result:
value: 32767
value: !decimal 32767
type: decimal<38, 0>
- group: basic
args:
- value: [-214748648, 214748647, 21470048, 4000000]
- value: !decimallist [-214748648, 214748647, 21470048, 4000000]
type: decimal<38, 0>
result:
value: 214748647
value: !decimal 214748647
type: decimal<38, 0>
- group: basic
args:
- value: [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
- value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<38, 0>
result:
value: 2000000000
value: !decimal 2000000000
type: decimal<38, 0>
- group: basic
args:
- value: [2.5, 0, 5.0, -2.5, -7.5]
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<38, 2>
result:
value: 5.0
value: !decimal 5.0
type: decimal<38, 2>
- group: basic
args:
- value: [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76]
- value: !decimallist [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76]
type: decimal<38, 0>
result:
value: 99999999999999999999999999999999999999
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: [Null, Null, Null]
- value: !decimallist [Null, Null, Null]
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
- value: !decimallist []
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [2000000000, Null, 629000000, -100000000, Null, 987654321]
- value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<38, 0>
result:
value: 2000000000
value: !decimal 2000000000
type: decimal<38, 0>
- group: null_handling
args:
- value: [Null, Null]
- value: !decimallist [Null, Null]
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
- value: !decimallist []
type: decimal<38, 0>
result:
value: Null
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null]
- value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null]
type: decimal<38, 0>
result:
value: 99999999999999999999999999999999999999
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
Loading

0 comments on commit 5cc3c8c

Please sign in to comment.