Skip to content

Commit

Permalink
Merge pull request #325 from OpShin/feat/unop_dynamic
Browse files Browse the repository at this point in the history
Make unary operations dynamic and implement more thereof
  • Loading branch information
nielstron authored Jan 31, 2024
2 parents 02d994e + 81da686 commit a56c7f4
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 13 deletions.
17 changes: 5 additions & 12 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@
Or: plt.Or,
}

UnaryOpMap = {
Not: {BoolInstanceType: plt.Not},
USub: {IntegerInstanceType: lambda x: plt.SubtractInteger(plt.Integer(0), x)},
}


def rec_constant_map_data(c):
if isinstance(c, bool):
Expand Down Expand Up @@ -208,13 +203,11 @@ def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
return ops

def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
opmap = UnaryOpMap.get(type(node.op))
assert opmap is not None, f"Operator {type(node.op)} is not supported"
op = opmap.get(node.operand.typ)
assert (
op is not None
), f"Operator {type(node.op)} is not supported for type {node.operand.typ}"
return op(self.visit(node.operand))
op = node.operand.typ.unop(node.op)
return plt.Apply(
op,
self.visit(node.operand),
)

def visit_Compare(self, node: TypedCompare) -> plt.AST:
assert len(node.ops) == 1, "Only single comparisons are supported"
Expand Down
79 changes: 79 additions & 0 deletions opshin/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ def validator(x: int) -> int:
ret = eval_uplc_value(source_code, x)
self.assertEqual(ret, -x, "not returned wrong value")

@given(x=st.integers())
def test_uadd_int(self, x):
source_code = """
def validator(x: int) -> int:
return +x
"""
ret = eval_uplc_value(source_code, x)
self.assertEqual(ret, +x, "not returned wrong value")

@given(x=st.integers())
def test_not_int(self, x):
source_code = """
def validator(x: int) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, x)
self.assertEqual(bool(ret), not x, "not returned wrong value")

@given(x=st.integers(), y=st.integers())
def test_add_int(self, x, y):
source_code = """
Expand Down Expand Up @@ -465,6 +483,15 @@ def validator(x: List[bytes], y: bytes) -> bool:
ret = eval_uplc_value(source_code, xs, y)
self.assertEqual(ret, y in xs, "list in returned wrong value")

@given(x=st.lists(st.integers()))
def test_not_list(self, x):
source_code = """
def validator(x: List[int]) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, x)
self.assertEqual(bool(ret), not x, "not returned wrong value")

@given(x=st.binary(), y=st.binary())
def test_eq_bytes(self, x, y):
source_code = """
Expand Down Expand Up @@ -759,6 +786,32 @@ def validator(x: List[int]) -> str:
ret, exp, "integer list string formatting returned wrong value"
)

@given(x=st.dictionaries(st.integers(), st.integers()))
def test_not_dict(self, x):
source_code = """
def validator(x: Dict[int, int]) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, x)
self.assertEqual(bool(ret), not x, "not returned wrong value")

@given(xs=st.dictionaries(st.integers(), st.integers()), y=st.integers())
def test_index_dict(self, xs, y):
source_code = """
from typing import Dict, List, Union
def validator(x: Dict[int, int], y: int) -> int:
return x[y]
"""
try:
exp = xs[y]
except KeyError:
exp = None
try:
ret = eval_uplc_value(source_code, xs, y)
except Exception as e:
ret = None
self.assertEqual(ret, exp, "list index returned wrong value")

@given(xs=st.dictionaries(formattable_text, st.integers()))
@example(dict())
@example({"": 0})
Expand Down Expand Up @@ -799,3 +852,29 @@ def validator(x: Anything) -> str:
self.assertEqual(
ret, exp, "raw cbor string formatting returned wrong value"
)

@given(x=st.text())
def test_not_string(self, x):
source_code = """
def validator(x: str) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, x.encode("utf8"))
self.assertEqual(bool(ret), not x, "not returned wrong value")

