-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
67 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,98 @@ | ||
""" | ||
This file contains the Ast class, which represents an abstract syntax tree which can be evaluated. | ||
""" | ||
from typing import Dict | ||
|
||
import copy | ||
from typing import Dict | ||
|
||
from common import RuleMatch, remove, left_assoc, Token | ||
from rules import rule_process_map, rule_process_value_map | ||
from vartypes import Variable | ||
|
||
|
||
class Ast: | ||
def __init__(self, ast: RuleMatch): | ||
self.ast = self._fixed(ast) | ||
def __init__(self, root: RuleMatch): | ||
self.root = self._fixed(root) | ||
|
||
def _fixed(self, ast): | ||
def _fixed(self, node): | ||
# print('**_fixed ast', ast) | ||
|
||
if not isinstance(ast, RuleMatch): | ||
return ast | ||
if not isinstance(node, RuleMatch): | ||
return node | ||
|
||
# This removes extraneous symbols from the tree. | ||
for i in range(len(ast.matched) - 1, -1, -1): | ||
if ast.matched[i].name in remove: | ||
del ast.matched[i] | ||
for i in range(len(node.matched) - 1, -1, -1): | ||
if node.matched[i].name in remove: | ||
del node.matched[i] | ||
|
||
# This flattens rules with a single matched rule. | ||
if len(ast.matched) is 1 and isinstance(ast.matched[0], RuleMatch): | ||
return self._fixed(ast.matched[0]) | ||
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch): | ||
return self._fixed(node.matched[0]) | ||
|
||
# This makes left-associative operations left-associative. | ||
for token_name, rule in left_assoc.items(): | ||
if len(ast.matched) == 3 and ast.matched[1].name == token_name and isinstance(ast.matched[2], RuleMatch) and len(ast.matched[2].matched) == 3 and ast.matched[2].matched[1].name == token_name: | ||
ast.matched[0] = RuleMatch(rule, [ast.matched[0], ast.matched[1], ast.matched[2].matched[0]]) | ||
ast.matched[1] = ast.matched[2].matched[1] | ||
ast.matched[2] = ast.matched[2].matched[2] | ||
return self._fixed(ast) | ||
if len(node.matched) == 3 and node.matched[1].name == token_name and isinstance(node.matched[2], RuleMatch) and len(node.matched[2].matched) == 3 and node.matched[2].matched[1].name == token_name: | ||
node.matched[0] = RuleMatch(rule, [node.matched[0], node.matched[1], node.matched[2].matched[0]]) | ||
node.matched[1] = node.matched[2].matched[1] | ||
node.matched[2] = node.matched[2].matched[2] | ||
return self._fixed(node) | ||
|
||
# This converts implicit multiplication to regular multiplication. | ||
if ast.name == 'mui': | ||
return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]])) | ||
if node.name == 'mui': | ||
return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]])) | ||
|
||
# This flattens matrix rows into parent matrix rows. | ||
if ast.name == 'mrw': | ||
for i in range(len(ast.matched) - 1, -1, -1): | ||
if ast.matched[i].name == 'mrw': | ||
ast.matched[i:] = ast.matched[i].matched | ||
return self._fixed(ast) | ||
if node.name == 'mrw': | ||
for i in range(len(node.matched) - 1, -1, -1): | ||
if node.matched[i].name == 'mrw': | ||
node.matched[i:] = node.matched[i].matched | ||
return self._fixed(node) | ||
|
||
# This flattens matrix bodies into parent matrix bodies. | ||
if ast.name == 'mbd': | ||
for i in range(len(ast.matched) - 1, -1, -1): | ||
if ast.matched[i].name == 'mbd': | ||
ast.matched[i:] = ast.matched[i].matched | ||
return self._fixed(ast) | ||
if node.name == 'mbd': | ||
for i in range(len(node.matched) - 1, -1, -1): | ||
if node.matched[i].name == 'mbd': | ||
node.matched[i:] = node.matched[i].matched | ||
return self._fixed(node) | ||
|
||
if isinstance(ast, RuleMatch): | ||
for i in range(len(ast.matched)): | ||
ast.matched[i] = self._fixed(ast.matched[i]) | ||
if isinstance(node, RuleMatch): | ||
for i in range(len(node.matched)): | ||
node.matched[i] = self._fixed(node.matched[i]) | ||
|
||
return ast | ||
return node | ||
|
||
def evaluate(self, vrs: Dict[str, RuleMatch]): | ||
return self._evaluate(self.ast, vrs) | ||
return self._evaluate(self.root, vrs) | ||
|
||
def _evaluate(self, ast, vrs: Dict[str, RuleMatch]): | ||
if ast.name == 'asn': | ||
return {ast.matched[0].value: ast.matched[1]} | ||
def _evaluate(self, node, vrs: Dict[str, RuleMatch]): | ||
if node.name == 'asn': | ||
return {node.matched[0].value: node.matched[1]} | ||
|
||
for token in ast.matched: | ||
for token in node.matched: | ||
if isinstance(token, RuleMatch) and not token.value: | ||
token.value = self._evaluate(token, vrs) | ||
|
||
values = [token.value for token in ast.matched if isinstance(token, RuleMatch) and token.value] | ||
tokens = [token for token in ast.matched if not isinstance(token, RuleMatch)] | ||
values = [token.value for token in node.matched if isinstance(token, RuleMatch) and token.value] | ||
tokens = [token for token in node.matched if not isinstance(token, RuleMatch)] | ||
|
||
if ast.matched[0].name == 'IDT': | ||
return self._evaluate(copy.deepcopy(vrs[ast.matched[0].value]), vrs) | ||
if node.matched[0].name == 'IDT': | ||
return self._evaluate(copy.deepcopy(vrs[node.matched[0].value]), vrs) | ||
|
||
elif ast.name in rule_process_value_map: | ||
process = rule_process_value_map[ast.name](values, tokens) | ||
elif node.name in rule_process_value_map: | ||
process = rule_process_value_map[node.name](values, tokens) | ||
|
||
else: | ||
process = rule_process_map[ast.name](values, tokens[0] if len(tokens) > 0 else None) # This extra rule is part of the num hotfix. | ||
process = rule_process_map[node.name](values, tokens[0] if len(tokens) > 0 else None) # This extra rule is part of the num hotfix. | ||
|
||
return process.value | ||
|
||
def infix(self) -> str: | ||
# TODO: Add parentheses where needed. | ||
return self._infix(self.ast) | ||
# TODO: Add parentheses and missing tokens. | ||
return self._infix(self.root) | ||
|
||
def _infix(self, ast: RuleMatch) -> str: | ||
return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched)) | ||
def _infix(self, node: RuleMatch) -> str: | ||
return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), node.matched)) | ||
|
||
def __str__(self): | ||
return self._str(self.ast) # + '\n>> ' + self.infix() | ||
return str(self.root) # + '\n>> ' + self.infix() | ||
|
||
def __repr__(self): | ||
return str(self) | ||
|
||
def _str(self, ast, depth=0) -> str: | ||
output = (('\t' * depth) + ast.name + ' = ' + str(ast.value if ast.value else None)) + '\n' | ||
|
||
for matched in ast.matched: | ||
if isinstance(matched, RuleMatch) and matched.matched: | ||
output += self._str(matched, depth + 1) | ||
|
||
else: | ||
output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n' | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters