Skip to content

Commit

Permalink
Merge pull request #1 from nrubin29/new_eval
Browse files Browse the repository at this point in the history
Rewrote evaluation code to be more object-oriented.
  • Loading branch information
nrubin29 authored Oct 8, 2017
2 parents 1cd9c6d + a0d757a commit 3f7e035
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 277 deletions.
8 changes: 8 additions & 0 deletions EXPLANATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1. The input is processed into a series of tokens using regular expressions.
* Calculator#_tokenize
2. The tokens are turned into a tree of RuleMatches using a right-recursive pattern matching algorithm.
* Calculator#_match
3. The tree is fixed. Unnecessary tokens are removed, precedence issues are fixed, etc.
* Ast#_fixed
4. The tree is evaluated in a recursive fashion.
* Ast#evaluate
122 changes: 58 additions & 64 deletions ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,103 +2,97 @@
This file contains the Ast class, which represents an abstract syntax tree which can be evaluated.
"""
import copy
from typing import Dict, Union
from typing import Dict

from common import RuleMatch, remove, left_assoc, Token, Value, value_map
from rules import calc_map
from common import RuleMatch, remove, left_assoc, Token
from rules import rule_process_map, rule_process_value_map


class Ast:
def __init__(self, ast: RuleMatch):
self.ast = self._fixed(ast)
def __init__(self, root: RuleMatch):
self.root = self._fixed(root)

def _fixed(self, ast):
def _fixed(self, node):
# print('**_fixed ast', ast)

if not isinstance(ast, RuleMatch):
return ast

# This flattens rules with a single matched element.
if len(ast.matched) is 1 and ast.name != 'num' and ast.name != 'mbd':
return self._fixed(ast.matched[0])
if not isinstance(node, RuleMatch):
return node

# This removes extraneous symbols from the tree.
for i in range(len(ast.matched) - 1, -1, -1):
if ast.matched[i].name in remove:
del ast.matched[i]
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name in remove:
del node.matched[i]

# This flattens rules with a single matched rule.
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch):
return self._fixed(node.matched[0])

# This makes left-associative operations left-associative.
for token_name, rule in left_assoc.items():
if len(ast.matched) == 3 and ast.matched[1].name == token_name and isinstance(ast.matched[2], RuleMatch) and len(ast.matched[2].matched) == 3 and ast.matched[2].matched[1].name == token_name:
ast.matched[0] = RuleMatch(rule, [ast.matched[0], ast.matched[1], ast.matched[2].matched[0]])
ast.matched[1] = ast.matched[2].matched[1]
ast.matched[2] = ast.matched[2].matched[2]
return self._fixed(ast)
if len(node.matched) == 3 and node.matched[1].name == token_name and isinstance(node.matched[2], RuleMatch) and len(node.matched[2].matched) == 3 and node.matched[2].matched[1].name == token_name:
node.matched[0] = RuleMatch(rule, [node.matched[0], node.matched[1], node.matched[2].matched[0]])
node.matched[1] = node.matched[2].matched[1]
node.matched[2] = node.matched[2].matched[2]
return self._fixed(node)

# This converts implicit multiplication to regular multiplication.
if ast.name == 'mui':
return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]]))
if node.name == 'mui':
return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]]))

# This flattens matrix rows into parent matrix rows.
if ast.name == 'mrw':
for i in range(len(ast.matched) - 1, -1, -1):
if ast.matched[i].name == 'mrw':
ast.matched[i:] = ast.matched[i].matched
return self._fixed(ast)
if node.name == 'mrw':
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == 'mrw':
node.matched[i:] = node.matched[i].matched
return self._fixed(node)

# This flattens matrix bodies into parent matrix bodies.
if ast.name == 'mbd':
for i in range(len(ast.matched) - 1, -1, -1):
if ast.matched[i].name == 'mbd':
ast.matched[i:] = ast.matched[i].matched
return self._fixed(ast)
if node.name == 'mbd':
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == 'mbd':
node.matched[i:] = node.matched[i].matched
return self._fixed(node)

if isinstance(ast, RuleMatch):
for i in range(len(ast.matched)):
ast.matched[i] = self._fixed(ast.matched[i])
if isinstance(node, RuleMatch):
for i in range(len(node.matched)):
node.matched[i] = self._fixed(node.matched[i])

return ast
return node

def evaluate(self, vrs: Dict[str, RuleMatch]):
return self._evaluate(self.ast, vrs)
return self._evaluate(self.root, vrs)

def _evaluate(self, ast, vrs: Dict[str, RuleMatch]): # -> Union[Dict[str, RuleMatch], Token]:
if ast.name == 'idt':
return {ast.matched[0].value: ast.matched[1]}
def _evaluate(self, node, vrs: Dict[str, RuleMatch]):
if node.name == 'asn':
return {node.matched[0].value: node.matched[1]}

for token in ast.matched:
for token in node.matched:
if isinstance(token, RuleMatch) and not token.value:
token.value = self._evaluate(token, vrs)

if any(map(lambda t: isinstance(t, RuleMatch), ast.matched)):
return calc_map[ast.name](ast.matched)
values = [token.value for token in node.matched if isinstance(token, RuleMatch) and token.value]
tokens = [token for token in node.matched if not isinstance(token, RuleMatch)]

if node.matched[0].name == 'IDT':
return self._evaluate(copy.deepcopy(vrs[node.matched[0].value]), vrs)

elif node.name in rule_process_value_map:
process = rule_process_value_map[node.name](values, tokens)

else:
if ast.matched[0].name == 'IDT':
return self._evaluate(copy.deepcopy(vrs[ast.matched[0].value]), vrs)
process = rule_process_map[node.name](values, tokens[0] if len(tokens) > 0 else None) # This extra rule is part of the num hotfix.

else:
# At this point, ast.name will _always_ be `num`.
return calc_map[ast.name](ast.matched)
return process.value

def infix(self) -> str:
# TODO: Add parentheses where needed.
return self._infix(self.ast)
# TODO: Add parentheses and missing tokens.
return self._infix(self.root)

def _infix(self, ast: RuleMatch) -> str:
return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched))
def _infix(self, node: RuleMatch) -> str:
return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), node.matched))

def __str__(self):
return self._str(self.ast) # + '\n>> ' + self.infix()

def _str(self, ast, depth=0) -> str:
output = (('\t' * depth) + ast.name + ' = ' + str(ast.value.value)) + '\n'

for matched in ast.matched:
if isinstance(matched, RuleMatch) and matched.matched:
output += self._str(matched, depth + 1)

else:
output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n'
return str(self.root) # + '\n>> ' + self.infix()

return output
def __repr__(self):
return str(self)
21 changes: 14 additions & 7 deletions calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re

from ast import Ast
from common import Value, Token, token_map, rules_map, RuleMatch
from common import Token, token_map, rules_map, RuleMatch, Value


class Calculator:
Expand All @@ -15,7 +15,7 @@ def __init__(self):

def evaluate(self, eqtn: str, verbose=True) -> Value:
for e in eqtn.split(';'):
root, remaining_tokens = self._match(self._tokenize(e), 'idt')
root, remaining_tokens = self._match(self._tokenize(e), 'asn')

if remaining_tokens:
raise Exception('Invalid equation (bad format)')
Expand All @@ -24,7 +24,7 @@ def evaluate(self, eqtn: str, verbose=True) -> Value:
res = ast.evaluate(self.vrs)

if isinstance(res, Value):
ast.ast.value = res
ast.root.value = res

if verbose:
print(ast)
Expand All @@ -49,10 +49,13 @@ def _tokenize(self, eqtn: str) -> List[Token]:
def _match(self, tokens: List[Token], target_rule: str):
# print('match', tokens, target_rule)

if tokens and tokens[0].name == target_rule: # This is a token, not a rule.
return tokens[0], tokens[1:]
if target_rule.isupper(): # This is a token, not a rule.
if tokens and tokens[0].name == target_rule:
return tokens[0], tokens[1:]

for pattern in rules_map.get(target_rule, ()):
return None, None

for pattern in rules_map[target_rule]:
# print('trying pattern', pattern)

remaining_tokens = tokens
Expand All @@ -69,5 +72,9 @@ def _match(self, tokens: List[Token], target_rule: str):
matched.append(m)
else:
# Success!
return RuleMatch(target_rule, matched, None), remaining_tokens
return RuleMatch(target_rule, matched), remaining_tokens

if rules_map.index(target_rule) + 1 < len(rules_map):
return self._match(tokens, rules_map.key_at(rules_map.index(target_rule) + 1))

return None, None
70 changes: 40 additions & 30 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,36 @@
This file contains important information for the calculator.
"""

