Skip to content

Commit

Permalink
ast2ast refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 8, 2024
1 parent 0a40c27 commit 919a92a
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 188 deletions.
10 changes: 8 additions & 2 deletions qlasskit/ast2ast/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from .astrewriter import ASTRewriter
from .constantfolder import ConstantFolder
from .replacemultitargetassign import ReplaceMultiTargetAssign
from .replacetypeann import ReplaceTypeAnn


class IndexReplacer(NodeTransformer):
Expand All @@ -35,11 +37,15 @@ def ast2ast(a_tree):
if sys.version_info < (3, 9):
a_tree = IndexReplacer().visit(a_tree)

# Matrix translator

# Fold constants
a_tree = ConstantFolder().visit(a_tree)

# Replace Type Annotations
a_tree = ReplaceTypeAnn().visit(a_tree)

# Replace multi-target assign
a_tree = ReplaceMultiTargetAssign().visit(a_tree)

# Rewrite the ast
a_tree = ASTRewriter().visit(a_tree)

Expand Down
229 changes: 43 additions & 186 deletions qlasskit/ast2ast/astrewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,79 +104,11 @@ def visit_Name(self, node):
return node


def _replace_types_annotations(ann, arg=None):
"""Replaces type annotations, translating high level types"""
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Tuple"
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

# Replace QintX with Qint[X]
if isinstance(ann, ast.Name) and ann.id[:4] == "Qint":
ann = ast.Subscript(
value=ast.Name(id="Qint", ctx=ast.Load()),
slice=ast.Constant(value=int(ann.id[4:])),
)

# Replace QfixedX with Qfixed[X]
if isinstance(ann, ast.Name) and ann.id[:6] == "Qfixed":
ann = ast.Subscript(
value=ast.Name(id="Qfixed", ctx=ast.Load()),
slice=ast.Constant(value=int(ann.id[6:])),
)

# Replace Qlist[T,n] with Tuple[(T,)*n]
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qlist"
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

# Replace Qmatrix[T,n,m] with Tuple[(Tuple[(T,)*m],)*n]
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qmatrix"
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value)
_ituple = ast.Tuple(elts=[copy.deepcopy(_ituple_row)] * _elts[1].value)

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

if arg is not None:
arg.annotation = ann
return arg
else:
return ann


class ASTRewriter(ast.NodeTransformer):
"""Rewrites the ast to a simplified version"""

def __init__(self, ret=None):
self.env = Environment()
def __init__(self, env=None, ret=None):
self.env = Environment() if env is None else env
self.ret = None
self._uniqd = 1

Expand All @@ -186,29 +118,6 @@ def uniqd(self):
self._uniqd += 1
return f"{self._uniqd:x}"

def __unroll_arg(self, arg):
"""Argument unrolling for visit_call()"""
if isinstance(arg, ast.Tuple):
# If it's a tuple, return elts
return arg.elts
elif isinstance(arg, ast.Name):
# If it's a name, is in env and is a Tuple, return elements
if (
arg.id in self.env
and isinstance(self.env.get_type(arg.id), ast.Subscript)
and self.env.get_type(arg.id).value.id == "Tuple"
):
_sval = self.env.get_type(arg.id).slice

return [
ast.Subscript(
value=ast.Name(id=arg.id, ctx=ast.Load()),
slice=ast.Constant(value=i, kind=None),
)
for i in range(len(_sval.elts))
]
return [arg]

def generic_visit(self, node):
return super().generic_visit(node)

Expand Down Expand Up @@ -382,58 +291,19 @@ def visit_List(self, node):
return ast.Tuple(elts=[self.visit(el) for el in node.elts])

def visit_AnnAssign(self, node):
node.annotation = _replace_types_annotations(node.annotation)
node.value = self.visit(node.value) if node.value else node.value
self.env.set_type(node.target.id, node.annotation)
return node

def visit_FunctionDef(self, node):
node.args.args = [
_replace_types_annotations(x.annotation, arg=x) for x in node.args.args
]

for x in node.args.args:
self.env.set_type(x.arg, x.annotation)

node.returns = _replace_types_annotations(node.returns)
self.ret = node.returns

return super().generic_visit(node)

def visit_Assign(self, node):
# Transform multi-target assign to single target assigns
if len(node.targets) == 1 and hasattr(node.targets[0], "elts"):
if isinstance(node.value, ast.Name):
return [
self.visit(
ast.Assign(
targets=[ast.Name(id=node.targets[0].elts[i].id)],
value=ast.Subscript(
value=node.value, slice=ast.Constant(value=i)
),
)
)
for i in range(len(node.targets[0].elts))
]

_temptup = self.visit(
ast.Assign(targets=[ast.Name(id="_temptup")], value=node.value)
)

single_assigns = [
self.visit(
ast.Assign(
targets=[ast.Name(id=node.targets[0].elts[i].id)],
value=ast.Subscript(
value=ast.Name(id="_temptup"), slice=ast.Constant(value=i)
),
)
)
for i in range(len(node.targets[0].elts))
]

