From b7ffd79019b4a70a05ba3834dd5c61f00bdedb9b Mon Sep 17 00:00:00 2001 From: nrubin29 Date: Sun, 10 Sep 2017 18:34:17 -0400 Subject: [PATCH] Split code into multiple files. --- __init__.py | 0 ast.py | 104 ++++++++++++++++++++ calculator.py | 98 +++++++++---------- calculator_ast.py | 242 ---------------------------------------------- common.py | 63 ++++++++++++ main.py | 11 +++ rules.py | 56 +++++++++++ tests.py | 9 +- 8 files changed, 285 insertions(+), 298 deletions(-) create mode 100644 __init__.py create mode 100644 ast.py delete mode 100644 calculator_ast.py create mode 100644 common.py create mode 100644 main.py create mode 100644 rules.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ast.py b/ast.py new file mode 100644 index 0000000..b0ca84f --- /dev/null +++ b/ast.py @@ -0,0 +1,104 @@ +""" +This file contains the Ast class, which represents an abstract syntax tree which can be evaluated. +""" +import copy +from typing import Dict, Union + +from common import RuleMatch, remove, left_assoc, Token, Value, value_map +from rules import calc_map + + +class Ast: + def __init__(self, ast: RuleMatch): + self.ast = self._fixed(ast) + + def _fixed(self, ast): + # print('**_fixed ast', ast) + + if not isinstance(ast, RuleMatch): + return ast + + # This flattens rules with a single matched element. + if len(ast.matched) is 1 and ast.name != 'num' and ast.name != 'mbd': + return self._fixed(ast.matched[0]) + + # This removes extraneous symbols from the tree. + for i in range(len(ast.matched) - 1, -1, -1): + if ast.matched[i].name in remove: + del ast.matched[i] + + # This makes left-associative operations left-associative. + for token_name, rule in left_assoc.items(): + if len(ast.matched) == 3 and ast.matched[1].name == token_name and isinstance(ast.matched[2], RuleMatch) and len(ast.matched[2].matched) == 3 and ast.matched[2].matched[1].name == token_name: + ast.matched[0] = RuleMatch(rule, [ast.matched[0], ast.matched[1], ast.matched[2].matched[0]]) + ast.matched[1] = ast.matched[2].matched[1] + ast.matched[2] = ast.matched[2].matched[2] + return self._fixed(ast) + + # This converts implicit multiplication to regular multiplication. + if ast.name == 'mui': + return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]])) + + # This flattens matrix rows into parent matrix rows. + if ast.name == 'mrw' and ast.matched[1].name == 'mrw': + ast.matched[1:] = ast.matched[1].matched + return self._fixed(ast) + + # This flattens matrix bodies into parent matrix bodies. + if ast.name == 'mbd' and len(ast.matched) > 1 and ast.matched[1].name == 'mbd': + ast.matched[1:] = ast.matched[1].matched + return self._fixed(ast) + + if isinstance(ast, RuleMatch): + for i in range(len(ast.matched)): + ast.matched[i] = self._fixed(ast.matched[i]) + + return ast + + def evaluate(self, vrs: Dict[str, RuleMatch]): + res = self._evaluate(self.ast, vrs) + + if isinstance(res, Token): + return Value(value_map[res.name], res.value) + + return res + + def _evaluate(self, ast, vrs: Dict[str, RuleMatch]) -> Union[Dict[str, RuleMatch], Token]: + if ast.name == 'idt': + return {ast.matched[0].value: ast.matched[1]} + + for i in range(len(ast.matched)): + token = ast.matched[i] + + if isinstance(token, RuleMatch): + ast.matched[i] = self._evaluate(token, vrs) + return self._evaluate(ast, vrs) + else: + if ast.matched[0].name == 'IDT': + return self._evaluate(copy.deepcopy(vrs[ast.matched[0].value]), vrs) + + else: + return calc_map[ast.name](ast.matched) + + def infix(self) -> str: + # TODO: Add parentheses where needed. + return self._infix(self.ast) + + def _infix(self, ast: RuleMatch) -> str: + return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched)) + + def __str__(self): + return self._str(self.ast) # + '\n>> ' + self.infix() + + def _str(self, ast, depth=0) -> str: + output = (('\t' * depth) + ast.name) + '\n' + + for matched in ast.matched: + # print('**matched', matched) + if isinstance(matched, RuleMatch) and matched.matched: + output += self._str(matched, depth + 1) + + else: + output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n' + + return output diff --git a/calculator.py b/calculator.py index c21c3cc..f15122d 100644 --- a/calculator.py +++ b/calculator.py @@ -1,75 +1,69 @@ """ -A calculator implemented with the Shunting Yard Algorithm. The AST version is far superior and this version should be disregarded. +This file contains the Calculator class, which accept an equation and generates an AST, and also keeps track of variables. """ +from typing import List import re -from collections import namedtuple -from operator import add, sub, mul, truediv, pow, mod -Operator = namedtuple('Operator', ('symbol', 'function', 'precedence', 'associativity')) +from ast import Ast +from common import Value, Token, token_map, rules_map, RuleMatch -operators = { - '+': Operator('+', add, 2, 'l'), - '-': Operator('-', sub, 2, 'l'), - '*': Operator('*', mul, 3, 'l'), - '/': Operator('/', truediv, 3, 'l'), - '%': Operator('%', mod, 3, 'l'), - '^': Operator('^', pow, 4, 'r'), - '(': Operator('(', None, 5, ''), - ')': Operator(')', None, 5, '') -} -expr = r'[\+\-\*\/|\^\(\)]|\d+' -eqtn = '3 + 4 * 2 / ( 1 - 5 ) ^ 2 ^ 3' # => 3.0001220703 # input() +class Calculator: + def __init__(self): + self.vrs = {} -rpn = [] -ops = [] + def evaluate(self, eqtn: str) -> Value: + for e in eqtn.split(';'): + root, remaining_tokens = self._match(self._tokenize(e), 'idt') -for token in re.findall(expr, eqtn): - if re.match(r'\d+', token): - rpn.append(token) + if remaining_tokens: + raise Exception('Invalid equation (bad format)') - elif token is '(': - ops.append(token) + ast = Ast(root) + print(ast) + res = ast.evaluate(self.vrs) - elif token is ')': - for op in reversed(list(map(operators.__getitem__, ops))): - if op.symbol is not '(': - rpn.append(ops.pop()) + if isinstance(res, Value): + return res - else: - break - - ops.pop() + elif isinstance(res, dict): + self.vrs.update(res) - else: - for op in reversed(list(map(operators.__getitem__, ops))): - if op.symbol is not '(' and op.precedence >= operators[token].precedence and op.associativity is 'l': - rpn.append(ops.pop()) + def _tokenize(self, eqtn: str) -> List[Token]: + tokens = [] - else: - break + if re.sub('(' + ')|('.join(token_map.keys()) + ')', '', 'eqtn').strip(): + raise Exception('Invalid equation (illegal tokens)') - ops.append(token) + for match in re.findall('(' + ')|('.join(token_map.keys()) + ')', eqtn): + entry = next(filter(lambda entry: entry[1] != '', enumerate(match)), None) + tokens.append(Token(list(token_map.values())[entry[0]], entry[1])) - # print('{:20s}|{:20s}|{:20s}'.format(token, ' '.join(rpn), ' '.join(ops)).strip()) + return tokens -# print() + def _match(self, tokens: List[Token], target_rule: str): + # print('match', tokens, target_rule) -while len(ops) > 0: - rpn.append(ops.pop()) + if tokens and tokens[0].name == target_rule: # This is a token, not a rule. + return tokens[0], tokens[1:] -print(' '.join(rpn)) + for pattern in rules_map.get(target_rule, ()): + # print('trying pattern', pattern) -output = [] + remaining_tokens = tokens + matched = [] -while len(rpn) > 0: - token = rpn.pop(0) - if re.match(r'\d+', token): - output.append(int(token)) + for pattern_token in pattern.split(): + # print('checking pattern_token', pattern_token) + m, remaining_tokens = self._match(remaining_tokens, pattern_token) - else: - b, a = output.pop(), output.pop() - output.append(operators[token].function(a, b)) + if not m: + # print('failed pattern match') + break -print(output[0]) + matched.append(m) + else: + # Success! + return RuleMatch(target_rule, matched), remaining_tokens + return None, None diff --git a/calculator_ast.py b/calculator_ast.py deleted file mode 100644 index 3b3ff01..0000000 --- a/calculator_ast.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -A calculator implemented with an Abstract Syntax Tree (AST). -""" - -import copy -import re -from collections import namedtuple, OrderedDict -from enum import Enum -from typing import List, Dict, Union - -import math - -Token = namedtuple('Token', ('name', 'value')) -RuleMatch = namedtuple('RuleMatch', ('name', 'matched')) - -token_map = OrderedDict(( - (r'\d+(?:\.\d+)?', 'NUM'), - (r'sqrt', 'OPR'), - (r'exp', 'OPR'), - (r'[a-zA-Z_]+', 'IDT'), - (r'=', 'EQL'), - (r'\+', 'ADD'), - (r'-', 'ADD'), - (r'\*\*', 'POW'), - (r'\*', 'MUL'), - (r'\/', 'MUL'), - (r'%', 'MUL'), - (r'\^', 'POW'), - (r'\(', 'LPA'), - (r'\)', 'RPA'), - (r'\[', 'LBR'), - (r'\]', 'RBR'), - (r'\,', 'CMA'), - (r'\|', 'PPE') -)) - -remove = ('EQL', 'LPA', 'RPA', 'LBR', 'RBR', 'CMA', 'PPE') - -rules_map = OrderedDict(( - ('idt', ('IDT EQL add', 'mat')), - ('mat', ('LBR mbd RBR', 'add')), - ('mbd', ('mrw PPE mbd', 'mrw', 'add')), - ('mrw', ('add CMA mrw', 'add')), - ('add', ('mul ADD add', 'mui', 'mul')), - ('mui', ('pow mul',)), - ('mul', ('pow MUL mul', 'pow')), - ('pow', ('opr POW pow', 'opr')), - ('opr', ('OPR LPA add RPA', 'neg')), - ('neg', ('ADD num', 'ADD opr', 'num')), - ('num', ('NUM', 'IDT', 'LPA add RPA')), -)) - -left_assoc = { - 'ADD': 'add', - 'MUL': 'mul', -} - - -class Type(Enum): - Number = 0 - Matrix = 1 - - -Value = namedtuple('Value', ('type', 'value')) -value_map = { - 'NUM': Type.Number, - 'MAT': Type.Matrix -} - -calc_map = { - 'add': lambda tokens: Token('NUM', float(tokens[0].value) + float(tokens[2].value) if tokens[1].value == '+' else float(tokens[0].value) - float(tokens[2].value)), - 'mul': lambda tokens: Token('NUM', float(tokens[0].value) * float(tokens[2].value) if tokens[1].value == '*' else float(tokens[0].value) / float(tokens[2].value) if tokens[1].value == '/' else float(tokens[0].value) % float(tokens[2].value)), - 'pow': lambda tokens: Token('NUM', float(tokens[0].value) ** float(tokens[2].value)), - 'opr': lambda tokens: Token('NUM', {'sqrt': math.sqrt, 'exp': math.exp}[tokens[0].value](tokens[1].value)), - 'neg': lambda tokens: Token('NUM', float(tokens[1].value) if tokens[0].value == '+' else -float(tokens[1].value)), - 'num': lambda tokens: Token('NUM', float(tokens[0].value)), - 'mat': lambda tokens: Token('MAT', [float(tokens[1].value)]) -} - - -class Calculator: - def __init__(self): - self.vrs = {} - - def evaluate(self, eqtn: str) -> Value: - for e in eqtn.split(';'): - root, remaining_tokens = self._match(self._tokenize(e), 'idt') - - if remaining_tokens: - raise Exception('Invalid equation (bad format)') - - ast = Ast(root) - print(ast) - res = ast.evaluate(self.vrs) - - if isinstance(res, Value): - return res - - elif isinstance(res, dict): - self.vrs.update(res) - - def _tokenize(self, eqtn: str) -> List[Token]: - tokens = [] - - if re.sub('(' + ')|('.join(token_map.keys()) + ')', '', 'eqtn').strip(): - raise Exception('Invalid equation (illegal tokens)') - - for match in re.findall('(' + ')|('.join(token_map.keys()) + ')', eqtn): - entry = next(filter(lambda entry: entry[1] != '', enumerate(match)), None) - tokens.append(Token(list(token_map.values())[entry[0]], entry[1])) - - return tokens - - def _match(self, tokens: List[Token], target_rule: str): - # print('match', tokens, target_rule) - - if tokens and tokens[0].name == target_rule: # This is a token, not a rule. - return tokens[0], tokens[1:] - - for pattern in rules_map.get(target_rule, ()): - # print('trying pattern', pattern) - - remaining_tokens = tokens - matched = [] - - for pattern_token in pattern.split(): - # print('checking pattern_token', pattern_token) - m, remaining_tokens = self._match(remaining_tokens, pattern_token) - - if not m: - # print('failed pattern match') - break - - matched.append(m) - else: - # Success! - return RuleMatch(target_rule, matched), remaining_tokens - return None, None - - -class Ast: - def __init__(self, ast: RuleMatch): - self.ast = self._fixed(ast) - - def _fixed(self, ast): - # print('**_fixed ast', ast) - - if not isinstance(ast, RuleMatch): - return ast - - # This flattens rules with a single matched element. - if len(ast.matched) is 1 and ast.name != 'num': - if ast.name != 'mat' or ast.matched[0].name != 'mbd': - return self._fixed(ast.matched[0]) - - # This removes extraneous symbols from the tree. - for i in range(len(ast.matched) - 1, -1, -1): - if ast.matched[i].name in remove: - del ast.matched[i] - - # This makes left-associative operations left-associative. - for token_name, rule in left_assoc.items(): - if len(ast.matched) == 3 and ast.matched[1].name == token_name and isinstance(ast.matched[2], RuleMatch) and len(ast.matched[2].matched) == 3 and ast.matched[2].matched[1].name == token_name: - ast.matched[0] = RuleMatch(rule, [ast.matched[0], ast.matched[1], ast.matched[2].matched[0]]) - ast.matched[1] = ast.matched[2].matched[1] - ast.matched[2] = ast.matched[2].matched[2] - return self._fixed(ast) - - # This converts implicit multiplication to regular multiplication. - if ast.name == 'mui': - return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]])) - - # This flattens matrix rows into parent matrix rows. - if ast.name == 'mrw' and ast.matched[1].name == 'mrw': - ast.matched[1:] = ast.matched[1].matched - return self._fixed(ast) - - # This flattens matrix bodies into parent matrix bodies. - if ast.name == 'mbd' and ast.matched[1].name == 'mbd': - ast.matched[1:] = ast.matched[1].matched - return self._fixed(ast) - - if isinstance(ast, RuleMatch): - for i in range(len(ast.matched)): - ast.matched[i] = self._fixed(ast.matched[i]) - - return ast - - def evaluate(self, vrs: Dict[str, RuleMatch]): - res = self._evaluate(self.ast, vrs) - - if isinstance(res, Token): - return Value(value_map[res.name], res.value) - - return res - - def _evaluate(self, ast, vrs: Dict[str, RuleMatch]) -> Union[Dict[str, RuleMatch], Token]: - if ast.name == 'idt': - return {ast.matched[0].value: ast.matched[1]} - - for i in range(len(ast.matched)): - token = ast.matched[i] - - if isinstance(token, RuleMatch): - ast.matched[i] = self._evaluate(token, vrs) - return self._evaluate(ast, vrs) - else: - if ast.matched[0].name == 'IDT': - return self._evaluate(copy.deepcopy(vrs[ast.matched[0].value]), vrs) - - else: - return calc_map[ast.name](ast.matched) - - def infix(self) -> str: - # TODO: Add parentheses where needed. - return self._infix(self.ast) - - def _infix(self, ast: RuleMatch) -> str: - return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched)) - - def __str__(self): - return self._str(self.ast) # + '\n>> ' + self.infix() - - def _str(self, ast, depth=0) -> str: - output = (('\t' * depth) + ast.name) + '\n' - - for matched in ast.matched: - # print('**matched', matched) - if isinstance(matched, RuleMatch) and matched.matched: - output += self._str(matched, depth + 1) - - else: - output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n' - - return output - - -if __name__ == '__main__': - calc = Calculator() - - while True: - print(calc.evaluate(input('>> '))) diff --git a/common.py b/common.py new file mode 100644 index 0000000..2dc0b9b --- /dev/null +++ b/common.py @@ -0,0 +1,63 @@ +""" +This file contains important information for the calculator. +""" + +from collections import namedtuple, OrderedDict +from enum import Enum + +Token = namedtuple('Token', ('name', 'value')) +RuleMatch = namedtuple('RuleMatch', ('name', 'matched')) + +token_map = OrderedDict(( + (r'\d+(?:\.\d+)?', 'NUM'), + (r'sqrt', 'OPR'), + (r'exp', 'OPR'), + (r'[a-zA-Z_]+', 'IDT'), + (r'=', 'EQL'), + (r'\+', 'ADD'), + (r'-', 'ADD'), + (r'\*\*', 'POW'), + (r'\*', 'MUL'), + (r'\/', 'MUL'), + (r'%', 'MUL'), + (r'\^', 'POW'), + (r'\(', 'LPA'), + (r'\)', 'RPA'), + (r'\[', 'LBR'), + (r'\]', 'RBR'), + (r'\,', 'CMA'), + (r'\|', 'PPE') +)) + +remove = ('EQL', 'LPA', 'RPA', 'LBR', 'RBR', 'CMA', 'PPE') + +rules_map = OrderedDict(( + ('idt', ('IDT EQL add', 'mat')), + ('mat', ('LBR mbd RBR', 'add')), + ('mbd', ('mrw PPE mbd', 'mrw', 'add')), + ('mrw', ('add CMA mrw', 'add')), + ('add', ('mul ADD add', 'mui', 'mul')), + ('mui', ('pow mul',)), + ('mul', ('pow MUL mul', 'pow')), + ('pow', ('opr POW pow', 'opr')), + ('opr', ('OPR LPA add RPA', 'neg')), + ('neg', ('ADD num', 'ADD opr', 'num')), + ('num', ('NUM', 'IDT', 'LPA add RPA')), +)) + +left_assoc = { + 'ADD': 'add', + 'MUL': 'mul', +} + + +class Type(Enum): + Number = 0 + Matrix = 1 + + +Value = namedtuple('Value', ('type', 'value')) +value_map = { + 'NUM': Type.Number, + 'MAT': Type.Matrix +} \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..cff20be --- /dev/null +++ b/main.py @@ -0,0 +1,11 @@ +""" +A calculator implemented with an Abstract Syntax Tree (AST). +""" + +from calculator import Calculator + +if __name__ == '__main__': + calc = Calculator() + + while True: + print(calc.evaluate(input('>> '))) diff --git a/rules.py b/rules.py new file mode 100644 index 0000000..ead80a8 --- /dev/null +++ b/rules.py @@ -0,0 +1,56 @@ +""" +This file contains methods to handle calculation of all the rules. +""" +import math +from typing import List + +from common import Token + + +def add(tokens: List[Token]) -> Token: + return Token('NUM', float(tokens[0].value) + float(tokens[2].value) if tokens[1].value == '+' else float(tokens[0].value) - float(tokens[2].value)) + + +def mul(tokens: List[Token]) -> Token: + return Token('NUM', float(tokens[0].value) * float(tokens[2].value) if tokens[1].value == '*' else float(tokens[0].value) / float(tokens[2].value) if tokens[1].value == '/' else float(tokens[0].value) % float(tokens[2].value)) + + +def pow(tokens: List[Token]) -> Token: + return Token('NUM', float(tokens[0].value) ** float(tokens[2].value)) + + +def opr(tokens: List[Token]) -> Token: + return Token('NUM', {'sqrt': math.sqrt, 'exp': math.exp}[tokens[0].value](tokens[1].value)) + + +def neg(tokens: List[Token]) -> Token: + return Token('NUM', float(tokens[1].value) if tokens[0].value == '+' else -float(tokens[1].value)) + + +def num(tokens: List[Token]) -> Token: + return Token('NUM', float(tokens[0].value)) + + +def mrw(tokens: List[Token]) -> Token: + return Token('MRW', list(map(lambda t: float(t.value), tokens))) + + +def mbd(tokens: List[Token]) -> Token: + return Token('MAT', list(map(lambda t: t.value, tokens))) + + +def mat(tokens: List[Token]) -> Token: + return tokens[0] + + +calc_map = { + 'add': add, + 'mul': mul, + 'pow': pow, + 'opr': opr, + 'neg': neg, + 'num': num, + 'mrw': mrw, + 'mbd': mbd, + 'mat': mat +} diff --git a/tests.py b/tests.py index 7bbd194..5fdf1b0 100644 --- a/tests.py +++ b/tests.py @@ -5,7 +5,7 @@ import random import unittest -from calculator_ast import Calculator +from calculator import Calculator def evaluate(eqtn: str): @@ -99,9 +99,10 @@ def runTest(self): class MatrixTests(unittest.TestCase): def runTest(self): - # self.assertEqual(evaluate('[1,2]'), [1.0]) - # self.assertEqual(evaluate('[1,2|4,5]'), [1.0]) - self.assertEqual(evaluate('[1,0,0|0,1,0|0,0,1]'), [1.0]) + self.assertEqual(evaluate('[1,2]'), [[1.0, 2.0]]) + self.assertEqual(evaluate('[1,2|4,5]'), [[1.0, 2.0], [4.0, 5.0]]) + self.assertEqual(evaluate('[1,0,0|0,1,0|0,0,1]'), [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [[0.0, 0.0, 1.0]]]) + self.assertEqual(evaluate('[1,0,0,0|0,1,0,0|0,0,1,0|0,0,0,1]'), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]) class RandomTests(unittest.TestCase):