Skip to content

Commit

Permalink
separate ast2ast, add constantfolding ast step
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 1, 2024
1 parent d50f58b commit 02d6943
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 49 deletions.
4 changes: 3 additions & 1 deletion docs/source/example_bernstein_vazirani.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
"source": [
"from qlasskit import qlassf, Qint\n",
"\n",
"\n",
"@qlassf\n",
"def oracle(x: Qint[4]) -> bool:\n",
" s = Qint4(14)\n",
" return (x[0] & s[0]) ^ (x[1] & s[1]) ^ (x[2] & s[2]) ^ (x[3] & s[3])\n",
"\n",
"oracle.export(\"qiskit\").draw(\"mpl\")\n"
"\n",
"oracle.export(\"qiskit\").draw(\"mpl\")"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include = [
"qlasskit.boolopt",
"qlasskit.types",
"qlasskit.ast2logic",
"qlasskit.ast2ast",
"qlasskit.qcircuit",
"qlasskit.compiler",
"qlasskit.algorithms",
Expand Down
17 changes: 17 additions & 0 deletions qlasskit/ast2ast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023-2024 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .ast2ast import ast2ast # noqa: F401
from .astrewriter import ASTRewriter # noqa: F401
from .constantfolder import ConstantFolder # noqa: F401
45 changes: 45 additions & 0 deletions qlasskit/ast2ast/ast2ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023-2024 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from ast import NodeTransformer

from .astrewriter import ASTRewriter
from .constantfolder import ConstantFolder


class IndexReplacer(NodeTransformer):
"""Replace Index with its content (for python < 3.9)"""

def generic_visit(self, node):
return super().generic_visit(node)

def visit_Index(self, node):
return self.visit(node.value)


def ast2ast(a_tree):
# print(ast.dump(a_tree))

# Replace indexes with its content if python < 3.9
if sys.version_info < (3, 9):
a_tree = IndexReplacer().visit(a_tree)

# Fold constants
a_tree = ConstantFolder().visit(a_tree)

# Rewrite the ast
a_tree = ASTRewriter().visit(a_tree)

# print(ast.dump(a_tree))
return a_tree
27 changes: 5 additions & 22 deletions qlasskit/ast2ast.py → qlasskit/ast2ast/astrewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,8 @@
# limitations under the License.
import ast
import copy
import sys

from .ast2logic import flatten


class IndexReplacer(ast.NodeTransformer):
"""Replace Index with its content (for python < 3.9)"""

def generic_visit(self, node):
return super().generic_visit(node)

def visit_Index(self, node):
return self.visit(node.value)
from ..ast2logic import flatten


class IsNamePresent(ast.NodeTransformer):
Expand Down Expand Up @@ -580,6 +569,10 @@ def visit_Call(self, node):
return node

def visit_BinOp(self, node):
# Check if we have two constants
# if isinstance(node.right, ast.Constant) and isinstance(node.left, ast.Constant):
# # return a constant evaluting the inner

# rewrite the ** operator to be a series of multiplications
if isinstance(node.op, ast.Pow):
if (
Expand All @@ -599,13 +592,3 @@ def visit_BinOp(self, node):
):
return ast.Constant(value=1)
return super().generic_visit(node)


def ast2ast(a_tree):
# print(ast.dump(a_tree))
if sys.version_info < (3, 9):
a_tree = IndexReplacer().visit(a_tree)

a_tree = ASTRewriter().visit(a_tree)
# print(ast.dump(a_tree))
return a_tree
123 changes: 123 additions & 0 deletions qlasskit/ast2ast/constantfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2023-2024 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import operator


class ConstantFolder(ast.NodeTransformer):
def __init__(self):
self.builtin_funcs = {
"abs": abs,
"len": len,
"min": min,
"max": max,
"sum": sum,
"any": any,
"all": all,
"chr": chr,
"ord": ord,
# "range": range, # This is handled differently
}

def visit_Compare(self, node):
self.generic_visit(node)
if len(node.ops) == 1 and len(node.comparators) == 1:
if isinstance(node.left, ast.Constant) and isinstance(
node.comparators[0], ast.Constant
):
op = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.Is: operator.is_,
ast.IsNot: operator.is_not,
ast.In: lambda x, y: x in y,
ast.NotIn: lambda x, y: x not in y,
}.get(type(node.ops[0]))
if op:
result = op(node.left.value, node.comparators[0].value)
return ast.Constant(value=result)
return node

def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.operand, ast.Constant):
op = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
ast.Not: operator.not_,
ast.Invert: operator.invert,
}.get(type(node.op))
if op:
return ast.Constant(op(node.operand.value)) # type: ignore
return node

