forked from open-mmlab/mmaction2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorch2onnx.py
169 lines (144 loc) · 5.8 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import argparse
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmaction.models import build_model
try:
import onnx
import onnxruntime as rt
except ImportError as e:
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=1.0.4')
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(model,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Convert pytorch model to onnx model.
Args:
model (:obj:`nn.Module`): The pytorch model to be exported.
input_shape (tuple[int]): The input tensor shape of the model.
opset_version (int): Opset version of onnx used. Default: 11.
show (bool): Determines whether to print the onnx model architecture.
Default: False.
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
verify (bool): Determines whether to verify the onnx model.
Default: False.
"""
model.cpu().eval()
input_tensor = torch.randn(input_shape)
register_extra_symbolics(opset_version)
torch.onnx.export(
model,
input_tensor,
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_result = model(input_tensor)[0].detach().numpy()
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(
None, {net_feed_input[0]: input_tensor.detach().numpy()})[0]
# only compare part of results
random_class = np.random.randint(pytorch_result.shape[1])
assert np.allclose(
pytorch_result[:, random_class], onnx_result[:, random_class]
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMAction2 models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--is-localizer',
action='store_true',
help='whether it is a localizer')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1, 3, 8, 224, 224],
help='input video size')
parser.add_argument(
'--softmax',
action='store_true',
help='wheter to add softmax layer at the end of recognizers')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMAction2 only supports opset 11 now'
cfg = mmcv.Config.fromfile(args.config)
# import modules from string list.
if not args.is_localizer:
cfg.model.backbone.pretrained = None
# build the model
model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
from functools import partial
model.forward = partial(model.forward_dummy, softmax=args.softmax)
elif hasattr(model, '_forward') and args.is_localizer:
model.forward = model._forward
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
model,
args.shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)