Skip to content

Commit

Permalink
refactor ast2ast with Env class, and grouping common code
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 7, 2024
1 parent 55c55f9 commit 0a40c27
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 109 deletions.
212 changes: 103 additions & 109 deletions qlasskit/ast2ast/astrewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,65 @@
from typing import Any

from ..ast2logic import flatten
from .env import Environment


def create_if_exp(nname, iname, max_i, jname=None, max_j=None):
"""Given a List or List of List `nname`, an index `iname` and an optional index `jname`,
returns L[0] if i == 0 else L[1] if i == 1 ..."""

def access_ij(i, j):
fsub = ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
)

if jname is not None:
return ast.Subscript(
value=fsub,
slice=ast.Constant(value=j),
ctx=ast.Load(),
)
else:
return fsub

def _create_if_exp(i, j=None):
if i == max_i and (jname is None or j == max_j):
return access_ij(i, j)
else:
cmp_i = ast.Compare(
left=ast.Name(id=iname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=i)],
)
if jname is not None:
next_j = j + 1 if j < max_j else 0
next_i = i if j < max_j else i + 1

return ast.IfExp(
test=ast.BoolOp(
op=ast.And(),
values=[
cmp_i,
ast.Compare(
left=ast.Name(id=jname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=j)],
),
],
),
body=access_ij(i, j),
orelse=_create_if_exp(next_i, next_j),
)
else:
return ast.IfExp(
test=cmp_i,
body=access_ij(i, j),
orelse=_create_if_exp(i + 1),
)

return _create_if_exp(0, None if jname is None else 0)


@dataclass
Expand Down Expand Up @@ -116,9 +175,8 @@ def _replace_types_annotations(ann, arg=None):
class ASTRewriter(ast.NodeTransformer):
"""Rewrites the ast to a simplified version"""

def __init__(self, env={}, ret=None):
self.env = {}
self.const = {}
def __init__(self, ret=None):
self.env = Environment()
self.ret = None
self._uniqd = 1

Expand All @@ -137,10 +195,10 @@ def __unroll_arg(self, arg):
# If it's a name, is in env and is a Tuple, return elements
if (
arg.id in self.env
and isinstance(self.env[arg.id], ast.Subscript)
and self.env[arg.id].value.id == "Tuple"
and isinstance(self.env.get_type(arg.id), ast.Subscript)
and self.env.get_type(arg.id).value.id == "Tuple"
):
_sval = self.env[arg.id].slice
_sval = self.env.get_type(arg.id).slice

return [
ast.Subscript(
Expand All @@ -156,50 +214,27 @@ def generic_visit(self, node):

def visit_Subscript(self, node): # noqa: C901
# Replace L[a] with const a, to L[const]
if isinstance(node.slice, ast.Name) and node.slice.id in self.const:
node.slice = self.const[node.slice.id]
if isinstance(node.slice, ast.Name) and self.env.has_constant(node.slice.id):
node.slice = self.env.get_constant(node.slice.id)

# Handle inner access L[i]
elif isinstance(node.value, ast.Name) and isinstance(node.slice, ast.Name):
nname = node.value.id
iname = node.slice.id

def create_if_exp_single(i, max_i):
if i == max_i:
return ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
)
else:
next_i = i + 1
return ast.IfExp(
test=ast.Compare(
left=ast.Name(id=iname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=i)],
),
body=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
orelse=create_if_exp_single(next_i, max_i),
)

# Infer i and j sizes from env['a']
a_type = self.env[nname]
gtype = self.env.get_type(nname)

# self.env[nname] is a constant
if isinstance(a_type, ast.Tuple):
max_i = len(a_type.elts) - 1
if isinstance(gtype, ast.Tuple):
max_i = len(gtype.elts) - 1
# self.env[nname] is a type annotation
else:
outer_tuple = a_type.slice
outer_tuple = gtype.slice
max_i = len(outer_tuple.elts) - 1

# Create the IfExp structure
return create_if_exp_single(0, max_i)
return create_if_exp(nname, iname, max_i)

# Handle inner access L[i][j]
elif (
Expand All @@ -212,74 +247,33 @@ def create_if_exp_single(i, max_i):
iname = node.value.slice.id
jname = node.slice.id

def create_if_exp(i, j, max_i, max_j):
if i == max_i and j == max_j:
return ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
slice=ast.Constant(value=j),
ctx=ast.Load(),
)
else:
next_j = j + 1 if j < max_j else 0
next_i = i if j < max_j else i + 1
return ast.IfExp(
test=ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Name(id=iname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=i)],
),
ast.Compare(
left=ast.Name(id=jname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=j)],
),
],
),
body=ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
slice=ast.Constant(value=j),
ctx=ast.Load(),
),
orelse=create_if_exp(next_i, next_j, max_i, max_j),
)

# Infer i and j sizes from env['a']
a_type = self.env[nname]
gtype = self.env.get_type(nname)

# self.env[nname] is a constant
if isinstance(a_type, ast.Tuple):
max_i = len(a_type.elts) - 1
max_j = len(a_type.elts[0].elts) - 1 # type: ignore
if isinstance(gtype, ast.Tuple):
max_i = len(gtype.elts) - 1
max_j = len(gtype.elts[0].elts) - 1 # type: ignore
# self.env[nname] is a type annotation
else:
outer_tuple = a_type.slice
outer_tuple = gtype.slice
max_i = len(outer_tuple.elts) - 1
inner_tuple = outer_tuple.elts
max_j = len(inner_tuple) - 1

# Create the IfExp structure
return create_if_exp(0, 0, max_i, max_j)
return create_if_exp(nname, iname, max_i, jname, max_j)

