forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pytorch_onnx_onnxruntime.py
520 lines (405 loc) · 18.2 KB
/
test_pytorch_onnx_onnxruntime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import onnxruntime # noqa
import torch
import numpy as np
import io
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
import model_defs.word_language_model as word_language_model
def run_model_test(self, model, train, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
output = model(*input)
if isinstance(output, torch.Tensor):
output = (output,)
# export the model to ONNX
f = io.BytesIO()
torch.onnx.export(model, input, f,
opset_version=self.opset_version,
example_outputs=output)
input, _ = torch.jit._flatten(input)
output, _ = torch.jit._flatten(output)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, input))
outputs = list(map(to_numpy, output))
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
class TestONNXRuntime(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
def run_test(self, model, input, rtol=1e-3, atol=1e-7):
run_model_test(self, model, False, None,
input=input, rtol=rtol, atol=atol)
def run_word_language_model(self, model_name):
ntokens = 50
emsize = 5
nhid = 5
nlayers = 5
dropout = 0.2
tied = False
batchsize = 5
model = word_language_model.RNNModel(model_name, ntokens, emsize,
nhid, nlayers, dropout, tied,
batchsize)
x = torch.arange(0, ntokens).long().view(-1, batchsize)
# Only support CPU version, since tracer is not working in GPU RNN.
self.run_test(model, (x, model.hidden))
def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH")
def test_word_language_model_RNN_RELU(self):
self.run_word_language_model("RNN_RELU")
def test_word_language_model_LSTM(self):
self.run_word_language_model("LSTM")
def test_word_language_model_GRU(self):
self.run_word_language_model("GRU")
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_trace(self):
class FullModel(torch.nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_script(self):
class FullModelScripting(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModelScripting(), x)
def test_maxpool(self):
model = torch.nn.MaxPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_with_indices(self):
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_maxpool_dilation(self):
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_avgpool(self):
model = torch.nn.AvgPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_slice_trace(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x[0:1]
x = torch.randn(3)
self.run_test(MyModule(), x)
def test_slice_script(self):
class DynamicSliceModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x[1:x.size(0)]
x = torch.rand(1, 2)
self.run_test(DynamicSliceModel(), x)
def _test_index_generic(self, fn):
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, input):
return fn(input)
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
def test_tensor_index_advanced_indexing(self):
self._test_index_generic(
lambda input: input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])])
self._test_index_generic(lambda input: input[..., torch.tensor([2, 1]), torch.tensor([0, 3])])
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])])
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])])
def test_tensor_index_advanced_indexing_consecutive(self):
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None])
@skipIfUnsupportedMinOpsetVersion(10)
def test_flip(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.flip(x, dims=[0])
x = torch.tensor(np.arange(6.0).reshape(2, 3))
self.run_test(MyModule(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_interpolate_scale(self):
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2)
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
# NOTE: Supported in onnxruntime master, enable this after 0.5 release.
@skipIfUnsupportedOpsetVersion([10])
def test_interpolate_output_size(self):
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", size=(6, 8))
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_interpolate_downsample(self):
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=[1, 1, 0.5, 0.5])
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
def test_index_select_constant_scaler_index(self):
class IndexSelectScalerIndexModel(torch.nn.Module):
def forward(self, x):
index = 2
return torch.index_select(x, 1, torch.tensor(index))
x = torch.randn(3, 4)
self.run_test(IndexSelectScalerIndexModel(), x)
def test_index_select_scaler_index(self):
class IndexSelectScalerIndexModel(torch.nn.Module):
def __init__(self, index_base):
super(IndexSelectScalerIndexModel, self).__init__()
self.index_base = torch.tensor(index_base)
def forward(self, x, index_offset):
index = self.index_base + index_offset
return torch.index_select(x, 1, index)
x = torch.randn(3, 4)
offset = 2
index_offset = torch.tensor(offset)
base = 1
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
# TODO: enable for opset 10 when ONNXRuntime version will be updated
@skipIfUnsupportedOpsetVersion([10])
def test_topk(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.topk(x, 3)
x = torch.arange(1., 6., requires_grad=True)
self.run_test(MyModule(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_topk_script(self):
class MyModuleDynamic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, k):
return torch.topk(x, k)
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.run_test(MyModuleDynamic(), [x, k])
def test_layer_norm(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
self.run_test(model, x)
def test_reduce_log_sum_exp(self):
class ReduceLogSumExpModel(torch.nn.Module):
def forward(self, input):
a = torch.logsumexp(input, dim=0)
b = torch.logsumexp(input, dim=(0, 1))
return a + b
x = torch.randn(4, 4, requires_grad=True)
self.run_test(ReduceLogSumExpModel(), x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_adaptive_max_pool(self):
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_2d(self):
model = torch.nn.MaxPool2d(5, padding=(1, 2))
x = torch.randn(1, 20, 16, 50, requires_grad=True)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_max_tensors(self):
class MaxModel(torch.nn.Module):
def forward(self, input, other):
return torch.max(input, other)
model = MaxModel()
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 1, requires_grad=True)
self.run_test(model, (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
return torch._dim_arange(input, 1)
x = torch.ones(5, 6)
self.run_test(DimArange(), x)
def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
return input > other
x = torch.randn(1, 2, 3, 4, requires_grad=True)
y = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(GreaterModel(), (x, y))
x = torch.randint(10, (3, 4), dtype=torch.int32)
y = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(GreaterModel(), (x, y))
def test_gt_scalar(self):
class GreaterModel(torch.nn.Module):
def forward(self, input):
return input > 1
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(GreaterModel(), x)
x = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(GreaterModel(), x)
def test_lt(self):
class LessModel(torch.nn.Module):
def forward(self, input, other):
return input > other
x = torch.randn(1, 2, 3, 4, requires_grad=True)
y = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(LessModel(), (x, y))
x = torch.randint(10, (3, 4), dtype=torch.int32)
y = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(LessModel(), (x, y))
def test_matmul(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (3, 4))
y = torch.randint(10, (4, 5))
self.run_test(MatmulModel(), (x, y))
def test_matmul_batch(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(2, 3, 4, requires_grad=True)
y = torch.randn(2, 4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 4, 5))
self.run_test(MatmulModel(), (x, y))
def test_view(self):
class ViewModel(torch.nn.Module):
def forward(self, input):
return input.view(4, 24)
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
self.run_test(ViewModel(), x)
def test_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
def test_flatten2d(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input, 1)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
def forward(self, x):
return torch.zeros(x.size()) + torch.ones(x.size())
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_like_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
return zeros + ones
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
def test_sort(self):
class SortModel(torch.nn.Module):
def __init__(self, dim):
super(SortModel, self).__init__()
self.dim = dim
def forward(self, x):
return torch.sort(x, dim=self.dim, descending=True)
dim = 1
x = torch.randn(3, 4)
self.run_test(SortModel(dim), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_masked_fill(self):
class MaskedFillModel(torch.nn.Module):
def forward(self, x):
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
return x.masked_fill(mask, 2)
x = torch.zeros(4, 2, 3, requires_grad=True)
self.run_test(MaskedFillModel(), x)
class MaskedFillModel2(torch.nn.Module):
def forward(self, x):
return x.masked_fill(x > 3, -1)
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
self.run_test(MaskedFillModel2(), x)
# opset 7 tests
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=7))
# opset 8 tests
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=8))
# opset 10 tests
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=10))
if __name__ == '__main__':
unittest.main()