Skip to content

Commit

Permalink
allow byvar access to non constant List and Tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 6, 2024
1 parent 1217a17 commit 80a11c9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 8 deletions.
47 changes: 47 additions & 0 deletions qlasskit/ast2ast/astrewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,53 @@ def visit_Subscript(self, node): # noqa: C901
if isinstance(_sval, ast.Name) and _sval.id in self.const:
node.slice = self.const[_sval.id]

# Handle inner access L[i]
elif (
isinstance(node, ast.Subscript)
and isinstance(node.value, ast.Name)
and isinstance(node.slice, ast.Name)
):
nname = node.value.id
iname = node.slice.id

def create_if_exp(i, max_i):
if i == max_i:
return ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
)
else:
next_i = i + 1
return ast.IfExp(
test=ast.Compare(
left=ast.Name(id=iname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=i)],

),
body=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
orelse=create_if_exp(next_i, max_i),
)

# Infer i and j sizes from env['a']
a_type = self.env[nname]

# self.env[nname] is a constant
if isinstance(a_type, ast.Tuple):
max_i = len(a_type.elts) - 1
# self.env[nname] is a type annotation
else:
outer_tuple = a_type.slice
max_i = len(outer_tuple.elts) - 1

# Create the IfExp structure
return create_if_exp(0, max_i)

# Handle inner access L[i][j]
elif (
isinstance(node, ast.Subscript)
Expand Down
16 changes: 8 additions & 8 deletions test/qlassf/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def test_list_access_with_var(self):
compute_and_compare_results(self, qf)

def test_list_access_with_var_on_tuple(self):
# TODO: this fails on internal compiler
if self.compiler == "internal":
return

f = (
"def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n\tai,bi = ab\n"
"\td = c[ai] + c[bi]\n\treturn d"
Expand All @@ -107,10 +103,6 @@ def test_list_access_with_var_on_tuple(self):
compute_and_compare_results(self, qf)

def test_list_access_with_var_on_tuple2(self):
# TODO: this fails on internal compiler
if self.compiler == "internal":
return

f = (
"def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n"
"\td = c[ab[0]] + c[ab[1]]\n\treturn d"
Expand Down Expand Up @@ -146,3 +138,11 @@ def test_list_of_tuple_of_tuple2(self):
ttable = [(False, False), (True, False), (False, True), (True, True)]
tt = list(map(lambda e: (e[0], e[1], e[0] or e[1]), ttable))
qf.bind(io_list=tt)

def test_list_var_access(self):
f = (
"def test(a: Qlist[bool, 4], i: Qint[2]) -> bool:\n"
"\treturn a[i]"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
8 changes: 8 additions & 0 deletions test/qlassf/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,11 @@ def test_tuple_of_tuple_var_access(self):
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_tuple_var_access(self):
f = (
"def test(a: Tuple[bool, bool, bool, bool], i: Qint[2]) -> bool:\n"
"\treturn a[i]"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

0 comments on commit 80a11c9

Please sign in to comment.