-
Notifications
You must be signed in to change notification settings - Fork 1
/
export.py
121 lines (101 loc) · 3.81 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
from loguru import logger
import torch
from torch import nn
from yolox.exp import get_exp
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module
def make_parser():
parser = argparse.ArgumentParser("YOLOX onnx deploy")
parser.add_argument(
"--output-name", type=str, default="yolox.onnx", help="output name of models"
)
parser.add_argument(
"--input", default="data", type=str, help="input node name of onnx model"
)
parser.add_argument(
"--output", default="output", type=str, help="output node name of onnx model"
)
parser.add_argument(
"-o", "--opset", default=9, type=int, help="onnx opset version"
)
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument(
"--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
)
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
parser.add_argument(
"-f",
"--exp_file",
default='./exps/default/yolov5_repvgg.py',
type=str,
help="experiment description file",
)
parser.add_argument("-expn", "--experiment-name", type=str, default="yolov5_repvgg")
parser.add_argument("-n", "--name", type=str, default=None, help="yolov5_repvgg")
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--decode_in_inference",
action="store_true",
help="decode in inference or not"
)
return parser
@logger.catch
def main():
args = make_parser().parse_args()
logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
if args.ckpt is None:
file_name = os.path.join(exp.output_dir, args.experiment_name)
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cpu")
model.eval()
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model = replace_module(model, nn.SiLU, SiLU)
model.head.decode_in_inference = args.decode_in_inference
logger.info("loading checkpoint done.")
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
torch.onnx._export(
model,
dummy_input,
args.output_name,
input_names=[args.input],
output_names=[args.output],
dynamic_axes={args.input: {0: 'batch'},
args.output: {0: 'batch'}} if args.dynamic else None,
opset_version=args.opset,
)
logger.info("generated onnx model named {}".format(args.output_name))
if not args.no_onnxsim:
import onnx
from onnxsim import simplify
input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
# use onnxsimplify to reduce reduent model.
onnx_model = onnx.load(args.output_name)
model_simp, check = simplify(onnx_model,
dynamic_input_shape=args.dynamic,
input_shapes=input_shapes)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, args.output_name)
logger.info("generated simplified onnx model named {}".format(args.output_name))
if __name__ == "__main__":
main()