from collections import namedtuple, OrderedDict
from enum import Enum
from collections import OrderedDict, namedtuple
from typing import List

Token = namedtuple('Token', ('name', 'value'))
Value = namedtuple('Value', ('type', 'value'))


class RuleMatch:
def __init__(self, name: str, matched: List[Token], value: Value = None):
def __init__(self, name: str, matched: List[Token]):
self.name = name
self.matched = matched
self.value = value
self.value = None

def __str__(self):
return 'RuleMatch(' + ', '.join(map(str, [self.name, self.matched, self.value])) + ')'
return self._str(self) # + '\n>> ' + self.infix()

def __repr__(self):
return str(self)

def _str(self, ast, depth=0) -> str:
output = (('\t' * depth) + ast.name + ' = ' + str(ast.value if ast.value else None)) + '\n'

for matched in ast.matched:
if isinstance(matched, RuleMatch) and matched.matched:
output += self._str(matched, depth + 1)

else:
output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n'

return output


token_map = OrderedDict((
(r'\d+(?:\.\d+)?', 'NUM'),
Expand All @@ -32,6 +42,8 @@ def __repr__(self):
(r'trans', 'OPR'),
(r'cof', 'OPR'),
(r'inv', 'OPR'),
(r'identity', 'OPR'),
(r'trnsform', 'OPR'),
(r'rref', 'OPR'),
(r'[a-zA-Z_]+', 'IDT'),
(r'=', 'EQL'),
Expand All @@ -52,34 +64,32 @@ def __repr__(self):

remove = ('EQL', 'LPA', 'RPA', 'LBR', 'RBR', 'CMA', 'PPE')

rules_map = OrderedDict((
('idt', ('IDT EQL add', 'mat')),
('mat', ('LBR mbd RBR', 'add')),
('mbd', ('mrw PPE mbd', 'mrw', 'add')),
('mrw', ('add CMA mrw', 'add')),
('add', ('mul ADD add', 'mui', 'mul')),

class IndexedOrderedDict(OrderedDict):
def index(self, key):
return list(self.keys()).index(key)

def key_at(self, i):
return list(self.keys())[i]


rules_map = IndexedOrderedDict((
('asn', ('IDT EQL add',)),
('mat', ('LBR mbd RBR',)),
('mbd', ('mrw PPE mbd',)),
('mrw', ('add CMA mrw',)),
('add', ('mul ADD add',)),
('mui', ('pow mul',)),
('mul', ('pow MUL mul', 'pow')),
('pow', ('opr POW pow', 'opr')),
('opr', ('OPR LPA mat RPA', 'neg')),
('neg', ('ADD num', 'ADD opr', 'num')),
('num', ('NUM', 'IDT', 'LPA add RPA')),
('mul', ('pow MUL mul',)),
('pow', ('opr POW pow',)),
('opr', ('OPR LPA mat RPA',)),
('neg', ('ADD num', 'ADD opr')),
('var', ('IDT',)),
('num', ('NUM', 'LPA add RPA')),
))


left_assoc = {
'ADD': 'add',
'MUL': 'mul',
}


class Type(Enum):
Number = 0
Matrix = 1
MatrixRow = 2


value_map = {
'NUM': Type.Number,
'MAT': Type.Matrix,
'MRW': Type.MatrixRow
}
Loading

0 comments on commit 3f7e035

Please sign in to comment.