forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_speed_v_torch.py
286 lines (236 loc) · 11.4 KB
/
test_speed_v_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import unittest
import torch
torch.set_num_threads(1)
import time
import numpy as np
np.set_printoptions(linewidth=160)
from tinygrad import Device, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI
from tinygrad.features.jit import TinyJit
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
torch_dt = torch.float16 if getenv("HALF", 0) else torch.float32
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
if str(torch_device) == "mps":
import torch.mps
def sync(): torch.mps.synchronize()
elif str(torch_device) == "cuda":
import torch.cuda
def sync(): torch.cuda.synchronize()
else:
def sync(): pass
def colorize_float(x):
ret = f"{x:7.2f}x"
if x < 0.75:
return colored(ret, 'green')
elif x > 1.15:
return colored(ret, 'red')
else:
return colored(ret, 'yellow')
save_ops, save_mem = 0, 0
CNT = getenv("CNT", 8)
def helper_test_speed(f1, *args):
global save_ops, save_mem
ets = []
ret = None
cache_defeat = np.zeros((2048,2048))
for i in range(CNT):
del ret
# operation cache defeats
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
# force syncing
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
# clear 32MB global memory cache (CPU and global memory only)
cache_defeat += 1
# manual pre sync
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
else: sync()
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
st = time.perf_counter()
ret = f1(*args)
if isinstance(ret, Tensor): Device[ret.device].synchronize()
else: sync()
et = (time.perf_counter() - st) * 1000
if i >= 1: ets.append(et)
if GlobalCounters.global_ops:
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
return ret.numpy() if isinstance(ret, Tensor) else ret.cpu().numpy(), np.min(ets)
def helper_test_generic_square(name, N, f1, f2, onearg=False):
torch.manual_seed(0)
torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device)
torch_b = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) if not onearg else None
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
def helper_test_matvec(name, N, M):
torch.manual_seed(0)
torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device)
torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy())
helper_test_generic(f"{name:30s} {N:5d}x{M:5d}", lambda a,b: a@b, (torch_a, torch_b), TinyJit(lambda a,b:(a@b).realize()), (tiny_a, tiny_b))
prefix = None
def helper_test_generic(name, f1, f1_args, f2, f2_args):
global prefix
with torch.no_grad():
val_torch, et_torch = helper_test_speed(f1, *f1_args)
val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args)
desc = "faster" if et_torch > et_tinygrad else "slower"
flops = save_ops*1e-6
mem = save_mem*1e-6
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
torch.manual_seed(0)
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(torch_dat): return torch_conv(torch_dat)
def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
@unittest.skipIf(getenv("BIG") == 0, "no big tests")
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
class TestBigSpeed(unittest.TestCase):
def test_add(self):
def f(a, b): return a+b
helper_test_generic_square('add', 8192, f, f)
def test_exp(self):
def f(a, b): return a.exp()
helper_test_generic_square('exp', 8192, f, f, onearg=True)
def test_gemm_2048(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 2048, f, f)
def test_gemm_4096(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 4096, f, f)
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=132, img_size_x=132)
def test_matvec_4096_16384(self): helper_test_matvec('matvec_4096_16384', 4096, 16384)
def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
class TestSpeed(unittest.TestCase):
def test_sub(self):
def f(a, b): return a-b
helper_test_generic_square('sub', 4096, f, f)
@unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "breaking on webgpu CI")
def test_pow(self):
def f(a, b): return a.pow(b)
helper_test_generic_square('pow', 2048, f, f)
def test_sum(self):
def f(a, b): return a.sum()
helper_test_generic_square('sum', 2048, f, f, onearg=True)
helper_test_generic_square('sum', 4096, f, f, onearg=True)
def test_partial_sum(self):
R = 256
def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
helper_test_generic_square('partial_sum', 4096, f, f, onearg=True)
@unittest.skip("not really used in models")
def test_cumsum(self):
def f0(a, b): return a.cumsum(axis=0)
def f1(a, b): return a.cumsum(axis=1)
helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True)
helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True)
def test_cat(self):
helper_test_generic_square('cat_0', 256, lambda x,y: torch.cat((x,y),dim=0), lambda x,y: x.cat(y,dim=0))
helper_test_generic_square('cat_1', 256, lambda x,y: torch.cat((x,y),dim=1), lambda x,y: x.cat(y,dim=1))
def test_array_packing(self):
N = 2048
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
helper_test_generic_square('array_packing', N, f, f, onearg=True)
def test_permute(self):
for N in [1024, 4096]:
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f(a, b): return a.permute(1,0).contiguous()
helper_test_generic_square('permute', N, f, f, onearg=True)
def test_double_permute(self):
N = 64
torch.manual_seed(0)
torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
def f(a): return a.permute(1,0,3,2).contiguous()
helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
def test_neg(self):
def f(a, b): return -a
helper_test_generic_square('neg', 4096, f, f, onearg=True)
def test_exp(self):
def f(a, b): return a.exp()
helper_test_generic_square('exp', 2048, f, f, onearg=True)
def test_relu(self):
def f(a, b): return a.relu()
helper_test_generic_square('relu', 4096, f, f, onearg=True)
def test_max(self):
def f(a, b): return a.max()
helper_test_generic_square('max', 4096, f, f, onearg=True)
def test_mul_sum(self):
def f(a, b): return (a*b).sum()
helper_test_generic_square('mul_sum', 4096, f, f)
def test_add(self):
for N in [1, 1024, 4096]:
def f(a, b): return a + b
helper_test_generic_square('add', N, f, f)
def test_add_constant(self):
def f(a, b): return a+2.0
helper_test_generic_square('add_constant', 4096, f, f, onearg=True)
def test_add_sq(self):
def f(a, b): return a*a + b*b
helper_test_generic_square('add_sq', 4096, f, f)
def test_gemm(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 1024, f, f)
def test_gemm_small(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 256, f, f)
def test_gemm_unrolled(self):
N = 512
def f1(a, b): return [email protected]
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled', N, f1, f2)
def test_gemm_unrolled_permute_l(self):
N = 512
def f1(a, b): return [email protected]
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_l', N, f1, f2)
def test_gemm_unrolled_permute_r(self):
N = 512
def f1(a, b): return a@b
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_r', N, f1, f2)
def test_gemm_unrolled_permute_lr(self):
N = 512
def f1(a, b): return a.T@b
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
def test_matvec_1024_1024(self): helper_test_matvec('matvec_1024_1024', 1024, 1024)
def test_matvec_1024_4096(self): helper_test_matvec('matvec_1024_4096', 1024, 4096)
def test_matvec_4096_1024(self): helper_test_matvec('matvec_4096_1024', 4096, 1024)
def test_matvec_4096_4096(self): helper_test_matvec('matvec_4096_4096', 4096, 4096)
def test_openpilot_conv2d(self):
bs, in_chans, out_chans = 1,12,32
torch.manual_seed(0)
torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2))
def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:3", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
def test_conv2d(self):
for bs in [32]:
for in_chans in IN_CHANS:
for out_chans in [32]:
helper_test_conv(bs, in_chans, out_chans, 3, 34, 34)
if __name__ == '__main__':
unittest.main()