return [_temptup] + single_assigns

target_0id = node.targets[0].id
was_known = target_0id in self.env

Expand Down Expand Up @@ -487,73 +357,60 @@ def visit_AugAssign(self, node):
),
]

def visit_For(self, node): # noqa: C901
"""Unroll for loops to single iterations"""
iter = self.visit(node.iter)

# Get the list to iterate (should be defined with a fixed size)
if isinstance(iter, ast.Name) and self.env.has_type(iter.id):
iter_type = self.env.get_type(iter.id)

if isinstance(iter_type, ast.Tuple):
iter = iter_type.elts

elif isinstance(iter_type, ast.Subscript):
_elts = iter_type.slice.elts # type: ignore

iter = [
def __unroll_arg(self, arg):
"""Transform a node to a list (when is a Tuple or a subscribable type)"""
if isinstance(arg, ast.Tuple):
# If it's a tuple, return elts
return arg.elts
elif isinstance(arg, ast.Constant) and isinstance(arg.value, ast.Tuple):
return arg.value.elts
elif isinstance(arg, ast.Subscript):
_sval = self.env.get_type(arg.value.id) # type: ignore
if isinstance(_sval, ast.Subscript) and isinstance(_sval.slice, ast.Tuple):
return [
ast.Subscript(
value=ast.Name(id=iter.id, ctx=ast.Load()),
slice=ast.Constant(value=e),
ctx=ast.Load(),
value=ast.Subscript(
value=ast.Name(id=arg.value.id, ctx=ast.Load()), # type: ignore
slice=ast.Constant(value=arg.slice.value, kind=None), # type: ignore
),
slice=ast.Constant(value=i, kind=None),
)
for e in range(len(_elts))
for i in range(len(_sval.slice.elts))
]
elif isinstance(iter, ast.Tuple):
iter = iter.elts
elif (
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and self.env.has_type(iter.value.id)
and hasattr(iter.slice, "value")
):
iter_value_type = self.env.get_type(iter.value.id)
if isinstance(iter_value_type, ast.Tuple):
new_iter = iter_value_type.elts[iter.slice.value]

elif isinstance(iter_value_type, ast.Subscript):
_elts = iter_value_type.slice.elts[iter.slice.value] # type: ignore

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts
elif isinstance(arg, ast.Name):
# If it's a name, is in env and is a Tuple, return elements
if (
self.env.has_type(arg.id)
and isinstance(self.env.get_type(arg.id), ast.Subscript)
and self.env.get_type(arg.id).value.id == "Tuple"
):
_sval = self.env.get_type(arg.id).slice

new_iter = [
return [
ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
ctx=ast.Load(),
),
slice=ast.Constant(value=e),
value=ast.Name(id=arg.id, ctx=ast.Load()),
slice=ast.Constant(value=i, kind=None),
)
for e in range(len(_elts))
for i in range(len(_sval.elts))
]
else:
new_iter = iter

iter = new_iter

if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple):
iter = iter.value.elts
# If it's a tuple constant, return elements
elif self.env.has_constant(arg.id) and isinstance(
self.env.get_type(arg.id), ast.Tuple
):
return self.env.get_constant(arg.id).elts
return [arg]

# Unroll each for iteration
def visit_For(self, node):
"""Unroll for loops to single iterations"""
iter = self.__unroll_arg(self.visit(node.iter))
rolls = []
iter = flatten(iter)

for i in iter:
if isinstance(i, ast.Subscript) or isinstance(i, ast.Constant):
if isinstance(i, ast.Constant) or isinstance(i, ast.Subscript):
_val = i
else:
_val = ast.Constant(value=i)

self.env.set_constant(node.target.id, _val)

tar_assign = self.visit(ast.Assign(targets=[node.target], value=_val))
Expand Down
54 changes: 54 additions & 0 deletions qlasskit/ast2ast/replacemultitargetassign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2023-2024 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast


class ReplaceMultiTargetAssign(ast.NodeTransformer):

def visit_Assign(self, node):
if len(node.targets) != 1 or not hasattr(node.targets[0], "elts"):
return node

# Transform multi-target assign to single target assigns
if isinstance(node.value, ast.Name):
return [
self.visit(
ast.Assign(
targets=[ast.Name(id=node.targets[0].elts[i].id)],
value=ast.Subscript(
value=node.value, slice=ast.Constant(value=i)
),
)
)
for i in range(len(node.targets[0].elts))
]

_temptup = self.visit(
ast.Assign(targets=[ast.Name(id="_temptup")], value=node.value)
)

single_assigns = [
self.visit(
ast.Assign(
targets=[ast.Name(id=node.targets[0].elts[i].id)],
value=ast.Subscript(
value=ast.Name(id="_temptup"), slice=ast.Constant(value=i)
),
)
)
for i in range(len(node.targets[0].elts))
]

return [_temptup] + single_assigns
Loading

0 comments on commit 919a92a

Please sign in to comment.