def visit_BinOp(self, node):
self.generic_visit(node)
if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
op = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
ast.BitXor: operator.xor,
ast.BitAnd: operator.and_,
}.get(type(node.op))
if op:
return ast.Constant(op(node.left.value, node.right.value))
return node

def visit_Call(self, node):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id in self.builtin_funcs:
if all(isinstance(arg, ast.Constant) for arg in node.args):
func = self.builtin_funcs[node.func.id]
args = [arg.value for arg in node.args]
return ast.Constant(func(*args)) # type: ignore
return node

def visit_Subscript(self, node):
self.generic_visit(node)
if isinstance(node.value, ast.List) and isinstance(node.slice, ast.Constant):
if all(isinstance(elt, ast.Constant) for elt in node.value.elts):
try:
return node.value.elts[node.slice.value]
except IndexError:
pass
return node

def visit_If(self, node):
self.generic_visit(node)
if isinstance(node.test, ast.Constant):
if node.test.value:
return node.body
else:
return node.orelse
return node

def visit_IfExp(self, node):
self.generic_visit(node)
if isinstance(node.test, ast.Constant):
return node.body if node.test.value else node.orelse
return node
57 changes: 31 additions & 26 deletions test/test_ast_rewriter.py → test/test_ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,27 @@
import sys
import unittest

from qlasskit.ast2ast import ASTRewriter
from parameterized import parameterized

from qlasskit.ast2ast import ASTRewriter, ConstantFolder


class TestASTRewriter(unittest.TestCase):

def setUp(self):
self.rewriter = ASTRewriter()

def test_exponentiation_transformation(self):
code = "a = b ** 3"
tree = ast.parse(code)
new_tree = self.rewriter.visit(tree)

expected_code = "a = b * b * b"
expected_tree = ast.parse(expected_code)

self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))

def test_non_exponentiation(self):
code = "a = b + 3"
tree = ast.parse(code)
new_tree = self.rewriter.visit(tree)

expected_code = "a = b + 3"
expected_tree = ast.parse(expected_code)

self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))

def test_exponentiation_with_non_constant(self):
code = "a = b ** c"
@parameterized.expand(
[
("a = b ** 3", "a = b * b * b"),
("a = b + 3", "a = b + 3"),
("a = b ** c", "a = b ** c"),
]
)
def test_exponentiation_transformation(self, code, expected_code):
tree = ast.parse(code)
new_tree = self.rewriter.visit(tree)

expected_code = "a = b ** c"
expected_tree = ast.parse(expected_code)

self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))

def test_exponentiation_with_zero(self):
Expand All @@ -68,3 +53,23 @@ def test_exponentiation_with_zero(self):
expected_code = "a = 1"
expected_tree = ast.parse(expected_code)
self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))


class TestASTConstantFolder(unittest.TestCase):
def setUp(self):
self.rewriter = ConstantFolder()

@parameterized.expand(
[
("a + (13 - 12 + 1)", "a + 2"),
# ( "a + 13 - 12 + 1", "a + 2" ),
# ( "a + len([12])", "a + 1" ),
("if True: a \nelse: b", "a"),
("a if False else b", "b"),
]
)
def test_expected_code(self, code, expected_code):
tree = ast.parse(code)
new_tree = self.rewriter.visit(tree)
expected_tree = ast.parse(expected_code)
self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))

0 comments on commit 02d6943

Please sign in to comment.