diff --git a/common.py b/common.py index a6e8543..5b8731b 100644 --- a/common.py +++ b/common.py @@ -28,6 +28,10 @@ def __repr__(self): (r'sqrt', 'OPR'), (r'exp', 'OPR'), (r'det', 'OPR'), + (r'adj', 'OPR'), + (r'trans', 'OPR'), + (r'cof', 'OPR'), + (r'inv', 'OPR'), (r'[a-zA-Z_]+', 'IDT'), (r'=', 'EQL'), (r'\+', 'ADD'), diff --git a/rules.py b/rules.py index 8b2b12e..f4d29f6 100644 --- a/rules.py +++ b/rules.py @@ -1,6 +1,7 @@ """ This file contains methods to handle calculation of all the rules. """ +import copy import math import operator from typing import List, Union @@ -21,7 +22,7 @@ def pow(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: def opr(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: - return Value(Type.Number, {'sqrt': math.sqrt, 'exp': math.exp, 'det': det}[tokens[0].value](tokens[1].value.value)) + return Value(Type.Number, {'sqrt': math.sqrt, 'exp': math.exp, 'det': det, 'trans': trans, 'cof': cof, 'adj': adj, 'inv': inv}[tokens[0].value](tokens[1].value.value)) def neg(tokens: List[Union[Token, Value, RuleMatch]]) -> Value: @@ -59,6 +60,39 @@ def det(matrix: List[List[float]]) -> float: return sum(cofactors) +def trans(matrix: List[List[float]]) -> List[List[float]]: + return list(map(list, zip(*matrix))) + + +def cof(matrix: List[List[float]]) -> List[List[float]]: + # TODO: This code is pretty ugly. + + cofactor_matrix = [] + + for row in range(len(matrix)): + cofactor_matrix.append([]) + + for col in range(len(matrix[row])): + minor = copy.deepcopy(matrix) + del minor[row] + + for r in minor: + del r[col] + + cofactor_matrix[row].append(det(minor) * (1 if (row + col) % 2 is 0 else -1)) + + return cofactor_matrix + + +def adj(matrix: List[List[float]]) -> List[List[float]]: + return trans(cof(matrix)) + + +def inv(matrix: List[List[float]]): + multiplier = 1 / det(matrix) + return [[cell * multiplier for cell in row] for row in adj(matrix)] + + calc_map = { 'add': add, 'mul': mul, diff --git a/tests.py b/tests.py index a0141ae..456b1d1 100644 --- a/tests.py +++ b/tests.py @@ -103,11 +103,25 @@ def runTest(self): self.assertEqual(evaluate('det([1,2,3|4,5,6|7,8,8])'), 3.0) self.assertEqual(evaluate('[1,2|4,5]'), [[1.0, 2.0], [4.0, 5.0]]) + self.assertEqual(evaluate('trans([1,2|4,5])'), [[1.0, 4.0], [2.0, 5.0]]) + + self.assertEqual(evaluate('inv([1,4,7|3,0,5|-1,9,11])'), [[45/8, -19/8, -5/2], [19/4, -9/4, -2], [-27/8, 13/8, 3/2]]) + self.assertEqual(evaluate('[1,0,0|0,1,0|0,0,1]'), [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) self.assertEqual(evaluate('[1,0,0,0|0,1,0,0|0,0,1,0|0,0,0,1]'), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) self.assertEqual(evaluate('det([1,3,5,7|2,4,6,8|9,7,5,4|8,6,5,9])'), 2.0) + self.assertEqual(evaluate('cof([1,2,3|0,4,5|1,0,6])'), [[24, 5, -4], [-12, 3, 2], [-2, -5, 4]]) + + # Since we have floating-point issues, we have to test each value individually. + calc = evaluate('inv([1,2,3|0,4,5|1,0,6])') + ans = [[12/11, -6/11, -1/11], [5/22, 3/22, -5/22], [-2/11, 1/11, 2/11]] + + for row in range(len(calc)): + for col in range(len(calc)): + self.assertAlmostEqual(calc[row][col], ans[row][col]) + # for _ in range(10): # mat = [[random.randint(1, 100) for _ in range(10)] for _ in range(10)] # mat_str = '[' + '|'.join([','.join(map(str, line)) for line in mat]) + ']'