@given(x=st.binary())
def test_not_bytes(self, x):
source_code = """
def validator(x: bytes) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, x)
self.assertEqual(bool(ret), not x, "not returned wrong value")

def test_not_unit(self):
source_code = """
def validator(x: None) -> bool:
return not x
"""
ret = eval_uplc_value(source_code, uplc.BuiltinUnit())
self.assertEqual(bool(ret), not None, "not returned wrong value")
2 changes: 1 addition & 1 deletion opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def visit_BoolOp(self, node: BoolOp) -> TypedBoolOp:
def visit_UnaryOp(self, node: UnaryOp) -> TypedUnaryOp:
tu = copy(node)
tu.operand = self.visit(node.operand)
tu.typ = tu.operand.typ
tu.typ = tu.operand.typ.typ.unop_type(node.op).rettyp
return tu

def visit_Subscript(self, node: Subscript) -> TypedSubscript:
Expand Down
122 changes: 122 additions & 0 deletions opshin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,40 @@ def _binop_bin_fun(
f"{type(self).__name__} can not be used with operation {binop.__class__.__name__}"
)

def unop_type(self, unop: unaryop) -> "Type":
"""
Type of a unary operation on self.
"""
return FunctionType(
[InstanceType(self)],
InstanceType(self._unop_return_type(unop)),
)

def _unop_return_type(self, unop: unaryop) -> "Type":
"""
Return the type of a binary operation between self and other
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement {unop.__class__.__name__}"
)

def unop(self, unop: unaryop) -> plt.AST:
"""
Implements a unary operation on self
"""
return OLambda(
["self"],
self._unop_fun(unop)(OVar("self")),
)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
"""
Returns a unary function that implements the unary operation on self.
"""
raise NotImplementedError(
f"{type(self).__name__} can not be used with operation {unop.__class__.__name__}"
)


@dataclass(frozen=True, unsafe_hash=True)
class Record:
Expand Down Expand Up @@ -1000,6 +1034,16 @@ def _binop_bin_fun(self, binop: operator, other: AST):
):
return plt.AppendList

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False))
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class DictType(ClassType):
Expand Down Expand Up @@ -1287,6 +1331,16 @@ def CustomMapFilterList(
)
return OLambda(["self"], mapped_attrs)

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False))
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class FunctionType(ClassType):
Expand Down Expand Up @@ -1348,6 +1402,12 @@ def binop_type(self, binop: operator, other: "Type") -> "Type":
def binop(self, binop: operator, other: AST) -> plt.AST:
return self.typ.binop(binop, other)

def unop_type(self, unop: unaryop) -> "Type":
return self.typ.unop_type(unop)

def unop(self, unop: unaryop) -> plt.AST:
return self.typ.unop(unop)


@dataclass(frozen=True, unsafe_hash=True)
class IntegerType(AtomicType):
Expand Down Expand Up @@ -1541,6 +1601,24 @@ def _binop_bin_fun(self, binop: operator, other: AST):
elif other.typ == StringInstanceType:
return lambda x, y: StrIntMulImpl(y, x)

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, USub):
return IntegerType()
elif isinstance(unop, UAdd):
return IntegerType()
elif isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, USub):
return lambda x: plt.SubtractInteger(plt.Integer(0), x)
if isinstance(unop, UAdd):
return lambda x: x
if isinstance(unop, Not):
return lambda x: plt.EqualsInteger(x, plt.Integer(0))
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class StringType(AtomicType):
Expand Down Expand Up @@ -1591,6 +1669,18 @@ def _binop_bin_fun(self, binop: operator, other: AST):
if other.typ == IntegerInstanceType:
return StrIntMulImpl

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return lambda x: plt.EqualsInteger(
plt.LengthOfByteString(plt.EncodeUtf8(x)), plt.Integer(0)
)
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class ByteStringType(AtomicType):
Expand Down Expand Up @@ -1932,6 +2022,18 @@ def _binop_bin_fun(self, binop: operator, other: AST):
if other.typ == IntegerInstanceType:
return ByteStrIntMulImpl

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return lambda x: plt.EqualsInteger(
plt.LengthOfByteString(x), plt.Integer(0)
)
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class BoolType(AtomicType):
Expand Down Expand Up @@ -1967,6 +2069,16 @@ def stringify(self, recursive: bool = False) -> plt.AST:
),
)

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return plt.Not
return super()._unop_fun(unop)


@dataclass(frozen=True, unsafe_hash=True)
class UnitType(AtomicType):
Expand All @@ -1981,6 +2093,16 @@ def cmp(self, op: cmpop, o: "Type") -> plt.AST:
def stringify(self, recursive: bool = False) -> plt.AST:
return OLambda(["self"], plt.Text("None"))

def _unop_return_type(self, unop: unaryop) -> "Type":
if isinstance(unop, Not):
return BoolType()
return super()._unop_return_type(unop)

def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]:
if isinstance(unop, Not):
return lambda x: plt.Bool(True)
return super()._unop_fun(unop)


IntegerInstanceType = InstanceType(IntegerType())
StringInstanceType = InstanceType(StringType())
Expand Down

0 comments on commit a56c7f4

Please sign in to comment.