forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_utility_funs.py
222 lines (187 loc) · 8.85 KB
/
test_utility_funs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from __future__ import absolute_import, division, print_function, unicode_literals
from test_pytorch_common import TestCase, run_tests
import torch
import torch.onnx
from torch.onnx import utils
from torch.onnx.symbolic_helper import _set_opset_version
import onnx
import io
import copy
class TestUtilityFuns(TestCase):
opset_version = 9
def test_is_in_onnx_export(self):
test_self = self
class MyModule(torch.nn.Module):
def forward(self, x):
test_self.assertTrue(torch.onnx.is_in_onnx_export())
raise ValueError
return x + 1
x = torch.randn(3, 4)
f = io.BytesIO()
try:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
except ValueError:
self.assertFalse(torch.onnx.is_in_onnx_export())
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.transpose(a, 1, 0)
return b + x
_set_opset_version(self.opset_version)
x = torch.ones(3, 2)
graph, _, __ = utils._model_to_graph(TransposeModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Transpose"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_slice(self):
class NarrowModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.narrow(a, 0, 0, 1)
return b + x
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(NarrowModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_slice_index_exceeds_dim(self):
class SliceIndexExceedsDimModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = a[1:10] # index exceeds dimension
return b + x
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(SliceIndexExceedsDimModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_slice_negative_index(self):
class SliceNegativeIndexModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = a[0:-1] # index relative to the end
return b + x
_set_opset_version(self.opset_version)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(SliceNegativeIndexModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.unsqueeze(a, 0)
return b + x
_set_opset_version(self.opset_version)
x = torch.ones(1, 2, 3)
graph, _, __ = utils._model_to_graph(UnsqueezeModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Unsqueeeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_concat(self):
class ConcatModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.]])
b = torch.tensor([[4., 5., 6.]])
c = torch.cat((a, b), 0)
return b + c
_set_opset_version(self.opset_version)
x = torch.ones(2, 3)
graph, _, __ = utils._model_to_graph(ConcatModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_lstm(self):
class GruNet(torch.nn.Module):
def __init__(self):
super(GruNet, self).__init__()
self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
def forward(self, input, initial_state):
return self.mygru(input, initial_state)
_set_opset_version(self.opset_version)
input = torch.randn(5, 3, 7)
h0 = torch.randn(1, 3, 3)
graph, _, __ = utils._model_to_graph(GruNet(), (input, h0),
do_constant_folding=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Unsqueeze"
assert len(list(graph.nodes())) == 3
def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
def __init__(self):
super(MatMulNet, self).__init__()
self.B = torch.nn.Parameter(torch.ones(5, 3))
def forward(self, A):
return torch.matmul(A, torch.transpose(self.B, -1, -2))
_set_opset_version(self.opset_version)
A = torch.randn(2, 3)
graph, _, __ = utils._model_to_graph(MatMulNet(), (A),
do_constant_folding=True)
for node in graph.nodes():
assert node.kind() != "onnx::Transpose"
assert len(list(graph.nodes())) == 1
def test_strip_doc_string(self):
class MyModule(torch.nn.Module):
def forward(self, input):
return torch.exp(input)
x = torch.randn(3, 4)
def is_model_stripped(f, strip_doc_string=None):
if strip_doc_string is None:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
else:
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string,
opset_version=self.opset_version)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
onnx.helper.strip_doc_string(model_strip)
return model == model_strip
# test strip_doc_string=True (default)
self.assertTrue(is_model_stripped(io.BytesIO()))
# test strip_doc_string=False
self.assertFalse(is_model_stripped(io.BytesIO(), False))
# NB: remove this test once DataParallel can be correctly handled
def test_error_on_data_parallel(self):
model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
x = torch.randn(1, 2, 3, 4)
f = io.BytesIO()
with self.assertRaisesRegex(ValueError,
'torch.nn.DataParallel is not supported by ONNX '
'exporter, please use \'attribute\' module to '
'unwrap model from torch.nn.DataParallel. Try '):
torch.onnx.export(model, x, f, opset_version=self.opset_version)
# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=10))
if __name__ == '__main__':
run_tests()