From 769c3f89d95692d88e68a265d49d3c04e9d733f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Wed, 31 Jan 2024 09:38:56 +0100 Subject: [PATCH 1/4] Move unary implementations into types classes --- opshin/compiler.py | 17 ++++--------- opshin/types.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index e01efd92..47737d5e 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -45,11 +45,6 @@ Or: plt.Or, } -UnaryOpMap = { - Not: {BoolInstanceType: plt.Not}, - USub: {IntegerInstanceType: lambda x: plt.SubtractInteger(plt.Integer(0), x)}, -} - def rec_constant_map_data(c): if isinstance(c, bool): @@ -208,13 +203,11 @@ def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST: return ops def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST: - opmap = UnaryOpMap.get(type(node.op)) - assert opmap is not None, f"Operator {type(node.op)} is not supported" - op = opmap.get(node.operand.typ) - assert ( - op is not None - ), f"Operator {type(node.op)} is not supported for type {node.operand.typ}" - return op(self.visit(node.operand)) + op = node.left.typ.unop(node.op) + return plt.Apply( + op, + self.visit(node.left), + ) def visit_Compare(self, node: TypedCompare) -> plt.AST: assert len(node.ops) == 1, "Only single comparisons are supported" diff --git a/opshin/types.py b/opshin/types.py index 6de3be33..09c6aa28 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -107,6 +107,40 @@ def _binop_bin_fun( f"{type(self).__name__} can not be used with operation {binop.__class__.__name__}" ) + def unop_type(self, unop: unaryop) -> "Type": + """ + Type of a unary operation on self. + """ + return FunctionType( + [InstanceType(self)], + InstanceType(self._unop_return_type(unop)), + ) + + def _unop_return_type(self, unop: unaryop) -> "Type": + """ + Return the type of a binary operation between self and other + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement {unop.__class__.__name__}" + ) + + def unop(self, unop: unaryop) -> plt.AST: + """ + Implements a unary operation on self + """ + return OLambda( + ["self"], + self._unop_fun(unop)(OVar("self")), + ) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + """ + Returns a unary function that implements the unary operation on self. + """ + raise NotImplementedError( + f"{type(self).__name__} can not be used with operation {unop.__class__.__name__}" + ) + @dataclass(frozen=True, unsafe_hash=True) class Record: @@ -1541,6 +1575,24 @@ def _binop_bin_fun(self, binop: operator, other: AST): elif other.typ == StringInstanceType: return lambda x, y: StrIntMulImpl(y, x) + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, USub): + return IntegerType() + elif isinstance(unop, UAdd): + return IntegerType() + elif isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, USub): + return lambda x: plt.SubtractInteger(plt.Integer(0), x) + if isinstance(unop, UAdd): + return lambda x: x + if isinstance(unop, Not): + return lambda x: plt.NotEqualsInteger(x, plt.Integer(0)) + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class StringType(AtomicType): @@ -1967,6 +2019,16 @@ def stringify(self, recursive: bool = False) -> plt.AST: ), ) + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return plt.Not + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class UnitType(AtomicType): From f2efa22bafa48a7b3d7a87158102f50201fecc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Wed, 31 Jan 2024 09:43:56 +0100 Subject: [PATCH 2/4] Fixes --- opshin/compiler.py | 4 ++-- opshin/types.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index 47737d5e..3fd174b9 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -203,10 +203,10 @@ def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST: return ops def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST: - op = node.left.typ.unop(node.op) + op = node.operand.typ.unop(node.op) return plt.Apply( op, - self.visit(node.left), + self.visit(node.operand), ) def visit_Compare(self, node: TypedCompare) -> plt.AST: diff --git a/opshin/types.py b/opshin/types.py index 09c6aa28..1cea8228 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1034,6 +1034,16 @@ def _binop_bin_fun(self, binop: operator, other: AST): ): return plt.AppendList + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False)) + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class DictType(ClassType): @@ -1321,6 +1331,16 @@ def CustomMapFilterList( ) return OLambda(["self"], mapped_attrs) + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False)) + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class FunctionType(ClassType): @@ -1382,6 +1402,12 @@ def binop_type(self, binop: operator, other: "Type") -> "Type": def binop(self, binop: operator, other: AST) -> plt.AST: return self.typ.binop(binop, other) + def unop_type(self, unop: unaryop) -> "Type": + return self.typ.unop_type(unop) + + def unop(self, unop: unaryop) -> plt.AST: + return self.typ.unop(unop) + @dataclass(frozen=True, unsafe_hash=True) class IntegerType(AtomicType): From 7e651cf76e2580852395f1cca2fc0381ffe8563f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Wed, 31 Jan 2024 10:02:19 +0100 Subject: [PATCH 3/4] Add tests and fix for unop implementations of lists, dict and int --- opshin/tests/test_ops.py | 53 ++++++++++++++++++++++++++++++++++++++++ opshin/type_inference.py | 2 +- opshin/types.py | 2 +- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/opshin/tests/test_ops.py b/opshin/tests/test_ops.py index c204ab24..d13a3f4a 100644 --- a/opshin/tests/test_ops.py +++ b/opshin/tests/test_ops.py @@ -112,6 +112,24 @@ def validator(x: int) -> int: ret = eval_uplc_value(source_code, x) self.assertEqual(ret, -x, "not returned wrong value") + @given(x=st.integers()) + def test_uadd_int(self, x): + source_code = """ +def validator(x: int) -> int: + return +x + """ + ret = eval_uplc_value(source_code, x) + self.assertEqual(ret, +x, "not returned wrong value") + + @given(x=st.integers()) + def test_not_int(self, x): + source_code = """ +def validator(x: int) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, x) + self.assertEqual(bool(ret), not x, "not returned wrong value") + @given(x=st.integers(), y=st.integers()) def test_add_int(self, x, y): source_code = """ @@ -465,6 +483,15 @@ def validator(x: List[bytes], y: bytes) -> bool: ret = eval_uplc_value(source_code, xs, y) self.assertEqual(ret, y in xs, "list in returned wrong value") + @given(x=st.lists(st.integers())) + def test_not_list(self, x): + source_code = """ +def validator(x: List[int]) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, x) + self.assertEqual(bool(ret), not x, "not returned wrong value") + @given(x=st.binary(), y=st.binary()) def test_eq_bytes(self, x, y): source_code = """ @@ -759,6 +786,32 @@ def validator(x: List[int]) -> str: ret, exp, "integer list string formatting returned wrong value" ) + @given(x=st.dictionaries(st.integers(), st.integers())) + def test_not_dict(self, x): + source_code = """ +def validator(x: Dict[int, int]) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, x) + self.assertEqual(bool(ret), not x, "not returned wrong value") + + @given(xs=st.dictionaries(st.integers(), st.integers()), y=st.integers()) + def test_index_dict(self, xs, y): + source_code = """ +from typing import Dict, List, Union +def validator(x: Dict[int, int], y: int) -> int: + return x[y] + """ + try: + exp = xs[y] + except KeyError: + exp = None + try: + ret = eval_uplc_value(source_code, xs, y) + except Exception as e: + ret = None + self.assertEqual(ret, exp, "list index returned wrong value") + @given(xs=st.dictionaries(formattable_text, st.integers())) @example(dict()) @example({"": 0}) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index bfa26784..8b3ef3e3 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -650,7 +650,7 @@ def visit_BoolOp(self, node: BoolOp) -> TypedBoolOp: def visit_UnaryOp(self, node: UnaryOp) -> TypedUnaryOp: tu = copy(node) tu.operand = self.visit(node.operand) - tu.typ = tu.operand.typ + tu.typ = tu.operand.typ.typ.unop_type(node.op).rettyp return tu def visit_Subscript(self, node: Subscript) -> TypedSubscript: diff --git a/opshin/types.py b/opshin/types.py index 1cea8228..89acfc5e 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1616,7 +1616,7 @@ def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: if isinstance(unop, UAdd): return lambda x: x if isinstance(unop, Not): - return lambda x: plt.NotEqualsInteger(x, plt.Integer(0)) + return lambda x: plt.EqualsInteger(x, plt.Integer(0)) return super()._unop_fun(unop) From 5dee07104c85d9f445e876e12c74107a0dd991ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Wed, 31 Jan 2024 10:09:41 +0100 Subject: [PATCH 4/4] Add unop implementations for bytes, str and None --- opshin/tests/test_ops.py | 26 ++++++++++++++++++++++++++ opshin/types.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/opshin/tests/test_ops.py b/opshin/tests/test_ops.py index d13a3f4a..5e32f962 100644 --- a/opshin/tests/test_ops.py +++ b/opshin/tests/test_ops.py @@ -852,3 +852,29 @@ def validator(x: Anything) -> str: self.assertEqual( ret, exp, "raw cbor string formatting returned wrong value" ) + + @given(x=st.text()) + def test_not_string(self, x): + source_code = """ +def validator(x: str) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, x.encode("utf8")) + self.assertEqual(bool(ret), not x, "not returned wrong value") + + @given(x=st.binary()) + def test_not_bytes(self, x): + source_code = """ +def validator(x: bytes) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, x) + self.assertEqual(bool(ret), not x, "not returned wrong value") + + def test_not_unit(self): + source_code = """ +def validator(x: None) -> bool: + return not x + """ + ret = eval_uplc_value(source_code, uplc.BuiltinUnit()) + self.assertEqual(bool(ret), not None, "not returned wrong value") diff --git a/opshin/types.py b/opshin/types.py index 89acfc5e..965caa05 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1669,6 +1669,18 @@ def _binop_bin_fun(self, binop: operator, other: AST): if other.typ == IntegerInstanceType: return StrIntMulImpl + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return lambda x: plt.EqualsInteger( + plt.LengthOfByteString(plt.EncodeUtf8(x)), plt.Integer(0) + ) + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class ByteStringType(AtomicType): @@ -2010,6 +2022,18 @@ def _binop_bin_fun(self, binop: operator, other: AST): if other.typ == IntegerInstanceType: return ByteStrIntMulImpl + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return lambda x: plt.EqualsInteger( + plt.LengthOfByteString(x), plt.Integer(0) + ) + return super()._unop_fun(unop) + @dataclass(frozen=True, unsafe_hash=True) class BoolType(AtomicType): @@ -2069,6 +2093,16 @@ def cmp(self, op: cmpop, o: "Type") -> plt.AST: def stringify(self, recursive: bool = False) -> plt.AST: return OLambda(["self"], plt.Text("None")) + def _unop_return_type(self, unop: unaryop) -> "Type": + if isinstance(unop, Not): + return BoolType() + return super()._unop_return_type(unop) + + def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: + if isinstance(unop, Not): + return lambda x: plt.Bool(True) + return super()._unop_fun(unop) + IntegerInstanceType = InstanceType(IntegerType()) StringInstanceType = InstanceType(StringType())