diff --git a/opshin/builder.py b/opshin/builder.py index 75726f45..f3bb2d25 100644 --- a/opshin/builder.py +++ b/opshin/builder.py @@ -1,6 +1,7 @@ import copy import dataclasses import enum +import functools import json import types import typing @@ -203,9 +204,9 @@ def compile( return plt_code -def _compile( +@functools.lru_cache(maxsize=32) +def _static_compile( source_code: str, - *args: typing.Union[pycardano.Datum, uplc_ast.Constant], contract_file: str = "", force_three_params=False, validator_function_name="validator", @@ -229,7 +230,34 @@ def _compile( constant_folding=constant_folding, allow_isinstance_anything=allow_isinstance_anything, ) + return code + + +def _compile( + source_code: str, + *args: typing.Union[pycardano.Datum, uplc_ast.Constant], + contract_file: str = "", + force_three_params=False, + validator_function_name="validator", + optimize_patterns=True, + remove_dead_code=True, + constant_folding=False, + allow_isinstance_anything=False, +): + """ + Expects a python module and returns the build artifacts from compiling it + """ + code = _static_compile( + source_code, + contract_file=contract_file, + force_three_params=force_three_params, + validator_function_name=validator_function_name, + optimize_patterns=optimize_patterns, + remove_dead_code=remove_dead_code, + constant_folding=constant_folding, + allow_isinstance_anything=allow_isinstance_anything, + ) code = _apply_parameters(code, *args) return code diff --git a/opshin/compiler.py b/opshin/compiler.py index 928ba6ab..9abc9ce1 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -855,7 +855,7 @@ def visit_ClassDef(self, node: TypedClassDef) -> CallAST: def visit_Attribute(self, node: TypedAttribute) -> plt.AST: assert isinstance( - node.typ, InstanceType + node.value.typ, InstanceType ), "Can only access attributes of instances" obj = self.visit(node.value) attr = node.value.typ.attribute(node.attr) diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index 3fa37b4b..633f385d 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -2885,3 +2885,42 @@ def validator(x: bool) -> str: """ res_true = eval_uplc_value(source_code, 1) res_false = eval_uplc_value(source_code, 0) + + @unittest.expectedFailure + def test_class_attribute_access(self): + source_code = """ +from dataclasses import dataclass +from pycardano import Datum as Anything, PlutusData +from typing import Dict, List, Union + +@dataclass +class A(PlutusData): + CONSTR_ID = 15 + a: int + b: bytes + d: List[int] + +def validator(_: None) -> int: + return A.CONSTR_ID + """ + builder._compile(source_code) + + def test_constr_id_access(self): + source_code = """ +from dataclasses import dataclass +from pycardano import Datum as Anything, PlutusData +from typing import Dict, List, Union + +@dataclass +class A(PlutusData): + CONSTR_ID = 15 + a: int + b: bytes + d: List[int] + +def validator(_: None) -> int: + return A(0, b"", [1,2]).CONSTR_ID + """ + res = eval_uplc_value(source_code, Unit()) + + self.assertEqual(15, res, "Invalid constr id") diff --git a/opshin/tests/utils.py b/opshin/tests/utils.py index 1a7e9415..290b2364 100644 --- a/opshin/tests/utils.py +++ b/opshin/tests/utils.py @@ -1,4 +1,5 @@ import dataclasses +import functools import typing import pycardano