# Unroll L[a] with (L[0] if a == 0 else L[1] if a == 1 ...) when self.env[L] is constant
# Unroll L[a] with (L[0] if a == 0 else L[1] if a == 1 ...) when L is constant
elif (
isinstance(node.slice, ast.Name) and node.slice.id not in self.const
isinstance(node.slice, ast.Name)
and not self.env.has_constant(node.slice.id)
) or isinstance(node.slice, ast.Subscript):
if isinstance(node.value, ast.Name):
if node.value.id == "Tuple":
return node

tup = self.env[node.value.id]
tup = self.env.get_constant(node.value.id)
else:
tup = node.value

Expand Down Expand Up @@ -390,7 +384,7 @@ def visit_List(self, node):
def visit_AnnAssign(self, node):
node.annotation = _replace_types_annotations(node.annotation)
node.value = self.visit(node.value) if node.value else node.value
self.env[node.target] = node.annotation
self.env.set_type(node.target.id, node.annotation)
return node

def visit_FunctionDef(self, node):
Expand All @@ -399,7 +393,7 @@ def visit_FunctionDef(self, node):
]

for x in node.args.args:
self.env[x.arg] = x.annotation
self.env.set_type(x.arg, x.annotation)

node.returns = _replace_types_annotations(node.returns)
self.ret = node.returns
Expand Down Expand Up @@ -443,13 +437,14 @@ def visit_Assign(self, node):
target_0id = node.targets[0].id
was_known = target_0id in self.env

if isinstance(node.value, ast.Name) and node.value.id in self.env:
self.env[target_0id] = self.env[node.value.id]
if isinstance(node.value, ast.Constant):
self.env.set_constant(target_0id, node.value)
elif isinstance(node.value, ast.Name) and node.value.id in self.env:
self.env.copy_type(node.value.id, target_0id)
elif isinstance(node.value, ast.Tuple) or isinstance(node.value, ast.List):
# TODO: this is a constant, not an annotation
self.env[target_0id] = self.visit(node.value)
self.env.set_constant(target_0id, self.visit(node.value))
else:
self.env[target_0id] = "Unknown"
self.env.set_type(target_0id, "Unknown")

# If value is not self referencing, we can skip this (ie: a = b + 1)
ip = IsNamePresent(target_0id)
Expand Down Expand Up @@ -497,12 +492,14 @@ def visit_For(self, node): # noqa: C901
iter = self.visit(node.iter)

# Get the list to iterate (should be defined with a fixed size)
if isinstance(iter, ast.Name) and iter.id in self.env:
if isinstance(self.env[iter.id], ast.Tuple):
iter = self.env[iter.id].elts
if isinstance(iter, ast.Name) and self.env.has_type(iter.id):
iter_type = self.env.get_type(iter.id)

elif isinstance(self.env[iter.id], ast.Subscript):
_elts = self.env[iter.id].slice.elts
if isinstance(iter_type, ast.Tuple):
iter = iter_type.elts

elif isinstance(iter_type, ast.Subscript):
_elts = iter_type.slice.elts # type: ignore

iter = [
ast.Subscript(
Expand All @@ -517,14 +514,15 @@ def visit_For(self, node): # noqa: C901
elif (
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and iter.value.id in self.env
and self.env.has_type(iter.value.id)
and hasattr(iter.slice, "value")
):
if isinstance(self.env[iter.value.id], ast.Tuple):
new_iter = self.env[iter.value.id].elts[iter.slice.value]
iter_value_type = self.env.get_type(iter.value.id)
if isinstance(iter_value_type, ast.Tuple):
new_iter = iter_value_type.elts[iter.slice.value]

elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]
elif isinstance(iter_value_type, ast.Subscript):
_elts = iter_value_type.slice.elts[iter.slice.value] # type: ignore

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts
Expand Down Expand Up @@ -556,7 +554,7 @@ def visit_For(self, node): # noqa: C901
else:
_val = ast.Constant(value=i)

self.const[node.target.id] = _val
self.env.set_constant(node.target.id, _val)

tar_assign = self.visit(ast.Assign(targets=[node.target], value=_val))
rolls.extend(flatten([tar_assign]))
Expand Down Expand Up @@ -670,10 +668,6 @@ def visit_Call(self, node):
return node

def visit_BinOp(self, node):
# Check if we have two constants
# if isinstance(node.right, ast.Constant) and isinstance(node.left, ast.Constant):
# # return a constant evaluting the inner

# Rewrite the ** operator to be a series of multiplications
if isinstance(node.op, ast.Pow):
if (
Expand Down
63 changes: 63 additions & 0 deletions qlasskit/ast2ast/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023-2024 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast


class Environment:
def __init__(self):
self.types = {}
self.constants = {}

def set_type(self, name, type_annotation):
self.types[name] = type_annotation

def set_constant(self, name, value):
self.constants[name] = value

if name not in self.types:
self.types[name] = value

def copy_type(self, origin, dest):
self.types[dest] = self.types[origin]

def get_type(self, name):
return self.types.get(name)

def get_constant(self, name):
return self.constants.get(name)

def has_type(self, name):
return name in self.types

def has_constant(self, name):
return name in self.constants

def remove(self, name):
self.types.pop(name, None)
self.constants.pop(name, None)

def __contains__(self, name):
return name in self.types or name in self.constants

def __getitem__(self, name):
if name in self.constants:
return self.constants[name]
return self.types.get(name)

def __setitem__(self, name, value):
if isinstance(value, ast.AST):
self.set_type(name, value)
else:
self.set_constant(name, value)

0 comments on commit 0a40c27

Please sign in to comment.