Skip to content

Commit

Permalink
Added checks for validity of equation.
Browse files Browse the repository at this point in the history
Cleaned up AST fixing.
Added matrix fixing.
  • Loading branch information
nrubin29 committed Sep 6, 2017
1 parent f931439 commit 05a3f35
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
59 changes: 38 additions & 21 deletions calculator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from collections import namedtuple, OrderedDict
from enum import Enum
from typing import List, Dict
from typing import List, Dict, Union

import math

Expand Down Expand Up @@ -34,6 +34,8 @@
(r'\|', 'PPE')
))

remove = ('EQL', 'LPA', 'RPA', 'LBR', 'RBR', 'CMA', 'PPE')

rules_map = OrderedDict((
('idt', ('IDT EQL add', 'mat')),
('mat', ('LBR mbd RBR', 'add')),
Expand All @@ -58,6 +60,7 @@ class Type(Enum):
Number = 0
Matrix = 1


Value = namedtuple('Value', ('type', 'value'))
value_map = {
'NUM': Type.Number,
Expand All @@ -79,9 +82,14 @@ class Calculator:
def __init__(self):
self.vrs = {}

def evaluate(self, eqtn: str):
def evaluate(self, eqtn: str) -> Value:
for e in eqtn.split(';'):
ast = Ast(self._match(self._tokenize(e), 'idt')[0])
root, remaining_tokens = self._match(self._tokenize(e), 'idt')

if remaining_tokens:
raise Exception('Invalid equation (bad format)')

ast = Ast(root)
print(ast)
res = ast.evaluate(self.vrs)

Expand All @@ -91,9 +99,12 @@ def evaluate(self, eqtn: str):
elif isinstance(res, dict):
self.vrs.update(res)

def _tokenize(self, eqtn: str):
def _tokenize(self, eqtn: str) -> List[Token]:
tokens = []

if re.sub('(' + ')|('.join(token_map.keys()) + ')', '', 'eqtn').strip():
raise Exception('Invalid equation (illegal tokens)')

for match in re.findall('(' + ')|('.join(token_map.keys()) + ')', eqtn):
entry = next(filter(lambda entry: entry[1] != '', enumerate(match)), None)
tokens.append(Token(list(token_map.values())[entry[0]], entry[1]))
Expand Down Expand Up @@ -139,11 +150,13 @@ def _fixed(self, ast):

# This flattens rules with a single matched element.
if len(ast.matched) is 1 and ast.name != 'num':
return self._fixed(ast.matched[0])
if ast.name != 'mat' or ast.matched[0].name != 'mbd':
return self._fixed(ast.matched[0])

# This flattens `num`s by removing parentheses.
if ast.name == 'num' and len(ast.matched) is 3:
return self._fixed(ast.matched[1])
# This removes extraneous symbols from the tree.
for i in range(len(ast.matched) - 1, -1, -1):
if ast.matched[i].name in remove:
del ast.matched[i]

# This makes left-associative operations left-associative.
for token_name, rule in left_assoc.items():
Expand All @@ -155,14 +168,16 @@ def _fixed(self, ast):

# This converts implicit multiplication to regular multiplication.
if ast.name == 'mui':
m = RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]])
return self._fixed(m)

# This removes the parentheses from an operation.
if ast.name == 'opr' and len(ast.matched) == 4:
ast.matched[1] = ast.matched[2]
del ast.matched[3]
del ast.matched[2]
return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]]))

# This flattens matrix rows into parent matrix rows.
if ast.name == 'mrw' and ast.matched[1].name == 'mrw':
ast.matched[1:] = ast.matched[1].matched
return self._fixed(ast)

# This flattens matrix bodies into parent matrix bodies.
if ast.name == 'mbd' and ast.matched[1].name == 'mbd':
ast.matched[1:] = ast.matched[1].matched
return self._fixed(ast)

if isinstance(ast, RuleMatch):
Expand All @@ -179,9 +194,9 @@ def evaluate(self, vrs: Dict[str, RuleMatch]):

return res

def _evaluate(self, ast, vrs: Dict[str, RuleMatch]):
def _evaluate(self, ast, vrs: Dict[str, RuleMatch]) -> Union[Dict[str, RuleMatch], Token]:
if ast.name == 'idt':
return {ast.matched[0].value: ast.matched[2]}
return {ast.matched[0].value: ast.matched[1]}

for i in range(len(ast.matched)):
token = ast.matched[i]
Expand All @@ -196,17 +211,17 @@ def _evaluate(self, ast, vrs: Dict[str, RuleMatch]):
else:
return calc_map[ast.name](ast.matched)

def infix(self):
def infix(self) -> str:
# TODO: Add parentheses where needed.
return self._infix(self.ast)

def _infix(self, ast: RuleMatch):
def _infix(self, ast: RuleMatch) -> str:
return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched))

def __str__(self):
return self._str(self.ast) # + '\n>> ' + self.infix()

def _str(self, ast, depth=0):
def _str(self, ast, depth=0) -> str:
output = (('\t' * depth) + ast.name) + '\n'

for matched in ast.matched:
Expand All @@ -219,7 +234,9 @@ def _str(self, ast, depth=0):

return output


if __name__ == '__main__':
calc = Calculator()

while True:
print(calc.evaluate(input('>> ')))
12 changes: 12 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ def evaluate(eqtn: str):
return res.value


class InvalidEquationTests(unittest.TestCase):
def runTest(self):
with self.assertRaises(Exception):
evaluate('$')

with self.assertRaises(Exception):
evaluate('1 +')

with self.assertRaises(Exception):
evaluate('1 * / 2')


class AdditionTests(unittest.TestCase):
def runTest(self):
self.assertEqual(evaluate('1 + 2'), 3.0)
Expand Down

0 comments on commit 05a3f35

Please sign in to comment.