From 09ce256ec8ac0a78246d875be956c93654310f3b Mon Sep 17 00:00:00 2001 From: nrubin29 Date: Sat, 21 Oct 2017 18:25:53 -0400 Subject: [PATCH] Added support for multiplying matrices by numbers and matrices. Fixed an incorrect rule. --- common.py | 4 ++-- tests.py | 4 ++-- vartypes.py | 23 ++++++++++++++++++++++- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/common.py b/common.py index ffe7af5..55ac98b 100644 --- a/common.py +++ b/common.py @@ -93,7 +93,7 @@ def key_at(self, i): rules_map = IndexedOrderedDict(( - ('asn', ('IDT EQL add',)), + ('asn', ('IDT EQL mat',)), ('mat', ('LBR mbd RBR',)), ('mbd', ('mrw PPE mbd',)), ('mrw', ('add CMA mrw',)), @@ -109,7 +109,7 @@ def key_at(self, i): # This helps a little bit, but not much. # Perhaps construct a graph (single-linked list) (asn -> mat -> mbd, etc.) and pass nodes in evaluate(). -rules_map = ImmutableIndexedDict(['asn', 'mat', 'mbd', 'mrw', 'add', 'mui', 'mul', 'pow', 'opr', 'neg', 'var', 'num'], dict((('asn', ('IDT EQL add',)),('mat', ('LBR mbd RBR',)),('mbd', ('mrw PPE mbd',)),('mrw', ('add CMA mrw',)),('add', ('mul ADD add',)),('mui', ('pow mul',)),('mul', ('pow MUL mul',)),('pow', ('opr POW pow',)),('opr', ('OPR LPA mat RPA',)),('neg', ('ADD num', 'ADD opr')),('var', ('IDT',)),('num', ('NUM', 'LPA add RPA')),))) +rules_map = ImmutableIndexedDict(['asn', 'mat', 'mbd', 'mrw', 'add', 'mui', 'mul', 'pow', 'opr', 'neg', 'var', 'num'], dict((('asn', ('IDT EQL mat',)),('mat', ('LBR mbd RBR',)),('mbd', ('mrw PPE mbd',)),('mrw', ('add CMA mrw',)),('add', ('mul ADD add',)),('mui', ('pow mul',)),('mul', ('pow MUL mul',)),('pow', ('opr POW pow',)),('opr', ('OPR LPA mat RPA',)),('neg', ('ADD num', 'ADD opr')),('var', ('IDT',)),('num', ('NUM', 'LPA add RPA')),))) left_assoc = { diff --git a/tests.py b/tests.py index e29635b..dcf51b6 100644 --- a/tests.py +++ b/tests.py @@ -111,12 +111,12 @@ def runTest(self): rnd = lambda e: round(e, 5) more_zeroes = lambda n: n if n <= 75 else 0 - for r_dim in range(3, 10): + for r_dim in range(3, 4): print(r_dim) self.assertTrue(sympy.Matrix(evaluate('identity({})'.format(r_dim), False)).equals(sympy.Identity(r_dim))) - for _ in range(50): + for _ in range(5): mat = [[more_zeroes(random.randint(0, 100)) for _ in range(r_dim)] for _ in range(r_dim)] mat_str = '[' + '|'.join([','.join(map(str, line)) for line in mat]) + ']' diff --git a/vartypes.py b/vartypes.py index f9bc376..d368c54 100644 --- a/vartypes.py +++ b/vartypes.py @@ -100,8 +100,10 @@ def mul(this: Value, other: Value) -> Value: if other.type is Number: return Value(Number, this.value * other.value) + elif other.type is Matrix: + return Matrix.mul(other, this) + else: - # TODO: Support for scalar * matrix. raise Exception('Cannot mul {} and {}'.format(this.type, other.type)) @staticmethod @@ -146,6 +148,25 @@ class Matrix(Type): def new(tokens: List[Value]) -> Value: return Value(Matrix, list(map(lambda t: t.value, tokens))) + @staticmethod + def mul(this: Value, other: Value): + if other.type == Number: + # Number * Matrix + return Value(Matrix, [[cell * other.value for cell in row] for row in other.value]) + + elif other.type == Matrix: + # Matrix * Matrix + result = [[0 for _ in range(len(this.value))] for _ in range(len(other.value[0]))] + + for i in range(len(this.value)): + for j in range(len(other.value[0])): + for k in range(len(other.value)): + result[i][j] += this.value[i][k] * other.value[k][j] + + return Value(Matrix, result) + + raise Exception('Cannot mul {} and {}'.format(this.type, other.type)) + @staticmethod def det(matrix: Value) -> Value: return Value(Number, Matrix._det(matrix.value))