-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorched_strassen_matmul.py
144 lines (139 loc) · 5.29 KB
/
pytorched_strassen_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Title: Combining Assembly Language Powered PyTorch MatMul with Strassen MatMul Method
# Author: Ajay Khanna |Git: Ajaykhanna|Twitter: @samdig| LinkedIn: ajay-khanna|
# Date: Feb.20.2020
# Place: UC Merced
# Lab: Dr. Isborn
# Architecture: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz/32.0GB
# OS: Windows 11 21H2
# Python: 3.9.7
#------------------------------------------------------
# Importing Time & PyTorch
import time
import torch
import numpy as np
#------------------------------------------------------
#
# Method #1: Using Pure Python
# Matrix Multiplication in Pure Python but with PyTorch
# Function for Matrix Multiplication
def matmul(a,b):
a_rows, a_cols = a.shape
b_rows, b_cols = b.shape
assert a_cols == b_rows
c = torch.zeros(a_rows, b_cols)
# For loop for Element-wise Operation
#
for i in range(a_rows):
for j in range(b_cols):
for k in range(a_cols):
c[i, j] += a[i, k] * b[k, j]
return c
#------------------------------------------------------
#
def matmul_pytorch(a, b):
c = a.matmul(b)
return c
#------------------------------------------------------
#
# Building Strassen Algorithm
# Currently Applicable to Only Square Matrices
# With Total Matrix Elements = 2^n, n >=1
# To Apply This Method to Non-Square Pad Each
# Row and Column with Zeros and Convert them
# to Square Martices
# Split Square Matrix into Even Number of Sub-Square Matrices
def split(matrix):
n = len(matrix)
return matrix[:n//2, :n//2], matrix[:n//2, n//2:], matrix[n//2:, :n//2], matrix[n//2:, n//2:]
#------------------------------------------------------
#
# Stressen Alogrithm with Standard Matrix Multiplication Method
def standard_stressen(A, B):
if len(A) <= 2:
return matmul(A, B)
a, b, c, d = split(A)
e, f, g, h = split(B)
p1 = standard_stressen(a+d, e+h)
p2 = standard_stressen(d, g-e)
p3 = standard_stressen(a+b, h)
p4 = standard_stressen(b-d, g+h)
p5 = standard_stressen(a, f-h)
p6 = standard_stressen(c+d, e)
p7 = standard_stressen(a-c, e+f)
C11 = p1 + p2 - p3 + p4
C12 = p5 + p3
C21 = p6 + p2
C22 = p5 + p1 - p6 - p7
C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
return C
#------------------------------------------------------
#
# Pytorched Strassen Algorithm
def pytorched_strassen(A, B):
if len(A) <= 2:
return matmul_pytorch(A, B)
a, b, c, d = split(A)
e, f, g, h = split(B)
p1 = pytorched_strassen(a+d, e+h)
p2 = pytorched_strassen(d, g-e)
p3 = pytorched_strassen(a+b, h)
p4 = pytorched_strassen(b-d, g+h)
p5 = pytorched_strassen(a, f-h)
p6 = pytorched_strassen(c+d, e)
p7 = pytorched_strassen(a-c, e+f)
C11 = p1 + p2 - p3 + p4
C12 = p5 + p3
C21 = p6 + p2
C22 = p5 + p1 - p6 - p7
C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
return C
#------------------------------------------------------
#
# Generating Random Matrix
# Matrix Elements should be = 2^n, n >= 1
n = 10 # Power: To Create Matrix Elements
a_matrix = torch.rand(2**n, 2**n)
b_matrix = torch.rand(2**n, 2**n)
#
# Printing the Result of Random Matrix Multiplication
tic = time.perf_counter()
print('Matrix Multiplication with Standard Method: \n',
matmul(a_matrix, b_matrix), '\n')
toc = time.perf_counter()
print(f"Standard Method Took {toc - tic:0.4e} seconds")
#
tic = time.perf_counter()
print('Matrix Multiplication with Standard Strassen Method: \n',
standard_stressen(a_matrix, b_matrix), '\n')
toc = time.perf_counter()
print(f"Standard Stressen Method Took {toc - tic:0.4e} seconds")
#
tic = time.perf_counter()
print('Matrix Multiplication with PyTorced Strassen Method \n',
pytorched_strassen(a_matrix, b_matrix), '\n')
toc = time.perf_counter()
print(f"PyTorched Stressen Method took {toc - tic:0.4e} seconds")
tic = time.perf_counter()
print('Matrix Multiplication with PyTorced MatMul Method \n',
matmul_pytorch(a_matrix, b_matrix), '\n')
toc = time.perf_counter()
print(f"PyTorch MatMul Method took {toc - tic:0.4e} seconds")
#------------------------------------------------------
# Results
# Method #1: Standard Method took: 14286.00 seconds
# Method #2: Standard Method with Stressen took: 5212.60 seconds (x2.74 Faster)
# Method #3: PyTorched Stressen Method took: 761.03 seconds (x18.77 & x6.84 Faster)
# Method #4: PyTorch MatMul Method took: 0.02 seconds (x638K & x233K Faster)
# Conclusion
# For a Matrix of 1000 X 1000 dimensions and random matrix elements,
# the Standard/Brute force method is not ideal. Standard Strassen
# algorithm combined with brute force matrix multiplication method
# is 2.74 times faster than standard matmul method.
# Strassen method combined with PyTorched Assemble language powered
# method is 18.77 times faster than the traditional method and 6.84
# times faster than the standard Strassen method. Although the total
# time consumed decreases with the improved method, PyTorch Assemble
# language powered MatMul is ~638K times faster than the standard method,
# 233K times faster than standard Strassen, ~34K times faster than
# PyTorched Strassen Method.
# In Short: Don't Mess with PyTorch MatMul