Skip to content

Commit

Permalink
Added support for multiplying matrices by numbers and matrices.
Browse files Browse the repository at this point in the history
Fixed an incorrect rule.
  • Loading branch information
nrubin29 committed Oct 21, 2017
1 parent 9f0e176 commit 09ce256
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def key_at(self, i):


rules_map = IndexedOrderedDict((
('asn', ('IDT EQL add',)),
('asn', ('IDT EQL mat',)),
('mat', ('LBR mbd RBR',)),
('mbd', ('mrw PPE mbd',)),
('mrw', ('add CMA mrw',)),
Expand All @@ -109,7 +109,7 @@ def key_at(self, i):

# This helps a little bit, but not much.
# Perhaps construct a graph (single-linked list) (asn -> mat -> mbd, etc.) and pass nodes in evaluate().
rules_map = ImmutableIndexedDict(['asn', 'mat', 'mbd', 'mrw', 'add', 'mui', 'mul', 'pow', 'opr', 'neg', 'var', 'num'], dict((('asn', ('IDT EQL add',)),('mat', ('LBR mbd RBR',)),('mbd', ('mrw PPE mbd',)),('mrw', ('add CMA mrw',)),('add', ('mul ADD add',)),('mui', ('pow mul',)),('mul', ('pow MUL mul',)),('pow', ('opr POW pow',)),('opr', ('OPR LPA mat RPA',)),('neg', ('ADD num', 'ADD opr')),('var', ('IDT',)),('num', ('NUM', 'LPA add RPA')),)))
rules_map = ImmutableIndexedDict(['asn', 'mat', 'mbd', 'mrw', 'add', 'mui', 'mul', 'pow', 'opr', 'neg', 'var', 'num'], dict((('asn', ('IDT EQL mat',)),('mat', ('LBR mbd RBR',)),('mbd', ('mrw PPE mbd',)),('mrw', ('add CMA mrw',)),('add', ('mul ADD add',)),('mui', ('pow mul',)),('mul', ('pow MUL mul',)),('pow', ('opr POW pow',)),('opr', ('OPR LPA mat RPA',)),('neg', ('ADD num', 'ADD opr')),('var', ('IDT',)),('num', ('NUM', 'LPA add RPA')),)))


left_assoc = {
Expand Down
4 changes: 2 additions & 2 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def runTest(self):
rnd = lambda e: round(e, 5)
more_zeroes = lambda n: n if n <= 75 else 0

for r_dim in range(3, 10):
for r_dim in range(3, 4):
print(r_dim)

self.assertTrue(sympy.Matrix(evaluate('identity({})'.format(r_dim), False)).equals(sympy.Identity(r_dim)))

for _ in range(50):
for _ in range(5):
mat = [[more_zeroes(random.randint(0, 100)) for _ in range(r_dim)] for _ in range(r_dim)]
mat_str = '[' + '|'.join([','.join(map(str, line)) for line in mat]) + ']'

Expand Down
23 changes: 22 additions & 1 deletion vartypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def mul(this: Value, other: Value) -> Value:
if other.type is Number:
return Value(Number, this.value * other.value)

elif other.type is Matrix:
return Matrix.mul(other, this)

else:
# TODO: Support for scalar * matrix.
raise Exception('Cannot mul {} and {}'.format(this.type, other.type))

@staticmethod
Expand Down Expand Up @@ -146,6 +148,25 @@ class Matrix(Type):
def new(tokens: List[Value]) -> Value:
return Value(Matrix, list(map(lambda t: t.value, tokens)))

@staticmethod
def mul(this: Value, other: Value):
if other.type == Number:
# Number * Matrix
return Value(Matrix, [[cell * other.value for cell in row] for row in other.value])

elif other.type == Matrix:
# Matrix * Matrix
result = [[0 for _ in range(len(this.value))] for _ in range(len(other.value[0]))]

for i in range(len(this.value)):
for j in range(len(other.value[0])):
for k in range(len(other.value)):
result[i][j] += this.value[i][k] * other.value[k][j]

return Value(Matrix, result)

raise Exception('Cannot mul {} and {}'.format(this.type, other.type))

@staticmethod
def det(matrix: Value) -> Value:
return Value(Number, Matrix._det(matrix.value))
Expand Down

0 comments on commit 09ce256

Please sign in to comment.