diff --git a/EXPLANATION.md b/EXPLANATION.md new file mode 100644 index 0000000..0930e3a --- /dev/null +++ b/EXPLANATION.md @@ -0,0 +1,8 @@ +1. The input is processed into a series of tokens using regular expressions. + * Calculator#_tokenize +2. The tokens are turned into a tree of RuleMatches using a right-recursive pattern matching algorithm. + * Calculator#_match +3. The tree is fixed. Unnecessary tokens are removed, precedence issues are fixed, etc. + * Ast#_fixed +4. The tree is evaluated in a recursive fashion. + * Ast#evaluate \ No newline at end of file diff --git a/ast.py b/ast.py index 861b545..6a18915 100644 --- a/ast.py +++ b/ast.py @@ -2,103 +2,97 @@ This file contains the Ast class, which represents an abstract syntax tree which can be evaluated. """ import copy -from typing import Dict, Union +from typing import Dict -from common import RuleMatch, remove, left_assoc, Token, Value, value_map -from rules import calc_map +from common import RuleMatch, remove, left_assoc, Token +from rules import rule_process_map, rule_process_value_map class Ast: - def __init__(self, ast: RuleMatch): - self.ast = self._fixed(ast) + def __init__(self, root: RuleMatch): + self.root = self._fixed(root) - def _fixed(self, ast): + def _fixed(self, node): # print('**_fixed ast', ast) - if not isinstance(ast, RuleMatch): - return ast - - # This flattens rules with a single matched element. - if len(ast.matched) is 1 and ast.name != 'num' and ast.name != 'mbd': - return self._fixed(ast.matched[0]) + if not isinstance(node, RuleMatch): + return node # This removes extraneous symbols from the tree. - for i in range(len(ast.matched) - 1, -1, -1): - if ast.matched[i].name in remove: - del ast.matched[i] + for i in range(len(node.matched) - 1, -1, -1): + if node.matched[i].name in remove: + del node.matched[i] + + # This flattens rules with a single matched rule. + if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch): + return self._fixed(node.matched[0]) # This makes left-associative operations left-associative. for token_name, rule in left_assoc.items(): - if len(ast.matched) == 3 and ast.matched[1].name == token_name and isinstance(ast.matched[2], RuleMatch) and len(ast.matched[2].matched) == 3 and ast.matched[2].matched[1].name == token_name: - ast.matched[0] = RuleMatch(rule, [ast.matched[0], ast.matched[1], ast.matched[2].matched[0]]) - ast.matched[1] = ast.matched[2].matched[1] - ast.matched[2] = ast.matched[2].matched[2] - return self._fixed(ast) + if len(node.matched) == 3 and node.matched[1].name == token_name and isinstance(node.matched[2], RuleMatch) and len(node.matched[2].matched) == 3 and node.matched[2].matched[1].name == token_name: + node.matched[0] = RuleMatch(rule, [node.matched[0], node.matched[1], node.matched[2].matched[0]]) + node.matched[1] = node.matched[2].matched[1] + node.matched[2] = node.matched[2].matched[2] + return self._fixed(node) # This converts implicit multiplication to regular multiplication. - if ast.name == 'mui': - return self._fixed(RuleMatch('mul', [ast.matched[0], Token('MUL', '*'), ast.matched[1]])) + if node.name == 'mui': + return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]])) # This flattens matrix rows into parent matrix rows. - if ast.name == 'mrw': - for i in range(len(ast.matched) - 1, -1, -1): - if ast.matched[i].name == 'mrw': - ast.matched[i:] = ast.matched[i].matched - return self._fixed(ast) + if node.name == 'mrw': + for i in range(len(node.matched) - 1, -1, -1): + if node.matched[i].name == 'mrw': + node.matched[i:] = node.matched[i].matched + return self._fixed(node) # This flattens matrix bodies into parent matrix bodies. - if ast.name == 'mbd': - for i in range(len(ast.matched) - 1, -1, -1): - if ast.matched[i].name == 'mbd': - ast.matched[i:] = ast.matched[i].matched - return self._fixed(ast) + if node.name == 'mbd': + for i in range(len(node.matched) - 1, -1, -1): + if node.matched[i].name == 'mbd': + node.matched[i:] = node.matched[i].matched + return self._fixed(node) - if isinstance(ast, RuleMatch): - for i in range(len(ast.matched)): - ast.matched[i] = self._fixed(ast.matched[i]) + if isinstance(node, RuleMatch): + for i in range(len(node.matched)): + node.matched[i] = self._fixed(node.matched[i]) - return ast + return node def evaluate(self, vrs: Dict[str, RuleMatch]): - return self._evaluate(self.ast, vrs) + return self._evaluate(self.root, vrs) - def _evaluate(self, ast, vrs: Dict[str, RuleMatch]): # -> Union[Dict[str, RuleMatch], Token]: - if ast.name == 'idt': - return {ast.matched[0].value: ast.matched[1]} + def _evaluate(self, node, vrs: Dict[str, RuleMatch]): + if node.name == 'asn': + return {node.matched[0].value: node.matched[1]} - for token in ast.matched: + for token in node.matched: if isinstance(token, RuleMatch) and not token.value: token.value = self._evaluate(token, vrs) - if any(map(lambda t: isinstance(t, RuleMatch), ast.matched)): - return calc_map[ast.name](ast.matched) + values = [token.value for token in node.matched if isinstance(token, RuleMatch) and token.value] + tokens = [token for token in node.matched if not isinstance(token, RuleMatch)] + + if node.matched[0].name == 'IDT': + return self._evaluate(copy.deepcopy(vrs[node.matched[0].value]), vrs) + + elif node.name in rule_process_value_map: + process = rule_process_value_map[node.name](values, tokens) else: - if ast.matched[0].name == 'IDT': - return self._evaluate(copy.deepcopy(vrs[ast.matched[0].value]), vrs) + process = rule_process_map[node.name](values, tokens[0] if len(tokens) > 0 else None) # This extra rule is part of the num hotfix. - else: - # At this point, ast.name will _always_ be `num`. - return calc_map[ast.name](ast.matched) + return process.value def infix(self) -> str: - # TODO: Add parentheses where needed. - return self._infix(self.ast) + # TODO: Add parentheses and missing tokens. + return self._infix(self.root) - def _infix(self, ast: RuleMatch) -> str: - return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), ast.matched)) + def _infix(self, node: RuleMatch) -> str: + return ' '.join(map(lambda t: t.value if isinstance(t, Token) else self._infix(t), node.matched)) def __str__(self): - return self._str(self.ast) # + '\n>> ' + self.infix() - - def _str(self, ast, depth=0) -> str: - output = (('\t' * depth) + ast.name + ' = ' + str(ast.value.value)) + '\n' - - for matched in ast.matched: - if isinstance(matched, RuleMatch) and matched.matched: - output += self._str(matched, depth + 1) - - else: - output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n' + return str(self.root) # + '\n>> ' + self.infix() - return output + def __repr__(self): + return str(self) diff --git a/calculator.py b/calculator.py index 73f870f..5d6e655 100644 --- a/calculator.py +++ b/calculator.py @@ -6,7 +6,7 @@ import re from ast import Ast -from common import Value, Token, token_map, rules_map, RuleMatch +from common import Token, token_map, rules_map, RuleMatch, Value class Calculator: @@ -15,7 +15,7 @@ def __init__(self): def evaluate(self, eqtn: str, verbose=True) -> Value: for e in eqtn.split(';'): - root, remaining_tokens = self._match(self._tokenize(e), 'idt') + root, remaining_tokens = self._match(self._tokenize(e), 'asn') if remaining_tokens: raise Exception('Invalid equation (bad format)') @@ -24,7 +24,7 @@ def evaluate(self, eqtn: str, verbose=True) -> Value: res = ast.evaluate(self.vrs) if isinstance(res, Value): - ast.ast.value = res + ast.root.value = res if verbose: print(ast) @@ -49,10 +49,13 @@ def _tokenize(self, eqtn: str) -> List[Token]: def _match(self, tokens: List[Token], target_rule: str): # print('match', tokens, target_rule) - if tokens and tokens[0].name == target_rule: # This is a token, not a rule. - return tokens[0], tokens[1:] + if target_rule.isupper(): # This is a token, not a rule. + if tokens and tokens[0].name == target_rule: + return tokens[0], tokens[1:] - for pattern in rules_map.get(target_rule, ()): + return None, None + + for pattern in rules_map[target_rule]: # print('trying pattern', pattern) remaining_tokens = tokens @@ -69,5 +72,9 @@ def _match(self, tokens: List[Token], target_rule: str): matched.append(m) else: # Success! - return RuleMatch(target_rule, matched, None), remaining_tokens + return RuleMatch(target_rule, matched), remaining_tokens + + if rules_map.index(target_rule) + 1 < len(rules_map): + return self._match(tokens, rules_map.key_at(rules_map.index(target_rule) + 1)) + return None, None diff --git a/common.py b/common.py index 515f34b..0b5700f 100644 --- a/common.py +++ b/common.py @@ -2,26 +2,36 @@ This file contains important information for the calculator. """ -from collections import namedtuple, OrderedDict -from enum import Enum +from collections import OrderedDict, namedtuple from typing import List Token = namedtuple('Token', ('name', 'value')) Value = namedtuple('Value', ('type', 'value')) - class RuleMatch: - def __init__(self, name: str, matched: List[Token], value: Value = None): + def __init__(self, name: str, matched: List[Token]): self.name = name self.matched = matched - self.value = value + self.value = None def __str__(self): - return 'RuleMatch(' + ', '.join(map(str, [self.name, self.matched, self.value])) + ')' + return self._str(self) # + '\n>> ' + self.infix() def __repr__(self): return str(self) + def _str(self, ast, depth=0) -> str: + output = (('\t' * depth) + ast.name + ' = ' + str(ast.value if ast.value else None)) + '\n' + + for matched in ast.matched: + if isinstance(matched, RuleMatch) and matched.matched: + output += self._str(matched, depth + 1) + + else: + output += (('\t' * (depth + 1)) + matched.name + ': ' + matched.value) + '\n' + + return output + token_map = OrderedDict(( (r'\d+(?:\.\d+)?', 'NUM'), @@ -32,6 +42,8 @@ def __repr__(self): (r'trans', 'OPR'), (r'cof', 'OPR'), (r'inv', 'OPR'), + (r'identity', 'OPR'), + (r'trnsform', 'OPR'), (r'rref', 'OPR'), (r'[a-zA-Z_]+', 'IDT'), (r'=', 'EQL'), @@ -52,34 +64,32 @@ def __repr__(self): remove = ('EQL', 'LPA', 'RPA', 'LBR', 'RBR', 'CMA', 'PPE') -rules_map = OrderedDict(( - ('idt', ('IDT EQL add', 'mat')), - ('mat', ('LBR mbd RBR', 'add')), - ('mbd', ('mrw PPE mbd', 'mrw', 'add')), - ('mrw', ('add CMA mrw', 'add')), - ('add', ('mul ADD add', 'mui', 'mul')), + +class IndexedOrderedDict(OrderedDict): + def index(self, key): + return list(self.keys()).index(key) + + def key_at(self, i): + return list(self.keys())[i] + + +rules_map = IndexedOrderedDict(( + ('asn', ('IDT EQL add',)), + ('mat', ('LBR mbd RBR',)), + ('mbd', ('mrw PPE mbd',)), + ('mrw', ('add CMA mrw',)), + ('add', ('mul ADD add',)), ('mui', ('pow mul',)), - ('mul', ('pow MUL mul', 'pow')), - ('pow', ('opr POW pow', 'opr')), - ('opr', ('OPR LPA mat RPA', 'neg')), - ('neg', ('ADD num', 'ADD opr', 'num')), - ('num', ('NUM', 'IDT', 'LPA add RPA')), + ('mul', ('pow MUL mul',)), + ('pow', ('opr POW pow',)), + ('opr', ('OPR LPA mat RPA',)), + ('neg', ('ADD num', 'ADD opr')), + ('var', ('IDT',)), + ('num', ('NUM', 'LPA add RPA')), )) + left_assoc = { 'ADD': 'add', 'MUL': 'mul', } - - -class Type(Enum): - Number = 0 - Matrix = 1 - MatrixRow = 2 - - -value_map = { - 'NUM': Type.Number, - 'MAT': Type.Matrix, - 'MRW': Type.MatrixRow -} diff --git a/rules.py b/rules.py index 78a71be..dbeff1e 100644 --- a/rules.py +++ b/rules.py @@ -1,169 +1,78 @@ """ -This file contains methods to handle calculation of all the rules. +This file contains methods to handle processing RuleMatches """ -import copy -import math -import operator -from typing import List, Union +from typing import List -from common import Token, Value, Type, RuleMatch +from common import Token, Value +from vartypes import Number, MatrixRow, Matrix, Variable -def add(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, {'+': operator.add, '-': operator.sub}[tokens[1].value](tokens[0].value.value, tokens[2].value.value)) +class Process: + def __init__(self, operation, operands: List, raw_args=False): + self.operation = operation + self.operands = operands + if raw_args: + self.value = operation(operands) + else: + self.value = operation(*operands) -def mul(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, {'*': operator.mul, '/': operator.truediv, '%': operator.mod}[tokens[1].value](tokens[0].value.value, tokens[2].value.value)) +def var(_, tokens: List[Token]) -> Process: + return Process(Variable.new, tokens, raw_args=True) -def pow(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, tokens[0].value.value ** tokens[2].value.value) +def num(_, tokens: List[Token]) -> Process: + return Process(Number.new, tokens, raw_args=True) -def opr(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, {'sqrt': math.sqrt, 'exp': math.exp, 'det': det, 'trans': trans, 'cof': cof, 'adj': adj, 'inv': inv, 'rref': rref}[tokens[0].value](tokens[1].value.value)) +def mrw(values: List[Value], _) -> Process: + return Process(MatrixRow.new, values, raw_args=True) -def neg(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, tokens[1].value.value if tokens[0].value == '+' else -tokens[1].value.value) +def mbd(values: List[Value], _) -> Process: + return Process(Matrix.new, values, raw_args=True) -def num(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - if isinstance(tokens[0], RuleMatch): - return tokens[0].value - return Value(Type.Number, float(tokens[0].value)) +def add(operands: List[Value], operator: Token) -> Process: + return Process({'+': operands[0].type.add, '-': operands[0].type.sub}[operator.value], operands) -def mrw(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.MatrixRow, list(map(lambda t: t.value.value, tokens))) +def mul(operands: List[Value], operator: Token) -> Process: + return Process({'*': operands[0].type.mul, '/': operands[0].type.div, '%': operands[0].type.mod}[operator.value], operands) -def mbd(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Matrix, list(map(lambda t: t.value.value, tokens))) +def pow(operands: List[Value], operator: Token) -> Process: + return Process(operands[0].type.pow, operands) -def mat(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return tokens[0].value +def opr(operands: List[Value], operator: Token) -> Process: + return Process(getattr(operands[0].type, operator.value), operands) -def det(matrix: List[List[float]]) -> float: - if len(matrix) is 2: - return (matrix[0][0] * matrix[1][1]) - (matrix[0][1] * matrix[1][0]) +def neg(operands: List[Value], operator: Token) -> Process: + return Process({'+': operands[0].type.pos, '-': operands[0].type.neg}[operator.value], operands) - cofactors = [] - for col in range(len(matrix)): - cofactors.append(det([matrix[row][0:col] + matrix[row][col + 1:] for row in range(1, len(matrix))]) * matrix[0][col] * (1 if col % 2 is 0 else -1)) +def mat(operands: List[Value], operator: Token) -> Process: + # Since mbd creates the matrix, we just want this to return the matrix. + return Process(lambda x: x, operands) - return sum(cofactors) - - -def trans(matrix: List[List[float]]) -> List[List[float]]: - return list(map(list, zip(*matrix))) - - -def cof(matrix: List[List[float]]) -> List[List[float]]: - # TODO: This code is pretty ugly. - - cofactor_matrix = [] - - for row in range(len(matrix)): - cofactor_matrix.append([]) - - for col in range(len(matrix[row])): - minor = copy.deepcopy(matrix) - del minor[row] - - for r in minor: - del r[col] - - cofactor_matrix[row].append(det(minor) * (1 if (row + col) % 2 is 0 else -1)) - - return cofactor_matrix - - -def adj(matrix: List[List[float]]) -> List[List[float]]: - return trans(cof(matrix)) - - -def inv(matrix: List[List[float]]): - multiplier = 1 / det(matrix) - return [[cell * multiplier for cell in row] for row in adj(matrix)] - - -def rref(matrix: List[List[float]]): - mat = copy.deepcopy(matrix) - row = 0 - col = 0 - - def count_leading_zeroes(row): - for i in range(len(row)): - if row[i] != 0: - return i - - return len(row) - - # Sort the matrix by the number of 0s in each row with the most 0s going to the bottom. - mat = sorted(mat, key=count_leading_zeroes) - - # print(mat) - - while row < len(mat) and col < len(mat[row]): - # print(row, mat) - - # If there is a leading 0, move column over but remain on the same row. - if mat[row][col] == 0: - col += 1 - continue - - # Divide each cell in the row by the first cell to ensure that the row starts with a 1. - mat[row] = [cell / mat[row][col] for cell in mat[row]] - - # Multiply all lower rows as needed. - for i in range(row + 1, len(mat)): - multiplier = -mat[i][col] / mat[row][col] - mat[i] = [cell + (mat[row][c] * multiplier) for c, cell in enumerate(mat[i])] - - row += 1 - col += 1 - - row = len(mat) - 1 - col = len(mat[row]) - 1 - - # print('going back up', row, col) - - while row > 0: - # If we have a 0 at this point, we don't need to go back up for this row. - if mat[row][col] == 0: - row -= 1 - col -= 1 - continue - - for i in range(row - 1, -1, -1): - multiplier = -mat[i][col] / mat[row][col] - - # print('multiplier', multiplier) - - mat[i] = [cell + (mat[row][c] * multiplier) for c, cell in enumerate(mat[i])] - - # print('it is now', mat[i]) - - row -= 1 - col -= 1 - - return mat +# The mapping for num, mrw, mbd. +rule_process_value_map = { + 'var': var, + 'num': num, + 'mrw': mrw, + 'mbd': mbd, +} -calc_map = { +# The mapping for all other rules. +rule_process_map = { 'add': add, 'mul': mul, 'pow': pow, 'opr': opr, 'neg': neg, - 'num': num, - 'mrw': mrw, - 'mbd': mbd, - 'mat': mat + 'mat': mat, } diff --git a/tests.py b/tests.py index 0a3d21b..d54185c 100644 --- a/tests.py +++ b/tests.py @@ -5,8 +5,6 @@ import random import unittest -import sympy - from calculator import Calculator @@ -75,13 +73,14 @@ def runTest(self): class ParenthesisTests(unittest.TestCase): def runTest(self): - # self.assertEqual(evaluate('(1 + 2) * 3'), 9.0) - # self.assertEqual(evaluate('(1 + 2) ^ (2 * 3 - 2)'), 81.0) + self.assertEqual(evaluate('(1 + 2) * 3'), 9.0) + self.assertEqual(evaluate('(1 + 2) ^ (2 * 3 - 2)'), 81.0) self.assertEqual(evaluate('2 (1 + 1)'), 4.0) class IdentifierTests(unittest.TestCase): def runTest(self): + pass self.assertEqual(evaluate('r = 10; r'), 10.0) self.assertEqual(round(evaluate('r = 5.2 * (3 + 2 / (1 + 1/6)); pi = 3.14159; area = pi * r^2; area'), 5), 1887.93915) self.assertEqual(round(evaluate('area = pi * r^2; r = 5.2 * (3 + 2 / (1 + 1/6)); pi = 3.14159; area'), 5), 1887.93915) @@ -106,49 +105,52 @@ def runTest(self): class MatrixTests(unittest.TestCase): def runTest(self): - # self.assertEqual(evaluate('[1,2]'), [[1.0, 2.0]]) - # self.assertEqual(evaluate('det([1,2,3|4,5,6|7,8,8])'), 3.0) - # - # self.assertEqual(evaluate('[1,2|4,5]'), [[1.0, 2.0], [4.0, 5.0]]) - # self.assertEqual(evaluate('trans([1,2|4,5])'), [[1.0, 4.0], [2.0, 5.0]]) - # - # self.assertEqual(evaluate('inv([1,4,7|3,0,5|-1,9,11])'), [[45/8, -19/8, -5/2], [19/4, -9/4, -2], [-27/8, 13/8, 3/2]]) - # - # self.assertEqual(evaluate('[1,0,0|0,1,0|0,0,1]'), [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - # - # self.assertEqual(evaluate('[1,0,0,0|0,1,0,0|0,0,1,0|0,0,0,1]'), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) - # self.assertEqual(evaluate('det([1,3,5,7|2,4,6,8|9,7,5,4|8,6,5,9])'), 2.0) - # - # self.assertEqual(evaluate('cof([1,2,3|0,4,5|1,0,6])'), [[24, 5, -4], [-12, 3, 2], [-2, -5, 4]]) - # - # # Since we have floating-point issues, we have to test each value individually. - # calc = evaluate('inv([1,2,3|0,4,5|1,0,6])') - # print(calc) - # ans = [[12/11, -6/11, -1/11], [5/22, 3/22, -5/22], [-2/11, 1/11, 2/11]] - # - # for row in range(len(calc)): - # for col in range(len(calc)): - # self.assertAlmostEqual(calc[row][col], ans[row][col]) + self.assertEqual(evaluate('[1,2]'), [1.0, 2.0]) # TODO: I guess we can now tell the difference between matrices and vectors...is that good? + self.assertEqual(evaluate('det([1,2,3|4,5,6|7,8,8])'), 3.0) - # self.assertEqual(evaluate('rref([1,2,3|4,5,6|7,8,8])'), [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - # self.assertEqual(evaluate('rref([1,2,4|4,7,6|7,1,8])'), [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - self.assertEqual(evaluate('rref([1,2,3|4,5,6|4,5,6])'), [[1, 0, -1], [0, 1, 2], [0, 0, 0]]) - # self.assertEqual(evaluate('rref([1,2,3|4,5,6|7,8,9])'), [[1, 0, -1], [0, 1, 2], [0, 0, 0]]) + self.assertEqual(evaluate('[1,2|4,5]'), [[1.0, 2.0], [4.0, 5.0]]) + self.assertEqual(evaluate('trans([1,2|4,5])'), [[1.0, 4.0], [2.0, 5.0]]) - for r_dim in range(3, 10): - print(r_dim) + self.assertEqual(evaluate('inv([1,4,7|3,0,5|-1,9,11])'), [[45/8, -19/8, -5/2], [19/4, -9/4, -2], [-27/8, 13/8, 3/2]]) - for _ in range(10): - mat = [[random.randint(0, 100) for _ in range(r_dim)] for _ in range(r_dim)] - mat_str = '[' + '|'.join([','.join(map(str, line)) for line in mat]) + ']' + self.assertEqual(evaluate('[1,0,0|0,1,0|0,0,1]'), [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - # print('*****<') - # print(sympy.Matrix(mat)) - # print(sympy.Matrix(evaluate('rref({})'.format(mat_str), False))) - # print(sympy.Matrix(mat).rref()[0]) - # print('>*****') + self.assertEqual(evaluate('[1,0,0,0|0,1,0,0|0,0,1,0|0,0,0,1]'), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) + self.assertEqual(evaluate('det([1,3,5,7|2,4,6,8|9,7,5,4|8,6,5,9])'), 2.0) - self.assertTrue(sympy.Matrix(evaluate('rref({})'.format(mat_str), False)).equals(sympy.Matrix(mat).rref()[0])) + self.assertEqual(evaluate('cof([1,2,3|0,4,5|1,0,6])'), [[24, 5, -4], [-12, 3, 2], [-2, -5, 4]]) + + # Since we have floating-point issues, we have to test each value individually. + calc = evaluate('inv([1,2,3|0,4,5|1,0,6])') + print(calc) + ans = [[12/11, -6/11, -1/11], [5/22, 3/22, -5/22], [-2/11, 1/11, 2/11]] + + for row in range(len(calc)): + for col in range(len(calc)): + self.assertAlmostEqual(calc[row][col], ans[row][col]) + + self.assertEqual(evaluate('identity(3)'), [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + self.assertEqual(evaluate('trnsform([0,0,2|0,3,0|4,0,0])'), [[0, 0, 1], [0, 1, 0], [1, 0, 0]]) + + self.assertEqual(evaluate('rref([1,2,3|4,5,6|7,8,8])'), [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + self.assertEqual(evaluate('rref([1,2,4|4,7,6|7,1,8])'), [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + self.assertEqual(evaluate('rref([1,2,3|4,5,6|4,5,6])'), [[1, 0, -1], [0, 1, 2], [0, 0, 0]]) + self.assertEqual(evaluate('rref([1,2,3|4,5,6|7,8,9])'), [[1, 0, -1], [0, 1, 2], [0, 0, 0]]) + + # for r_dim in range(3, 10): + # print(r_dim) + # + # for _ in range(10): + # mat = [[random.randint(0, 100) for _ in range(r_dim)] for _ in range(r_dim)] + # mat_str = '[' + '|'.join([','.join(map(str, line)) for line in mat]) + ']' + # + # # print('*****<') + # # print(sympy.Matrix(mat)) + # # print(sympy.Matrix(evaluate('rref({})'.format(mat_str), False))) + # # print(sympy.Matrix(mat).rref()[0]) + # # print('>*****') + # + # self.assertTrue(sympy.Matrix(evaluate('rref({})'.format(mat_str), False)).equals(sympy.Matrix(mat).rref()[0])) class RandomTests(unittest.TestCase): diff --git a/vartypes.py b/vartypes.py new file mode 100644 index 0000000..fd7a755 --- /dev/null +++ b/vartypes.py @@ -0,0 +1,293 @@ +from abc import ABCMeta, abstractmethod +from typing import List + +import math + +import copy + +from common import Value, Token + + +class Type(metaclass=ABCMeta): + @staticmethod + @abstractmethod + def new(tokens: List) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def pos(this: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def neg(this: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def add(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def sub(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def mul(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def div(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def mod(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def pow(this: Value, other: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def sqrt(this: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def exp(this: Value) -> Value: + raise Exception('Operation not defined for this type.') + + @staticmethod + def identity(this: Value) -> Value: + raise Exception('Operation not defined for this type.') + + +class Variable(Type): + @staticmethod + def new(tokens: List) -> Value: + return Value(Variable, tokens[0].value) + + +class Number(Type): + @staticmethod + def new(tokens: List[Token]) -> Value: + return Value(Number, float(tokens[0].value)) + + @staticmethod + def pos(this: Value) -> Value: + return Value(Number, this.value) + + @staticmethod + def neg(this: Value) -> Value: + return Value(Number, -this.value) + + @staticmethod + def add(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value + other.value) + + else: + raise Exception('Cannot add {} and {}'.format(this.type, other.type)) + + @staticmethod + def sub(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value - other.value) + + else: + raise Exception('Cannot sub {} and {}'.format(this.type, other.type)) + + @staticmethod + def mul(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value * other.value) + + else: + # TODO: Support for scalar * matrix. + raise Exception('Cannot mul {} and {}'.format(this.type, other.type)) + + @staticmethod + def div(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value / other.value) + + else: + raise Exception('Cannot div {} and {}'.format(this.type, other.type)) + + @staticmethod + def mod(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value % other.value) + + else: + raise Exception('Cannot mod {} and {}'.format(this.type, other.type)) + + @staticmethod + def pow(this: Value, other: Value) -> Value: + if other.type is Number: + return Value(Number, this.value ** other.value) + + else: + raise Exception('Cannot pow {} and {}'.format(this.type, other.type)) + + @staticmethod + def sqrt(this: Value) -> Value: + return Value(Number, math.sqrt(this.value)) + + @staticmethod + def exp(this: Value) -> Value: + return Value(Number, math.exp(this.value)) + + @staticmethod + def identity(this: Value) -> Value: + return Value(Matrix, [[1 if col is row else 0 for col in range(int(this.value))] for row in range(int(this.value))]) + + +class Matrix(Type): + @staticmethod + def new(tokens: List[Value]) -> Value: + return Value(Matrix, list(map(lambda t: t.value, tokens))) + + @staticmethod + def det(matrix: Value) -> Value: + return Value(Number, Matrix._det(matrix.value)) + + @staticmethod + def _det(matrix: List[List[float]]) -> float: + if len(matrix) is 2: + return (matrix[0][0] * matrix[1][1]) - (matrix[0][1] * matrix[1][0]) + + cofactors = [] + + for col in range(len(matrix)): + cofactors.append(Matrix._det([matrix[row][0:col] + matrix[row][col + 1:] for row in range(1, len(matrix))]) * matrix[0][col] * (1 if col % 2 is 0 else -1)) + + return sum(cofactors) + + @staticmethod + def trans(matrix: Value) -> Value: + return Value(Matrix, list(map(list, zip(*matrix.value)))) + + @staticmethod + def cof(matrix: Value) -> Value: + # TODO: This code is pretty ugly. + + cofactor_matrix = [] + mat = matrix.value + + for row in range(len(mat)): + cofactor_matrix.append([]) + + for col in range(len(mat[row])): + minor = copy.deepcopy(mat) + del minor[row] + + for r in minor: + del r[col] + + cofactor_matrix[row].append(Matrix._det(minor) * (1 if (row + col) % 2 is 0 else -1)) + + return Value(Matrix, cofactor_matrix) + + @staticmethod + def adj(matrix: Value) -> Value: + return Matrix.trans(Matrix.cof(matrix)) + + @staticmethod + def inv(matrix: Value) -> Value: + multiplier = 1 / Matrix._det(matrix.value) + return Value(Matrix, [[cell * multiplier for cell in row] for row in Matrix.adj(matrix).value]) + + @staticmethod + def trnsform(matrix) -> Value: + # Returns the transformation matrix which, when multiplied by the original matrix, will give the rref form of the original matrix. + mat = copy.deepcopy(matrix.value) + ident = Number.identity(Value(Number, len(mat))).value + row = 0 + col = 0 + + def count_leading_zeroes(row): + for i in range(len(row)): + if row[i] != 0: + return i + + return len(row) + + while row < len(mat) - 1: + if count_leading_zeroes(mat[row]) > count_leading_zeroes(mat[row + 1]): + mat[row], mat[row + 1] = mat[row + 1], mat[row] + ident[row], ident[row + 1] = ident[row + 1], ident[row] + row = 0 + + else: + row += 1 + + row = 0 + + return Value(Matrix, ident) + + @staticmethod + def rref(matrix: Value) -> Value: + mat = copy.deepcopy(matrix.value) + row = 0 + col = 0 + + def count_leading_zeroes(row): + for i in range(len(row)): + if row[i] != 0: + return i + + return len(row) + + # Sort the matrix by the number of 0s in each row with the most 0s going to the bottom. + mat = sorted(mat, key=count_leading_zeroes) + + # print(mat) + + while row < len(mat) and col < len(mat[row]): + # print(row, mat) + + # If there is a leading 0, move column over but remain on the same row. + if mat[row][col] == 0: + col += 1 + continue + + # Divide each cell in the row by the first cell to ensure that the row starts with a 1. + mat[row] = [cell / mat[row][col] for cell in mat[row]] + + # Multiply all lower rows as needed. + for i in range(row + 1, len(mat)): + multiplier = -mat[i][col] / mat[row][col] + mat[i] = [cell + (mat[row][c] * multiplier) for c, cell in enumerate(mat[i])] + + row += 1 + col += 1 + + row = len(mat) - 1 + col = len(mat[row]) - 1 + + # print('going back up', row, col) + + while row > 0: + # If we have a 0 at this point, we don't need to go back up for this row. + if mat[row][col] == 0: + row -= 1 + col -= 1 + continue + + for i in range(row - 1, -1, -1): + multiplier = -mat[i][col] / mat[row][col] + + # print('multiplier', multiplier) + + mat[i] = [cell + (mat[row][c] * multiplier) for c, cell in enumerate(mat[i])] + + # print('it is now', mat[i]) + + row -= 1 + col -= 1 + + return Value(Matrix, mat) + + +class MatrixRow(Type): + @staticmethod + def new(tokens: List[Value]) -> Value: + return Value(MatrixRow, list(map(lambda t: t.value, tokens)))