Skip to content

Commit

Permalink
Merge branch 'feat/allow_empty_dicts' into feat/unop_dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Jan 31, 2024
2 parents 5dee071 + 0f29c2f commit 81da686
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion opshin/optimize/optimize_const_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def generic_visit(self, node: AST):
if any(
isinstance(node_eval, t)
for t in ACCEPTED_ATOMIC_TYPES + [list, dict, PlutusData]
):
) and not (node_eval == [] or node_eval == {}):
new_node = Constant(node_eval, None)
copy_location(new_node, node)
return new_node
Expand Down
13 changes: 13 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,3 +2702,16 @@ def validator(_: None) -> Dict[int, int]:
"""
res = eval_uplc_value(source_code, Unit(), constant_folding=True)
self.assertEqual(res, {})

def test_empty_dict_displaced_constant_folding(self):
source_code = """
from typing import Dict, List, Union
VAR: Dict[bytes, int] = {}
def validator(b: Dict[int, Dict[bytes, int]]) -> Dict[bytes, int]:
a = b.get(0, VAR)
return a
"""
res = eval_uplc_value(source_code, {1: {b"": 0}}, constant_folding=True)
self.assertEqual(res, {})
4 changes: 2 additions & 2 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def constant_type(c):
), "Constant lists must contain elements of a single type only"
return InstanceType(ListType(first_typ))
if isinstance(c, dict):
assert len(c) > 0, "Lists must be non-empty"
assert len(c) > 0, "Dicts must be non-empty"
first_key_typ = constant_type(next(iter(c.keys())))
first_value_typ = constant_type(next(iter(c.values())))
assert all(
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def visit_While(self, node: For) -> bool:
def visit_Return(self, node: Return) -> bool:
assert (
self.func_rettyp >= node.typ
), f"Function '{node.name}' annotated return type does not match actual return type"
), f"Function annotated return type does not match actual return type"
return True

def check_fulfills(self, node: FunctionDef):
Expand Down

0 comments on commit 81da686

Please sign in to comment.