From 65beceb27f76cedc3fc1788a0f428679a1d6063e Mon Sep 17 00:00:00 2001 From: SCM Date: Fri, 23 Aug 2024 19:52:40 +0100 Subject: [PATCH 01/14] add fraction dunders --- opshin/std/fractions.py | 116 ++++++++++++++++++++++-- opshin/tests/test_std/test_fractions.py | 103 +++++++++++++++++++++ 2 files changed, 210 insertions(+), 9 deletions(-) diff --git a/opshin/std/fractions.py b/opshin/std/fractions.py index caad5f83..d76f149a 100644 --- a/opshin/std/fractions.py +++ b/opshin/std/fractions.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from pycardano import Datum as Anything, PlutusData from typing import Dict, List, Union +from typing import Self from opshin.std.math import * @@ -17,9 +18,106 @@ class Fraction(PlutusData): numerator: int denominator: int + def norm(self) -> Self: + """Restores the invariant that num/denom are in the smallest possible denomination and denominator > 0""" + return _norm_gcd_fraction(_norm_signs_fraction(self)) + + def ceil(self) -> int: + return ( + self.numerator + self.denominator - sign(self.denominator) + ) // self.denominator + + def __add__(self, other: Self) -> Self: + """returns self + other""" + return Fraction( + (self.numerator * other.denominator) + (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + + def __neg__( + self, + ) -> Self: + """returns -self""" + return Fraction(-self.numerator, self.denominator) + + def __sub__(self, other: Self) -> Self: + """returns self - other""" + return Fraction( + (self.numerator * other.denominator) - (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + + def __mul__(self, other: Self) -> Self: + """returns self * other""" + return Fraction( + self.numerator * other.numerator, self.denominator * other.denominator + ) + + def __truediv__(self, other: Self) -> Self: + """returns self / other""" + return Fraction( + self.numerator * other.denominator, self.denominator * other.numerator + ) + + def __ge__(self, other: Self) -> Self: + """returns self >= other""" + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator >= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator <= self.denominator * other.numerator + ) + return res + + def __le__(self, other: Self) -> Self: + """returns self <= other""" + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator <= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator >= self.denominator * other.numerator + ) + return res + + def __eq__(self, other: Self) -> Self: + """returns self == other""" + return self.numerator * other.denominator == self.denominator * other.numerator + + def __lt__(self, other: Self) -> Self: + """returns self < other""" + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator < self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator > self.denominator * other.numerator + ) + return res + + def __gt__(self, other: Self) -> Self: + """returns self > other""" + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator > self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator < self.denominator * other.numerator + ) + return res + + def __floordiv__(self, other: Self) -> int: + x = self / other + return x.numerator // x.denominator + def add_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a + b""" + """returns self + other""" return Fraction( (a.numerator * b.denominator) + (b.numerator * a.denominator), a.denominator * b.denominator, @@ -32,17 +130,17 @@ def neg_fraction(a: Fraction) -> Fraction: def sub_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a - b""" + """returns self - other""" return add_fraction(a, neg_fraction(b)) def mul_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a * b""" + """returns self * other""" return Fraction(a.numerator * b.numerator, a.denominator * b.denominator) def div_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a / b""" + """returns self / other""" return Fraction(a.numerator * b.denominator, a.denominator * b.numerator) @@ -63,7 +161,7 @@ def norm_fraction(a: Fraction) -> Fraction: def ge_fraction(a: Fraction, b: Fraction) -> bool: - """returns a >= b""" + """returns self >= other""" if a.denominator * b.denominator >= 0: res = a.numerator * b.denominator >= a.denominator * b.numerator else: @@ -72,7 +170,7 @@ def ge_fraction(a: Fraction, b: Fraction) -> bool: def le_fraction(a: Fraction, b: Fraction) -> bool: - """returns a <= b""" + """returns self <= other""" if a.denominator * b.denominator >= 0: res = a.numerator * b.denominator <= a.denominator * b.numerator else: @@ -81,12 +179,12 @@ def le_fraction(a: Fraction, b: Fraction) -> bool: def eq_fraction(a: Fraction, b: Fraction) -> bool: - """returns a == b""" + """returns self == other""" return a.numerator * b.denominator == a.denominator * b.numerator def lt_fraction(a: Fraction, b: Fraction) -> bool: - """returns a < b""" + """returns self < other""" if a.denominator * b.denominator >= 0: res = a.numerator * b.denominator < a.denominator * b.numerator else: @@ -95,7 +193,7 @@ def lt_fraction(a: Fraction, b: Fraction) -> bool: def gt_fraction(a: Fraction, b: Fraction) -> bool: - """returns a > b""" + """returns self > other""" if a.denominator * b.denominator >= 0: res = a.numerator * b.denominator > a.denominator * b.numerator else: diff --git a/opshin/tests/test_std/test_fractions.py b/opshin/tests/test_std/test_fractions.py index 60330b81..f5abf9de 100644 --- a/opshin/tests/test_std/test_fractions.py +++ b/opshin/tests/test_std/test_fractions.py @@ -24,6 +24,15 @@ def test_add(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid add" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_add_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_added = a + b + oc_normalized = native_fraction_from_oc_fraction(oc_added) + assert oc_normalized == ( + native_fraction_from_oc_fraction(a) + native_fraction_from_oc_fraction(b) + ), "Invalid add" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_sub(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_subbed = oc_fractions.sub_fraction(a, b) @@ -33,6 +42,15 @@ def test_sub(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid sub" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_sub_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_subbed = a - b + oc_normalized = native_fraction_from_oc_fraction(oc_subbed) + assert oc_normalized == ( + native_fraction_from_oc_fraction(a) - native_fraction_from_oc_fraction(b) + ), "Invalid sub" + + @hypothesis.given(denormalized_fractions) def test_neg(a: oc_fractions.Fraction): oc_negged = oc_fractions.neg_fraction(a) @@ -40,6 +58,13 @@ def test_neg(a: oc_fractions.Fraction): assert oc_normalized == -native_fraction_from_oc_fraction(a), "Invalid neg" +@hypothesis.given(denormalized_fractions) +def test_neg_dunder(a: oc_fractions.Fraction): + oc_negged = -a + oc_normalized = native_fraction_from_oc_fraction(oc_negged) + assert oc_normalized == -native_fraction_from_oc_fraction(a), "Invalid neg" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_mul(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_mulled = oc_fractions.mul_fraction(a, b) @@ -49,6 +74,15 @@ def test_mul(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid mul" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_mul_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_mulled = a * b + oc_normalized = native_fraction_from_oc_fraction(oc_mulled) + assert oc_normalized == ( + native_fraction_from_oc_fraction(a) * native_fraction_from_oc_fraction(b) + ), "Invalid mul" + + @hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) def test_div(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_divved = oc_fractions.div_fraction(a, b) @@ -58,6 +92,15 @@ def test_div(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid div" +@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) +def test_div_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_divved = a / b + oc_normalized = native_fraction_from_oc_fraction(oc_divved) + assert oc_normalized == ( + native_fraction_from_oc_fraction(a) / native_fraction_from_oc_fraction(b) + ), "Invalid div" + + @hypothesis.given(denormalized_fractions) def test_norm_sign(a: oc_fractions.Fraction): oc_normed = oc_fractions._norm_signs_fraction(a) @@ -76,6 +119,15 @@ def test_norm(a: oc_fractions.Fraction): assert oc_normed.denominator == oc_normalized.denominator, "Invalid norm" +@hypothesis.given(denormalized_fractions) +@hypothesis.example(oc_fractions.Fraction(0, -1)) +def test_norm_method(a: oc_fractions.Fraction): + oc_normed = a.norm() + oc_normalized = native_fraction_from_oc_fraction(a) + assert oc_normed.numerator == oc_normalized.numerator, "Invalid norm" + assert oc_normed.denominator == oc_normalized.denominator, "Invalid norm" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_ge(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_ge = oc_fractions.ge_fraction(a, b) @@ -83,6 +135,13 @@ def test_ge(a: oc_fractions.Fraction, b: oc_fractions.Fraction): assert oc_ge == ge, "Invalid ge" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_ge_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_ge = a >= b + ge = native_fraction_from_oc_fraction(a) >= native_fraction_from_oc_fraction(b) + assert oc_ge == ge, "Invalid ge" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_le(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_le = oc_fractions.le_fraction(a, b) @@ -90,6 +149,13 @@ def test_le(a: oc_fractions.Fraction, b: oc_fractions.Fraction): assert oc_le == le, "Invalid le" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_le_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_le = a <= b + le = native_fraction_from_oc_fraction(a) <= native_fraction_from_oc_fraction(b) + assert oc_le == le, "Invalid le" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_lt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_lt = oc_fractions.lt_fraction(a, b) @@ -97,6 +163,13 @@ def test_lt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): assert oc_lt == lt, "Invalid lt" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_lt_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_lt = a < b + lt = native_fraction_from_oc_fraction(a) < native_fraction_from_oc_fraction(b) + assert oc_lt == lt, "Invalid lt" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_gt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_gt = oc_fractions.gt_fraction(a, b) @@ -104,6 +177,13 @@ def test_gt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): assert oc_gt == gt, "Invalid gt" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_gt_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_gt = a > b + gt = native_fraction_from_oc_fraction(a) > native_fraction_from_oc_fraction(b) + assert oc_gt == gt, "Invalid gt" + + @hypothesis.given(denormalized_fractions, denormalized_fractions) def test_eq(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_eq = oc_fractions.eq_fraction(a, b) @@ -111,6 +191,13 @@ def test_eq(a: oc_fractions.Fraction, b: oc_fractions.Fraction): assert oc_eq == eq, "Invalid eq" +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_eq_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_eq = a == b + eq = native_fraction_from_oc_fraction(a) == native_fraction_from_oc_fraction(b) + assert oc_eq == eq, "Invalid eq" + + @hypothesis.given(denormalized_fractions) def test_floor(a: oc_fractions.Fraction): oc_floor = oc_fractions.floor_fraction(a) @@ -119,9 +206,25 @@ def test_floor(a: oc_fractions.Fraction): ), "Invalid floor" +@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) +def test_floor_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_floor = a // b + floor = native_fraction_from_oc_fraction(a) // native_fraction_from_oc_fraction(b) + + assert oc_floor == floor, "Invalid floor" + + @hypothesis.given(denormalized_fractions) def test_ceil(a: oc_fractions.Fraction): oc_ceil = oc_fractions.ceil_fraction(a) assert ( native_math.ceil(native_fraction_from_oc_fraction(a)) == oc_ceil ), "Invalid ceil" + + +@hypothesis.given(denormalized_fractions) +def test_ceil_method(a: oc_fractions.Fraction): + oc_ceil = a.ceil() + assert ( + native_math.ceil(native_fraction_from_oc_fraction(a)) == oc_ceil + ), "Invalid ceil" From d1f7791bc6a23a178e63037750baac093541f9cc Mon Sep 17 00:00:00 2001 From: SCM Date: Sat, 24 Aug 2024 14:21:59 +0100 Subject: [PATCH 02/14] fraction dunder union[Fraction, int] --- .../rewrite/rewrite_forbidden_overwrites.py | 1 + opshin/rewrite/rewrite_import_typing.py | 10 + opshin/rewrite/rewrite_scoping.py | 17 +- opshin/std/fractions.py | 255 +++++++++--------- opshin/tests/test_std/test_fractions.py | 124 ++------- opshin/type_inference.py | 21 +- 6 files changed, 184 insertions(+), 244 deletions(-) diff --git a/opshin/rewrite/rewrite_forbidden_overwrites.py b/opshin/rewrite/rewrite_forbidden_overwrites.py index 488e1a22..72347eeb 100644 --- a/opshin/rewrite/rewrite_forbidden_overwrites.py +++ b/opshin/rewrite/rewrite_forbidden_overwrites.py @@ -13,6 +13,7 @@ "List", "Dict", "Union", + "Self", # decorator and class name "dataclass", "PlutusData", diff --git a/opshin/rewrite/rewrite_import_typing.py b/opshin/rewrite/rewrite_import_typing.py index 6f58414f..a906e7aa 100644 --- a/opshin/rewrite/rewrite_import_typing.py +++ b/opshin/rewrite/rewrite_import_typing.py @@ -49,6 +49,16 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef: and arg.annotation.id == "Self" ): node.body[i].args.args[j].annotation.idSelf = node.name + if ( + isinstance(arg.annotation, Subscript) + and arg.annotation.value.id == "Union" + ): + for k, s in enumerate(arg.annotation.slice.elts): + if isinstance(s, Name) and s.id == "Self": + node.body[i].args.args[j].annotation.slice.elts[ + k + ].idSelf = node.name + if ( isinstance(attribute.returns, Name) and attribute.returns.id == "Self" diff --git a/opshin/rewrite/rewrite_scoping.py b/opshin/rewrite/rewrite_scoping.py index 560af9a3..ca401ad3 100644 --- a/opshin/rewrite/rewrite_scoping.py +++ b/opshin/rewrite/rewrite_scoping.py @@ -40,6 +40,7 @@ class RewriteScoping(CompilingNodeTransformer): step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope" latest_scope_id: int scopes: typing.List[typing.Tuple[OrderedSet, int]] + current_Self: typing.Tuple[str, str] def variable_scope_id(self, name: str) -> int: """find the id of the scope in which this variable is defined (closest to its usage)""" @@ -86,6 +87,9 @@ def visit_Module(self, node: Module) -> Module: def visit_Name(self, node: Name) -> Name: nc = copy(node) # setting is handled in either enclosing module or function + if node.id == "Self": + assert node.idSelf == self.current_Self[1] + nc.idSelf_new = self.current_Self[0] nc.id = self.map_name(node.id) return nc @@ -93,6 +97,7 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef: cp_node = RecordScoper.scope(node, self) for i, attribute in enumerate(cp_node.body): if isinstance(attribute, FunctionDef): + self.current_Self = (cp_node.name, cp_node.orig_name) cp_node.body[i] = self.visit_FunctionDef(attribute, method=True) return cp_node @@ -108,17 +113,9 @@ def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> Function a_cp = copy(a) self.set_variable_scope(a.arg) a_cp.arg = self.map_name(a.arg) - a_cp.annotation = ( - self.visit(a.annotation) - if not hasattr(a.annotation, "idSelf") - else a.annotation - ) + a_cp.annotation = self.visit(a.annotation) node_cp.args.args.append(a_cp) - node_cp.returns = ( - self.visit(node.returns) - if not hasattr(node.returns, "idSelf") - else node.returns - ) + node_cp.returns = self.visit(node.returns) # vars defined in this scope shallow_node_def_collector = ShallowNameDefCollector() for s in node.body: diff --git a/opshin/std/fractions.py b/opshin/std/fractions.py index d76f149a..5a098efc 100644 --- a/opshin/std/fractions.py +++ b/opshin/std/fractions.py @@ -27,12 +27,18 @@ def ceil(self) -> int: self.numerator + self.denominator - sign(self.denominator) ) // self.denominator - def __add__(self, other: Self) -> Self: + def __add__(self, other: Union[Self, int]) -> Self: """returns self + other""" - return Fraction( - (self.numerator * other.denominator) + (other.numerator * self.denominator), - self.denominator * other.denominator, - ) + if isinstance(other, Fraction): + return Fraction( + (self.numerator * other.denominator) + + (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + else: + return Fraction( + self.numerator + (other * self.denominator), self.denominator + ) def __neg__( self, @@ -40,108 +46,136 @@ def __neg__( """returns -self""" return Fraction(-self.numerator, self.denominator) - def __sub__(self, other: Self) -> Self: + def __sub__(self, other: Union[Self, int]) -> Self: """returns self - other""" - return Fraction( - (self.numerator * other.denominator) - (other.numerator * self.denominator), - self.denominator * other.denominator, - ) + if isinstance(other, Fraction): + return Fraction( + (self.numerator * other.denominator) + - (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + else: + return Fraction( + self.numerator - (other * self.denominator), self.denominator + ) - def __mul__(self, other: Self) -> Self: + def __mul__(self, other: Union[Self, int]) -> Self: """returns self * other""" - return Fraction( - self.numerator * other.numerator, self.denominator * other.denominator - ) + if isinstance(other, Fraction): + return Fraction( + self.numerator * other.numerator, self.denominator * other.denominator + ) + else: + return Fraction(self.numerator * other, self.denominator) - def __truediv__(self, other: Self) -> Self: + def __truediv__(self, other: Union[Self, int]) -> Self: """returns self / other""" - return Fraction( - self.numerator * other.denominator, self.denominator * other.numerator - ) + if isinstance(other, Fraction): + return Fraction( + self.numerator * other.denominator, self.denominator * other.numerator + ) + else: + return Fraction(self.numerator, self.denominator * other) - def __ge__(self, other: Self) -> Self: + def __ge__(self, other: Union[Self, int]) -> bool: """returns self >= other""" - if self.denominator * other.denominator >= 0: - res = ( - self.numerator * other.denominator >= self.denominator * other.numerator - ) + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + >= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + <= self.denominator * other.numerator + ) + return res else: - res = ( - self.numerator * other.denominator <= self.denominator * other.numerator - ) - return res + if self.denominator >= 0: + res = self.numerator >= self.denominator * other + else: + res = self.numerator <= self.denominator * other + return res - def __le__(self, other: Self) -> Self: + def __le__(self, other: Union[Self, int]) -> bool: """returns self <= other""" - if self.denominator * other.denominator >= 0: - res = ( - self.numerator * other.denominator <= self.denominator * other.numerator - ) + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + <= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + >= self.denominator * other.numerator + ) + return res else: - res = ( - self.numerator * other.denominator >= self.denominator * other.numerator - ) - return res + if self.denominator >= 0: + res = self.numerator <= self.denominator * other + else: + res = self.numerator >= self.denominator * other + return res - def __eq__(self, other: Self) -> Self: + def __eq__(self, other: Union[Self, int]) -> bool: """returns self == other""" - return self.numerator * other.denominator == self.denominator * other.numerator + if isinstance(other, Fraction): + return ( + self.numerator * other.denominator == self.denominator * other.numerator + ) + else: + return self.numerator == self.denominator * other - def __lt__(self, other: Self) -> Self: + def __lt__(self, other: Union[Self, int]) -> bool: """returns self < other""" - if self.denominator * other.denominator >= 0: - res = ( - self.numerator * other.denominator < self.denominator * other.numerator - ) + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + < self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + > self.denominator * other.numerator + ) + return res else: - res = ( - self.numerator * other.denominator > self.denominator * other.numerator - ) - return res + if self.denominator >= 0: + res = self.numerator < self.denominator * other + else: + res = self.numerator > self.denominator * other + return res - def __gt__(self, other: Self) -> Self: + def __gt__(self, other: Union[Self, int]) -> bool: """returns self > other""" - if self.denominator * other.denominator >= 0: - res = ( - self.numerator * other.denominator > self.denominator * other.numerator - ) + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + > self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + < self.denominator * other.numerator + ) + return res else: - res = ( - self.numerator * other.denominator < self.denominator * other.numerator - ) - return res - - def __floordiv__(self, other: Self) -> int: - x = self / other - return x.numerator // x.denominator - - -def add_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns self + other""" - return Fraction( - (a.numerator * b.denominator) + (b.numerator * a.denominator), - a.denominator * b.denominator, - ) - - -def neg_fraction(a: Fraction) -> Fraction: - """returns -a""" - return Fraction(-a.numerator, a.denominator) - - -def sub_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns self - other""" - return add_fraction(a, neg_fraction(b)) - - -def mul_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns self * other""" - return Fraction(a.numerator * b.numerator, a.denominator * b.denominator) - - -def div_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns self / other""" - return Fraction(a.numerator * b.denominator, a.denominator * b.numerator) + if self.denominator >= 0: + res = self.numerator > self.denominator * other + else: + res = self.numerator < self.denominator * other + return res + + def __floordiv__(self, other: Union[Self, int]) -> int: + if isinstance(other, Fraction): + x = self / other + return x.numerator // x.denominator + else: + return self.numerator // (other * self.denominator) def _norm_signs_fraction(a: Fraction) -> Fraction: @@ -160,50 +194,5 @@ def norm_fraction(a: Fraction) -> Fraction: return _norm_gcd_fraction(_norm_signs_fraction(a)) -def ge_fraction(a: Fraction, b: Fraction) -> bool: - """returns self >= other""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator >= a.denominator * b.numerator - else: - res = a.numerator * b.denominator <= a.denominator * b.numerator - return res - - -def le_fraction(a: Fraction, b: Fraction) -> bool: - """returns self <= other""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator <= a.denominator * b.numerator - else: - res = a.numerator * b.denominator >= a.denominator * b.numerator - return res - - -def eq_fraction(a: Fraction, b: Fraction) -> bool: - """returns self == other""" - return a.numerator * b.denominator == a.denominator * b.numerator - - -def lt_fraction(a: Fraction, b: Fraction) -> bool: - """returns self < other""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator < a.denominator * b.numerator - else: - res = a.numerator * b.denominator > a.denominator * b.numerator - return res - - -def gt_fraction(a: Fraction, b: Fraction) -> bool: - """returns self > other""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator > a.denominator * b.numerator - else: - res = a.numerator * b.denominator < a.denominator * b.numerator - return res - - -def floor_fraction(a: Fraction) -> int: - return a.numerator // a.denominator - - def ceil_fraction(a: Fraction) -> int: return (a.numerator + a.denominator - sign(a.denominator)) // a.denominator diff --git a/opshin/tests/test_std/test_fractions.py b/opshin/tests/test_std/test_fractions.py index f5abf9de..c3f4fc3d 100644 --- a/opshin/tests/test_std/test_fractions.py +++ b/opshin/tests/test_std/test_fractions.py @@ -1,30 +1,37 @@ import hypothesis import hypothesis.strategies as hst - +from typing import Union from opshin.std import fractions as oc_fractions +from opshin.tests.utils import eval_uplc import fractions as native_fractions import math as native_math +from uplc.ast import PlutusConstr + non_null = hst.one_of(hst.integers(min_value=1), hst.integers(max_value=-1)) denormalized_fractions = hst.builds(oc_fractions.Fraction, hst.integers(), non_null) denormalized_fractions_non_null = hst.builds(oc_fractions.Fraction, non_null, non_null) +denormalized_fractions_and_int = hst.one_of([denormalized_fractions, hst.integers()]) +denormalized_fractions_and_int_non_null = hst.one_of( + [denormalized_fractions_non_null, non_null] +) -def native_fraction_from_oc_fraction(f: oc_fractions.Fraction): - return native_fractions.Fraction(f.numerator, f.denominator) +def native_fraction_from_oc_fraction(f: Union[oc_fractions.Fraction, int]): + if isinstance(f, oc_fractions.Fraction): + return native_fractions.Fraction(f.numerator, f.denominator) + else: + return f -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_add(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_added = oc_fractions.add_fraction(a, b) - oc_normalized = native_fraction_from_oc_fraction(oc_added) - assert oc_normalized == ( - native_fraction_from_oc_fraction(a) + native_fraction_from_oc_fraction(b) - ), "Invalid add" +def plutus_to_native(f): + assert isinstance(f, PlutusConstr) + assert f.constructor == 1 + return native_fractions.Fraction(*[field.value for field in f.fields]) -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_add_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_added = a + b oc_normalized = native_fraction_from_oc_fraction(oc_added) @@ -33,16 +40,7 @@ def test_add_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid add" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_sub(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_subbed = oc_fractions.sub_fraction(a, b) - oc_normalized = native_fraction_from_oc_fraction(oc_subbed) - assert oc_normalized == ( - native_fraction_from_oc_fraction(a) - native_fraction_from_oc_fraction(b) - ), "Invalid sub" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_sub_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_subbed = a - b oc_normalized = native_fraction_from_oc_fraction(oc_subbed) @@ -51,13 +49,6 @@ def test_sub_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid sub" -@hypothesis.given(denormalized_fractions) -def test_neg(a: oc_fractions.Fraction): - oc_negged = oc_fractions.neg_fraction(a) - oc_normalized = native_fraction_from_oc_fraction(oc_negged) - assert oc_normalized == -native_fraction_from_oc_fraction(a), "Invalid neg" - - @hypothesis.given(denormalized_fractions) def test_neg_dunder(a: oc_fractions.Fraction): oc_negged = -a @@ -65,16 +56,7 @@ def test_neg_dunder(a: oc_fractions.Fraction): assert oc_normalized == -native_fraction_from_oc_fraction(a), "Invalid neg" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_mul(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_mulled = oc_fractions.mul_fraction(a, b) - oc_normalized = native_fraction_from_oc_fraction(oc_mulled) - assert oc_normalized == ( - native_fraction_from_oc_fraction(a) * native_fraction_from_oc_fraction(b) - ), "Invalid mul" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_mul_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_mulled = a * b oc_normalized = native_fraction_from_oc_fraction(oc_mulled) @@ -83,16 +65,7 @@ def test_mul_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): ), "Invalid mul" -@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) -def test_div(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_divved = oc_fractions.div_fraction(a, b) - oc_normalized = native_fraction_from_oc_fraction(oc_divved) - assert oc_normalized == ( - native_fraction_from_oc_fraction(a) / native_fraction_from_oc_fraction(b) - ), "Invalid div" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int_non_null) def test_div_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_divved = a / b oc_normalized = native_fraction_from_oc_fraction(oc_divved) @@ -128,85 +101,42 @@ def test_norm_method(a: oc_fractions.Fraction): assert oc_normed.denominator == oc_normalized.denominator, "Invalid norm" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_ge(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_ge = oc_fractions.ge_fraction(a, b) - ge = native_fraction_from_oc_fraction(a) >= native_fraction_from_oc_fraction(b) - assert oc_ge == ge, "Invalid ge" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_ge_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_ge = a >= b ge = native_fraction_from_oc_fraction(a) >= native_fraction_from_oc_fraction(b) assert oc_ge == ge, "Invalid ge" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_le(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_le = oc_fractions.le_fraction(a, b) - le = native_fraction_from_oc_fraction(a) <= native_fraction_from_oc_fraction(b) - assert oc_le == le, "Invalid le" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_le_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_le = a <= b le = native_fraction_from_oc_fraction(a) <= native_fraction_from_oc_fraction(b) assert oc_le == le, "Invalid le" -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_lt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_lt = oc_fractions.lt_fraction(a, b) - lt = native_fraction_from_oc_fraction(a) < native_fraction_from_oc_fraction(b) - assert oc_lt == lt, "Invalid lt" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_lt_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_lt = a < b lt = native_fraction_from_oc_fraction(a) < native_fraction_from_oc_fraction(b) assert oc_lt == lt, "Invalid lt" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_gt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_gt = oc_fractions.gt_fraction(a, b) - gt = native_fraction_from_oc_fraction(a) > native_fraction_from_oc_fraction(b) - assert oc_gt == gt, "Invalid gt" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_gt_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_gt = a > b gt = native_fraction_from_oc_fraction(a) > native_fraction_from_oc_fraction(b) assert oc_gt == gt, "Invalid gt" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_eq(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_eq = oc_fractions.eq_fraction(a, b) - eq = native_fraction_from_oc_fraction(a) == native_fraction_from_oc_fraction(b) - assert oc_eq == eq, "Invalid eq" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_eq_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_eq = a == b eq = native_fraction_from_oc_fraction(a) == native_fraction_from_oc_fraction(b) assert oc_eq == eq, "Invalid eq" -@hypothesis.given(denormalized_fractions) -def test_floor(a: oc_fractions.Fraction): - oc_floor = oc_fractions.floor_fraction(a) - assert ( - native_math.floor(native_fraction_from_oc_fraction(a)) == oc_floor - ), "Invalid floor" - - -@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int_non_null) def test_floor_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): oc_floor = a // b floor = native_fraction_from_oc_fraction(a) // native_fraction_from_oc_fraction(b) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 4c7c3d3d..a1fd8117 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -369,7 +369,10 @@ def type_from_annotation(self, ann: expr): if isinstance(ann, Name): if ann.id in ATOMIC_TYPES: return ATOMIC_TYPES[ann.id] - v_t = self.variable_type(ann.id) + if ann.id == "Self": + v_t = self.variable_type(ann.idSelf_new) + else: + v_t = self.variable_type(ann.id) if isinstance(v_t, ClassType): return v_t raise TypeInferenceError( @@ -465,9 +468,19 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: ), f"The following Dunder methods are supported {list(DUNDER_MAP.values())}. Received {func.name} which is not supported" func.name = f"{n.name}_{attribute.name}" for arg in func.args.args: - assert ( - arg.annotation is None or arg.annotation.id != n.name - ), "Invalid Python, class name is undefined at this stage." + if not arg.annotation is None: + if isinstance(arg.annotation, ast.Name): + assert ( + arg.annotation is None or arg.annotation.id != n.name + ), "Invalid Python, class name is undefined at this stage." + elif ( + isinstance(arg.annotation, ast.Subscript) + and arg.annotation.value.id == "Union" + ): + for s in arg.annotation.slice.elts: + assert ( + s.id != n.name + ), "Invalid Python, class name is undefined at this stage." assert ( func.returns is None or func.returns.id != n.name ), "Invalid Python, class name is undefined at this stage" From 8b69eec1cbdb130065edb7b0b53157bc7705f701 Mon Sep 17 00:00:00 2001 From: SCM Date: Tue, 3 Sep 2024 18:25:38 +0100 Subject: [PATCH 03/14] Fix casting to Builtin --- opshin/compiler.py | 2 ++ opshin/std/fractions.py | 3 ++- opshin/tests/test_Unions.py | 19 +++++++++++++++++++ opshin/type_inference.py | 12 +++++++++--- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index ac8b0e27..b48471cf 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -393,6 +393,8 @@ def visit_Name(self, node: TypedName) -> plt.AST: if isinstance(node.typ, ClassType): # if this is not an instance but a class, call the constructor return node.typ.constr() + if hasattr(node, "is_wrapped"): + return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id))) return plt.Force(plt.Var(node.id)) def visit_Expr(self, node: TypedExpr) -> CallAST: diff --git a/opshin/std/fractions.py b/opshin/std/fractions.py index 5a098efc..f6f236c9 100644 --- a/opshin/std/fractions.py +++ b/opshin/std/fractions.py @@ -37,7 +37,8 @@ def __add__(self, other: Union[Self, int]) -> Self: ) else: return Fraction( - self.numerator + (other * self.denominator), self.denominator + (self.numerator) + (other * self.denominator), + self.denominator, ) def __neg__( diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py index b9632e2e..ef3b8930 100644 --- a/opshin/tests/test_Unions.py +++ b/opshin/tests/test_Unions.py @@ -305,3 +305,22 @@ def validator(x: Union[int, bytes, bool]) -> int: with self.assertRaises(CompilerError) as ce: res = eval_uplc_value(source_code, True) self.assertIsInstance(ce.exception.orig_err, AssertionError) + + @hypothesis.given(st.sampled_from([14, b""])) + def test_Union_builtin_cast(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def validator(x: Union[int,bytes]) -> int: + k: int = 0 + if isinstance(x, int): + k = x+5 + elif isinstance(x, bytes): + k = 7 + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 5 if isinstance(x, int) else 7 + self.assertEqual(res, real) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index a1fd8117..6801d2a7 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -279,6 +279,7 @@ class AggressiveTypeInferencer(CompilingNodeTransformer): def __init__(self, allow_isinstance_anything=False): self.allow_isinstance_anything = allow_isinstance_anything self.FUNCTION_ARGUMENT_REGISTRY = {} + self.wrapped = [] # A stack of dictionaries for storing scoped knowledge of variable types self.scopes = [INITIAL_SCOPE] @@ -625,15 +626,19 @@ def visit_If(self, node: If) -> TypedIf: ).visit(typed_if.test) # for the time of the branch, these types are cast initial_scope = copy(self.scopes[-1]) - self.implement_typechecks(typchecks) + wrapped = self.implement_typechecks(typchecks) + self.wrapped.extend(wrapped.keys()) typed_if.body = self.visit_sequence(node.body) + self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()] + # save resulting types final_scope_body = copy(self.scopes[-1]) # reverse typechecks and remove typing of one branch self.scopes[-1] = initial_scope # for the time of the else branch, the inverse types hold - self.implement_typechecks(inv_typchecks) + wrapped = self.implement_typechecks(inv_typchecks) typed_if.orelse = self.visit_sequence(node.orelse) + self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()] final_scope_else = self.scopes[-1] # unify the resulting branch scopes self.scopes[-1] = merge_scope(final_scope_body, final_scope_else) @@ -702,6 +707,8 @@ def visit_Name(self, node: Name) -> TypedName: else: # Make sure that the rhs of an assign is evaluated first tn.typ = self.variable_type(node.id) + if node.id in self.wrapped: + tn.is_wrapped = True return tn def visit_keyword(self, node: keyword) -> Typedkeyword: @@ -864,7 +871,6 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript: "Dict", "List", ]: - ts.value = ts.typ = self.type_from_annotation(ts) return ts From 2b35b250832e2b82fb1f46a7d05d45114601d29d Mon Sep 17 00:00:00 2001 From: SCM Date: Tue, 3 Sep 2024 21:53:37 +0100 Subject: [PATCH 04/14] check for is_wrapped being True --- opshin/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index b48471cf..163fa29e 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -393,7 +393,7 @@ def visit_Name(self, node: TypedName) -> plt.AST: if isinstance(node.typ, ClassType): # if this is not an instance but a class, call the constructor return node.typ.constr() - if hasattr(node, "is_wrapped"): + if hasattr(node, "is_wrapped") and node.is_wrapped: return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id))) return plt.Force(plt.Var(node.id)) From 5771e2b32164ba65fd2b26b2733e3c9b579772dc Mon Sep 17 00:00:00 2001 From: SCM Date: Wed, 4 Sep 2024 18:27:10 +0100 Subject: [PATCH 05/14] fix fractions for python<3.11 --- opshin/fun_impls.py | 6 ++++++ opshin/std/fractions.py | 25 ++++++++++++------------- opshin/tests/test_std/test_fractions.py | 20 +++++++++++++++++++- opshin/type_inference.py | 17 ++++++++++++++--- 4 files changed, 51 insertions(+), 17 deletions(-) diff --git a/opshin/fun_impls.py b/opshin/fun_impls.py index 7a65dee8..ea19bb15 100644 --- a/opshin/fun_impls.py +++ b/opshin/fun_impls.py @@ -95,6 +95,12 @@ def type_from_args(self, args: typing.List[Type]) -> FunctionType: return FunctionType(args, BoolInstanceType) def impl_from_args(self, args: typing.List[Type]) -> plt.AST: + if not (isinstance(args[0], UnionType) or isinstance(args[0].typ, UnionType)): + if args[0].typ == args[1]: + return OLambda(["x"], plt.Bool(True)) + else: + return OLambda(["x"], plt.Bool(False)) + if isinstance(args[1], IntegerType): return OLambda( ["x"], diff --git a/opshin/std/fractions.py b/opshin/std/fractions.py index f6f236c9..9c96e6bd 100644 --- a/opshin/std/fractions.py +++ b/opshin/std/fractions.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from pycardano import Datum as Anything, PlutusData from typing import Dict, List, Union -from typing import Self from opshin.std.math import * @@ -18,7 +17,7 @@ class Fraction(PlutusData): numerator: int denominator: int - def norm(self) -> Self: + def norm(self) -> "Fraction": """Restores the invariant that num/denom are in the smallest possible denomination and denominator > 0""" return _norm_gcd_fraction(_norm_signs_fraction(self)) @@ -27,7 +26,7 @@ def ceil(self) -> int: self.numerator + self.denominator - sign(self.denominator) ) // self.denominator - def __add__(self, other: Union[Self, int]) -> Self: + def __add__(self, other: Union["Fraction", int]) -> "Fraction": """returns self + other""" if isinstance(other, Fraction): return Fraction( @@ -43,11 +42,11 @@ def __add__(self, other: Union[Self, int]) -> Self: def __neg__( self, - ) -> Self: + ) -> "Fraction": """returns -self""" return Fraction(-self.numerator, self.denominator) - def __sub__(self, other: Union[Self, int]) -> Self: + def __sub__(self, other: Union["Fraction", int]) -> "Fraction": """returns self - other""" if isinstance(other, Fraction): return Fraction( @@ -60,7 +59,7 @@ def __sub__(self, other: Union[Self, int]) -> Self: self.numerator - (other * self.denominator), self.denominator ) - def __mul__(self, other: Union[Self, int]) -> Self: + def __mul__(self, other: Union["Fraction", int]) -> "Fraction": """returns self * other""" if isinstance(other, Fraction): return Fraction( @@ -69,7 +68,7 @@ def __mul__(self, other: Union[Self, int]) -> Self: else: return Fraction(self.numerator * other, self.denominator) - def __truediv__(self, other: Union[Self, int]) -> Self: + def __truediv__(self, other: Union["Fraction", int]) -> "Fraction": """returns self / other""" if isinstance(other, Fraction): return Fraction( @@ -78,7 +77,7 @@ def __truediv__(self, other: Union[Self, int]) -> Self: else: return Fraction(self.numerator, self.denominator * other) - def __ge__(self, other: Union[Self, int]) -> bool: + def __ge__(self, other: Union["Fraction", int]) -> bool: """returns self >= other""" if isinstance(other, Fraction): if self.denominator * other.denominator >= 0: @@ -99,7 +98,7 @@ def __ge__(self, other: Union[Self, int]) -> bool: res = self.numerator <= self.denominator * other return res - def __le__(self, other: Union[Self, int]) -> bool: + def __le__(self, other: Union["Fraction", int]) -> bool: """returns self <= other""" if isinstance(other, Fraction): if self.denominator * other.denominator >= 0: @@ -120,7 +119,7 @@ def __le__(self, other: Union[Self, int]) -> bool: res = self.numerator >= self.denominator * other return res - def __eq__(self, other: Union[Self, int]) -> bool: + def __eq__(self, other: Union["Fraction", int]) -> bool: """returns self == other""" if isinstance(other, Fraction): return ( @@ -129,7 +128,7 @@ def __eq__(self, other: Union[Self, int]) -> bool: else: return self.numerator == self.denominator * other - def __lt__(self, other: Union[Self, int]) -> bool: + def __lt__(self, other: Union["Fraction", int]) -> bool: """returns self < other""" if isinstance(other, Fraction): if self.denominator * other.denominator >= 0: @@ -150,7 +149,7 @@ def __lt__(self, other: Union[Self, int]) -> bool: res = self.numerator > self.denominator * other return res - def __gt__(self, other: Union[Self, int]) -> bool: + def __gt__(self, other: Union["Fraction", int]) -> bool: """returns self > other""" if isinstance(other, Fraction): if self.denominator * other.denominator >= 0: @@ -171,7 +170,7 @@ def __gt__(self, other: Union[Self, int]) -> bool: res = self.numerator < self.denominator * other return res - def __floordiv__(self, other: Union[Self, int]) -> int: + def __floordiv__(self, other: Union["Fraction", int]) -> int: if isinstance(other, Fraction): x = self / other return x.numerator // x.denominator diff --git a/opshin/tests/test_std/test_fractions.py b/opshin/tests/test_std/test_fractions.py index c3f4fc3d..148ef078 100644 --- a/opshin/tests/test_std/test_fractions.py +++ b/opshin/tests/test_std/test_fractions.py @@ -2,7 +2,7 @@ import hypothesis.strategies as hst from typing import Union from opshin.std import fractions as oc_fractions -from opshin.tests.utils import eval_uplc +from opshin.tests.utils import eval_uplc, eval_uplc_value import fractions as native_fractions import math as native_math @@ -21,6 +21,8 @@ def native_fraction_from_oc_fraction(f: Union[oc_fractions.Fraction, int]): if isinstance(f, oc_fractions.Fraction): return native_fractions.Fraction(f.numerator, f.denominator) + elif isinstance(f, PlutusConstr): + return native_fractions.Fraction(*[x.value for x in f.fields]) else: return f @@ -158,3 +160,19 @@ def test_ceil_method(a: oc_fractions.Fraction): assert ( native_math.ceil(native_fraction_from_oc_fraction(a)) == oc_ceil ), "Invalid ceil" + + +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_uplc(a, b): + source_code = """ +from opshin.std.fractions import * +from typing import Dict, List, Union + +def validator(a: Fraction, b: Union[Fraction, int]) -> Fraction: + return a+b +""" + ret = eval_uplc(source_code, a, b) + print(ret) + assert ( + native_fraction_from_oc_fraction(a) + native_fraction_from_oc_fraction(b) + ) == native_fraction_from_oc_fraction(ret), "invalid add" diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 6801d2a7..2a8e0330 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -367,6 +367,15 @@ def type_from_annotation(self, ann: expr): if isinstance(ann, Constant): if ann.value is None: return UnitType() + else: + for scope in reversed(self.scopes): + for key, value in scope.items(): + if ( + isinstance(value, RecordType) + and value.record.orig_name == ann.value + ): + return value + if isinstance(ann, Name): if ann.id in ATOMIC_TYPES: return ATOMIC_TYPES[ann.id] @@ -480,10 +489,12 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: ): for s in arg.annotation.slice.elts: assert ( - s.id != n.name + isinstance(s, Name) and s.id != n.name + ) or isinstance( + s, Constant ), "Invalid Python, class name is undefined at this stage." - assert ( - func.returns is None or func.returns.id != n.name + assert isinstance(func.returns, Constant) or ( + isinstance(func.returns, Name) and func.returns.id != n.name ), "Invalid Python, class name is undefined at this stage" ann = ast.Name(id=n.name, ctx=ast.Load()) custom_fix_missing_locations(ann, attribute.args.args[0]) From ab31e1ac35b4953e04a7fc6af8e851f205423593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Wed, 4 Sep 2024 21:24:58 +0200 Subject: [PATCH 06/14] Apply suggestions from code review Also check unwrapping for bytes type --- opshin/tests/test_Unions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py index ef3b8930..a51ed7a8 100644 --- a/opshin/tests/test_Unions.py +++ b/opshin/tests/test_Unions.py @@ -318,9 +318,9 @@ def validator(x: Union[int,bytes]) -> int: if isinstance(x, int): k = x+5 elif isinstance(x, bytes): - k = 7 + k = len(x) return k """ res = eval_uplc_value(source_code, x) - real = x + 5 if isinstance(x, int) else 7 + real = x + 5 if isinstance(x, int) else len(x) self.assertEqual(res, real) From 0b7ad5006523fa6d6662eeac34e97a7bdc7279d8 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 4 Sep 2024 21:28:21 +0200 Subject: [PATCH 07/14] Add failing test case for internally created, unioned values --- opshin/tests/test_Unions.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py index a51ed7a8..b124ee49 100644 --- a/opshin/tests/test_Unions.py +++ b/opshin/tests/test_Unions.py @@ -324,3 +324,29 @@ def validator(x: Union[int,bytes]) -> int: res = eval_uplc_value(source_code, x) real = x + 5 if isinstance(x, int) else len(x) self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_internal(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(x: Union[int,bytes]) -> int: + k: int = 0 + if isinstance(x, int): + k = x+5 + elif isinstance(x, bytes): + k = len(x) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 6 if isinstance(x, int) else len(x) + self.assertEqual(res, real) From 303821bbe914d305a4322688718f4434d9f06348 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 4 Sep 2024 21:31:54 +0200 Subject: [PATCH 08/14] Add another failing testcase --- opshin/tests/test_Unions.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/opshin/tests/test_Unions.py b/opshin/tests/test_Unions.py index b124ee49..3bcdbcae 100644 --- a/opshin/tests/test_Unions.py +++ b/opshin/tests/test_Unions.py @@ -348,5 +348,25 @@ def validator(x: int) -> int: return k """ res = eval_uplc_value(source_code, x) - real = x + 6 if isinstance(x, int) else len(x) + real = x + 6 if x > 5 else len(x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_direct(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def validator(x: int) -> int: + y: Union[int,bytes] = 5 if x > 5 else b"0"*x + k: int = 0 + if isinstance(y, int): + k = y+1 + elif isinstance(y, bytes): + k = len(y) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 1 if x > 5 else len(x) self.assertEqual(res, real) From 5105204fbe7d9554f469c8dcbe84fa40056ef887 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 4 Sep 2024 21:56:01 +0200 Subject: [PATCH 09/14] Fix relative import --- tests/test_std/test_fractions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_std/test_fractions.py b/tests/test_std/test_fractions.py index 148ef078..2c09b927 100644 --- a/tests/test_std/test_fractions.py +++ b/tests/test_std/test_fractions.py @@ -2,7 +2,7 @@ import hypothesis.strategies as hst from typing import Union from opshin.std import fractions as oc_fractions -from opshin.tests.utils import eval_uplc, eval_uplc_value +from ..utils import eval_uplc, eval_uplc_value import fractions as native_fractions import math as native_math From 366a0e65e828102689231fe5a00f8a8e693dc0f9 Mon Sep 17 00:00:00 2001 From: SCM Date: Thu, 5 Sep 2024 11:43:37 +0100 Subject: [PATCH 10/14] bug fixes --- opshin/compiler.py | 10 +++++++++- tests/test_Unions.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index 163fa29e..1cf0b035 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -435,7 +435,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST: assert isinstance(t, InstanceType) # pass in all arguments evaluated with the statemonad a_int = self.visit(a) - if isinstance(t.typ, AnyType): + if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType): # if the function expects input of generic type data, wrap data before passing it inside a_int = transform_output_map(a.typ)(a_int) args.append(a_int) @@ -916,6 +916,14 @@ def visit_Dict(self, node: TypedDict) -> plt.AST: return l def visit_IfExp(self, node: TypedIfExp) -> plt.AST: + if isinstance(node.typ.typ, UnionType): + body = self.visit(node.body) + orelse = self.visit(node.orelse) + if not isinstance(node.body.typ, UnionType): + body = transform_output_map(node.body.typ)(body) + if not isinstance(node.orelse.typ, UnionType): + orelse = transform_output_map(node.orelse.typ)(orelse) + return plt.Ite(self.visit(node.test), body, orelse) return plt.Ite( self.visit(node.test), self.visit(node.body), diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 867b4848..8f1c0a00 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -348,7 +348,7 @@ def validator(x: int) -> int: return k """ res = eval_uplc_value(source_code, x) - real = x + 6 if x > 5 else len(x) + real = x + 6 if x > 5 else len(b"0" * x) self.assertEqual(res, real) @hypothesis.given(st.sampled_from(range(14))) @@ -368,5 +368,5 @@ def validator(x: int) -> int: return k """ res = eval_uplc_value(source_code, x) - real = x + 1 if x > 5 else len(x) + real = 5 + 1 if x > 5 else len(b"0" * x) self.assertEqual(res, real) From aceec49be98ee36cd6cac89fafd610e19426beb8 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 5 Sep 2024 14:37:18 +0200 Subject: [PATCH 11/14] Another exception for ifexpr --- tests/test_Unions.py | 54 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 8f1c0a00..758e4e4c 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -370,3 +370,57 @@ def validator(x: int) -> int: res = eval_uplc_value(source_code, x) real = 5 + 1 if x > 5 else len(b"0" * x) self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_cast_ifexpr(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + x: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + y: bytes + +def foo(x: Union[A, B]) -> int: + k: int = x.x + 1 if isinstance(x, A) else len(x.y) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(A(x)) + else: + k = foo(B(b"0"*x)) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 1 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_ifexpr(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(x: Union[int, bytes]) -> int: + k: int = x + 1 if isinstance(x, int) else len(x) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 2 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) From dd526c0f7c9b71f234cc33f9db7b26f0ea4794a3 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 5 Sep 2024 14:50:04 +0200 Subject: [PATCH 12/14] Add test cases for list filter expression These don't compile, which is not ideal but better than compiling and then crashing. --- tests/test_Unions.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 758e4e4c..dd14b787 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -420,6 +420,66 @@ def validator(x: int) -> int: else: k = foo(b"0"*x) return k +""" + res = eval_uplc_value(source_code, x) + real = x + 2 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @unittest.skip("Throw compilation error, hence not critical") + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_cast_List(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + x: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + y: bytes + +def foo(xs: List[Union[A, B]]) -> List[int]: + k: List[int] = [x.x + 1 for x in xs if isinstance(x, A)] + if not k: + k = [len(x.y) for x in xs if isinstance(x, B)] + return k + +def validator(x: int) -> int: + if x > 5: + k = foo([A(x)]) + else: + k = foo([B(b"0"*x)]) + return k[0] +""" + res = eval_uplc_value(source_code, x) + real = x + 1 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @unittest.skip("Throw compilation error, hence not critical") + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_ifexpr(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(xs: List[Union[int, bytes]]) -> List[int]: + k: List[int] = [x + 1 for x in xs if isinstance(x, int)] + if not k: + k = [len(x) for x in xs if isinstance(x, bytes)] + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k[0] """ res = eval_uplc_value(source_code, x) real = x + 2 if x > 5 else len(b"0" * x) From 3d5e861b5a5875bed02be9f05590f265dcabb8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Thu, 5 Sep 2024 15:56:51 +0200 Subject: [PATCH 13/14] Update tests/test_Unions.py --- tests/test_Unions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_Unions.py b/tests/test_Unions.py index dd14b787..2aad1f2d 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -462,7 +462,7 @@ def validator(x: int) -> int: @unittest.skip("Throw compilation error, hence not critical") @hypothesis.given(st.sampled_from(range(14))) - def test_Union_builtin_cast_ifexpr(self, x): + def test_Union_builtin_cast_List(self, x): source_code = """ from dataclasses import dataclass from typing import Dict, List, Union From 2eb83db8d4c8a5d5e5de4e28673aff40e65487f1 Mon Sep 17 00:00:00 2001 From: SCM Date: Fri, 4 Oct 2024 16:50:45 +0100 Subject: [PATCH 14/14] fix IfExp --- opshin/type_inference.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 2a8e0330..e31c6df6 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -648,6 +648,7 @@ def visit_If(self, node: If) -> TypedIf: self.scopes[-1] = initial_scope # for the time of the else branch, the inverse types hold wrapped = self.implement_typechecks(inv_typchecks) + self.wrapped.extend(wrapped.keys()) typed_if.orelse = self.visit_sequence(node.orelse) self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()] final_scope_else = self.scopes[-1] @@ -1149,10 +1150,15 @@ def visit_IfExp(self, node: IfExp) -> TypedIfExp: self.allow_isinstance_anything ).visit(node_cp.test) prevtyps = self.implement_typechecks(typchecks) + self.wrapped.extend(prevtyps.keys()) node_cp.body = self.visit(node.body) + self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()] + self.implement_typechecks(prevtyps) prevtyps = self.implement_typechecks(inv_typchecks) + self.wrapped.extend(prevtyps.keys()) node_cp.orelse = self.visit(node.orelse) + self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()] self.implement_typechecks(prevtyps) if node_cp.body.typ >= node_cp.orelse.typ: node_cp.typ = node_cp.body.typ