diff --git a/docs/source/example_bernstein_vazirani.ipynb b/docs/source/example_bernstein_vazirani.ipynb index 49beaeae..1cf434a9 100644 --- a/docs/source/example_bernstein_vazirani.ipynb +++ b/docs/source/example_bernstein_vazirani.ipynb @@ -36,12 +36,14 @@ "source": [ "from qlasskit import qlassf, Qint\n", "\n", + "\n", "@qlassf\n", "def oracle(x: Qint[4]) -> bool:\n", " s = Qint4(14)\n", " return (x[0] & s[0]) ^ (x[1] & s[1]) ^ (x[2] & s[2]) ^ (x[3] & s[3])\n", "\n", - "oracle.export(\"qiskit\").draw(\"mpl\")\n" + "\n", + "oracle.export(\"qiskit\").draw(\"mpl\")" ] }, { diff --git a/pyproject.toml b/pyproject.toml index a2fd99ca..023cd535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ include = [ "qlasskit.boolopt", "qlasskit.types", "qlasskit.ast2logic", + "qlasskit.ast2ast", "qlasskit.qcircuit", "qlasskit.compiler", "qlasskit.algorithms", diff --git a/qlasskit/ast2ast/__init__.py b/qlasskit/ast2ast/__init__.py new file mode 100644 index 00000000..281ffc29 --- /dev/null +++ b/qlasskit/ast2ast/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023-2024 Davide Gessa + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ast2ast import ast2ast # noqa: F401 +from .astrewriter import ASTRewriter # noqa: F401 +from .constantfolder import ConstantFolder # noqa: F401 diff --git a/qlasskit/ast2ast/ast2ast.py b/qlasskit/ast2ast/ast2ast.py new file mode 100644 index 00000000..bb023db1 --- /dev/null +++ b/qlasskit/ast2ast/ast2ast.py @@ -0,0 +1,45 @@ +# Copyright 2023-2024 Davide Gessa + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from ast import NodeTransformer + +from .astrewriter import ASTRewriter +from .constantfolder import ConstantFolder + + +class IndexReplacer(NodeTransformer): + """Replace Index with its content (for python < 3.9)""" + + def generic_visit(self, node): + return super().generic_visit(node) + + def visit_Index(self, node): + return self.visit(node.value) + + +def ast2ast(a_tree): + # print(ast.dump(a_tree)) + + # Replace indexes with its content if python < 3.9 + if sys.version_info < (3, 9): + a_tree = IndexReplacer().visit(a_tree) + + # Fold constants + a_tree = ConstantFolder().visit(a_tree) + + # Rewrite the ast + a_tree = ASTRewriter().visit(a_tree) + + # print(ast.dump(a_tree)) + return a_tree diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast/astrewriter.py similarity index 97% rename from qlasskit/ast2ast.py rename to qlasskit/ast2ast/astrewriter.py index 654879d6..4438ac7f 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast/astrewriter.py @@ -13,19 +13,8 @@ # limitations under the License. import ast import copy -import sys -from .ast2logic import flatten - - -class IndexReplacer(ast.NodeTransformer): - """Replace Index with its content (for python < 3.9)""" - - def generic_visit(self, node): - return super().generic_visit(node) - - def visit_Index(self, node): - return self.visit(node.value) +from ..ast2logic import flatten class IsNamePresent(ast.NodeTransformer): @@ -580,6 +569,10 @@ def visit_Call(self, node): return node def visit_BinOp(self, node): + # Check if we have two constants + # if isinstance(node.right, ast.Constant) and isinstance(node.left, ast.Constant): + # # return a constant evaluting the inner + # rewrite the ** operator to be a series of multiplications if isinstance(node.op, ast.Pow): if ( @@ -599,13 +592,3 @@ def visit_BinOp(self, node): ): return ast.Constant(value=1) return super().generic_visit(node) - - -def ast2ast(a_tree): - # print(ast.dump(a_tree)) - if sys.version_info < (3, 9): - a_tree = IndexReplacer().visit(a_tree) - - a_tree = ASTRewriter().visit(a_tree) - # print(ast.dump(a_tree)) - return a_tree diff --git a/qlasskit/ast2ast/constantfolder.py b/qlasskit/ast2ast/constantfolder.py new file mode 100644 index 00000000..4389aa18 --- /dev/null +++ b/qlasskit/ast2ast/constantfolder.py @@ -0,0 +1,123 @@ +# Copyright 2023-2024 Davide Gessa + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import operator + + +class ConstantFolder(ast.NodeTransformer): + def __init__(self): + self.builtin_funcs = { + "abs": abs, + "len": len, + "min": min, + "max": max, + "sum": sum, + "any": any, + "all": all, + "chr": chr, + "ord": ord, + # "range": range, # This is handled differently + } + + def visit_Compare(self, node): + self.generic_visit(node) + if len(node.ops) == 1 and len(node.comparators) == 1: + if isinstance(node.left, ast.Constant) and isinstance( + node.comparators[0], ast.Constant + ): + op = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, + ast.Is: operator.is_, + ast.IsNot: operator.is_not, + ast.In: lambda x, y: x in y, + ast.NotIn: lambda x, y: x not in y, + }.get(type(node.ops[0])) + if op: + result = op(node.left.value, node.comparators[0].value) + return ast.Constant(value=result) + return node + + def visit_UnaryOp(self, node): + self.generic_visit(node) + if isinstance(node.operand, ast.Constant): + op = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, + ast.Not: operator.not_, + ast.Invert: operator.invert, + }.get(type(node.op)) + if op: + return ast.Constant(op(node.operand.value)) # type: ignore + return node + + def visit_BinOp(self, node): + self.generic_visit(node) + if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant): + op = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, + ast.BitOr: operator.or_, + ast.BitXor: operator.xor, + ast.BitAnd: operator.and_, + }.get(type(node.op)) + if op: + return ast.Constant(op(node.left.value, node.right.value)) + return node + + def visit_Call(self, node): + self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id in self.builtin_funcs: + if all(isinstance(arg, ast.Constant) for arg in node.args): + func = self.builtin_funcs[node.func.id] + args = [arg.value for arg in node.args] + return ast.Constant(func(*args)) # type: ignore + return node + + def visit_Subscript(self, node): + self.generic_visit(node) + if isinstance(node.value, ast.List) and isinstance(node.slice, ast.Constant): + if all(isinstance(elt, ast.Constant) for elt in node.value.elts): + try: + return node.value.elts[node.slice.value] + except IndexError: + pass + return node + + def visit_If(self, node): + self.generic_visit(node) + if isinstance(node.test, ast.Constant): + if node.test.value: + return node.body + else: + return node.orelse + return node + + def visit_IfExp(self, node): + self.generic_visit(node) + if isinstance(node.test, ast.Constant): + return node.body if node.test.value else node.orelse + return node diff --git a/test/test_ast_rewriter.py b/test/test_ast2ast.py similarity index 68% rename from test/test_ast_rewriter.py rename to test/test_ast2ast.py index 9138830b..4dd778bd 100644 --- a/test/test_ast_rewriter.py +++ b/test/test_ast2ast.py @@ -16,7 +16,9 @@ import sys import unittest -from qlasskit.ast2ast import ASTRewriter +from parameterized import parameterized + +from qlasskit.ast2ast import ASTRewriter, ConstantFolder class TestASTRewriter(unittest.TestCase): @@ -24,34 +26,17 @@ class TestASTRewriter(unittest.TestCase): def setUp(self): self.rewriter = ASTRewriter() - def test_exponentiation_transformation(self): - code = "a = b ** 3" - tree = ast.parse(code) - new_tree = self.rewriter.visit(tree) - - expected_code = "a = b * b * b" - expected_tree = ast.parse(expected_code) - - self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree)) - - def test_non_exponentiation(self): - code = "a = b + 3" - tree = ast.parse(code) - new_tree = self.rewriter.visit(tree) - - expected_code = "a = b + 3" - expected_tree = ast.parse(expected_code) - - self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree)) - - def test_exponentiation_with_non_constant(self): - code = "a = b ** c" + @parameterized.expand( + [ + ("a = b ** 3", "a = b * b * b"), + ("a = b + 3", "a = b + 3"), + ("a = b ** c", "a = b ** c"), + ] + ) + def test_exponentiation_transformation(self, code, expected_code): tree = ast.parse(code) new_tree = self.rewriter.visit(tree) - - expected_code = "a = b ** c" expected_tree = ast.parse(expected_code) - self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree)) def test_exponentiation_with_zero(self): @@ -68,3 +53,23 @@ def test_exponentiation_with_zero(self): expected_code = "a = 1" expected_tree = ast.parse(expected_code) self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree)) + + +class TestASTConstantFolder(unittest.TestCase): + def setUp(self): + self.rewriter = ConstantFolder() + + @parameterized.expand( + [ + ("a + (13 - 12 + 1)", "a + 2"), + # ( "a + 13 - 12 + 1", "a + 2" ), + # ( "a + len([12])", "a + 1" ), + ("if True: a \nelse: b", "a"), + ("a if False else b", "b"), + ] + ) + def test_expected_code(self, code, expected_code): + tree = ast.parse(code) + new_tree = self.rewriter.visit(tree) + expected_tree = ast.parse(expected_code) + self.assertEqual(ast.dump(new_tree), ast.dump(expected_tree))