Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All tests pass #9

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ target/
# Temporary data
.ipynb_checkpoints/

extract_functions.py
117 changes: 117 additions & 0 deletions matrix_math.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,119 @@
import functools
import math


class ShapeException(Exception):
pass


def shape(array):
"""Take a vector or matrix and return a tuple with the
number of rows (for a vector) or the number of rows and columns
(for a matrix.)"""
try:
return (len(array), len(array[0]))
except TypeError:
return (len(array),)


def vector_walk(x, y, op=sum, filter=lambda x_, y_: True):
if shape(x) != shape(y):
raise ShapeException
try:
return [op([x_val, y[idx_r][idx_c]])
for idx_r, row in enumerate(x)
for idx_c, x_val in row
if filter(idx_r, idx_c)
]
except TypeError:
return [op([x_val, y[idx_r]]) for idx_r, x_val in enumerate(x)]


def sub(a_list):
if len(a_list) != 2:
raise ShapeException
return a_list[0] - a_list[1]


def times(a_list):
if len(a_list) != 2:
raise ShapeException
return a_list[0] * a_list[1]


def is_equal(idx_x, idx_y):
return idx_x == idx_y


def vector_add(x, y):
return vector_walk(x, y, op=sum)


def vector_sub(x, y):
return vector_walk(x, y, op=sub)


def vector_sum(*vectors):
return functools.reduce(vector_add, vectors)


def dot(x, y):
return sum(vector_walk(x, y, op=times, filter=is_equal))


def vector_multiply(x, scalar):
scalar_matrix = vector_walk(x, x, op=lambda x_: scalar)
return vector_walk(x, scalar_matrix, op=times)


def vector_mean(*vectors):
sum_vector = vector_sum(*vectors)
n = len(vectors)
return vector_multiply(sum_vector, 1 / n)


def magnitude(x):
return math.sqrt(dot(x, x))


def matrix_row(x, n):
return x[n]


def matrix_col(x, n):
return [val for row in x for idx, val in enumerate(row) if idx == n]


def matrix_cols(x):
for col in [val for row in x for idx, val in enumerate(row)]:
yield col


def matrix_scalar_multiply(matrix, scalar):
return [[i*scalar for i in row] for row in matrix]


def matrix_vector_multiply(matrix, vector):
if shape(matrix)[1] != shape(vector)[0]:
raise ShapeException

step1 = [[val * vector[idx] for idx, val in enumerate(row)]
for row in matrix]
return [sum(x) for x in step1]


def matrix_matrix_multiply(x, y):
if shape(x)[1] != shape(y)[0]:
raise ShapeException

y_transposed = [[row[i] for row in y] for i in range(len(y[0]))]

return [[dot(row, col) for col in y_transposed] for row in x]

# return [[dot(matrix_row(x, j), matrix_row(y_transposed, i))
# for i, val in enumerate(row)]
# for j, row in enumerate(x)]

# return [[dot(row, val) for i,val in enumerate(matrix_cols(y))
# if j < len(y)]
# for j, row in enumerate(x)]
2 changes: 1 addition & 1 deletion test_matrix_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_matrix_matrix_multiply():

Matrix * Matrix = Matrix
"""
assert matrix_matrix_multiply(A, B) == A
assert matrix_matrix_multiply(A, B) == B
assert matrix_matrix_multiply(B, C) == [[8, 10],
[20, 25],
[32, 40]]
Expand Down