-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_symbolic_shapetracker.py
173 lines (151 loc) · 6.26 KB
/
test_symbolic_shapetracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import unittest
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
def test_symbolic_st(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
def test_expr_idxs(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((x*3)+y)"
assert e2.render() == "1"
st = st.permute((1, 0))
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)
def test_cat_dim1_strides(self):
i = Variable("i", 1, 5).bind(4)
j = Variable("j", 1, 5).bind(4)
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
class TestSymbolicVarVals(unittest.TestCase):
def test_var_vals_empty(self):
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
def test_var_vals_shape(self):
x = Variable("x", 1, 100).bind(3)
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_offset(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
assert st.real_offset() == x * 3
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_mask(self):
x = Variable("x", 1, 100).bind(3)
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
st = ShapeTracker(views=(view,))
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_complex(self):
x = Variable("x", 1, 100).bind(3)
y = Variable("y", 1, 100).bind(4)
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
assert st.real_offset() == y * z
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
def test_shrink_reshape(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
st = st.reshape((3*4*3,))
assert st.var_vals == {Variable("x", 1, 100): 3}
class TestShapeTrackerUnbind(unittest.TestCase):
def test_view_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4))
def test_reshape_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
t = Tensor.rand(3, 4).reshape(bv, 4)
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),))
def test_shrink_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(2)
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
assert t.shape == (vi, 2, 3)
def test_reshape_symbols_reshape_ints(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
t = t.reshape(i, 4)
assert t.shape == (i, 4)
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
def test_two_symbol_reshape(self):
for i in range(1, 6):
for j in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
vj = Variable("j", 1, 5).bind(j)
t = Tensor.rand(i, j).reshape(vi, vj)
assert t.shape == (vi, vj)
# NOTE: this is currently not allowed
# t = t.reshape(1, vi*vj)
# assert t.shape == (1, vi*vj)
t = t.reshape(vj, vi)
assert t.shape == (vj, vi)
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
# TODO: enfore expand only into bound variables
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
def test_plus_expands_constant(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
assert t.shape == (2, 1)
class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self):
# taken from symbolic shape llama
i = Variable("i", 1, 120)
gidx0 = Variable("gidx0", 0, i)
lidx1 = Variable("lidx1", 0, 7)
idx = (gidx0, lidx1, Variable.num(1))
shape = (i+1, 8, 4)
strides = (1, (i*4)+4, i+1)
st = ShapeTracker((View.create(shape, strides), ))
idx, valid = st.expr_idxs(idx)
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
if __name__ == '__main__':
unittest.main()