-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathexport.py
97 lines (76 loc) · 3.24 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import argparse
from easy_ViTPose.vit_models.model import ViTPose
from easy_ViTPose.vit_utils.util import infer_dataset_by_path, dyn_model_import
parser = argparse.ArgumentParser()
parser.add_argument('--model-ckpt', type=str, required=True,
help='The torch model that shall be used for conversion')
parser.add_argument('--model-name', type=str, required=True, choices=['s', 'b', 'l', 'h'],
help='[s: ViT-S, b: ViT-B, l: ViT-L, h: ViT-H]')
parser.add_argument('--output', type=str, default='ckpts/',
help='File (without extension) or dir path for checkpoint output')
parser.add_argument('--dataset', type=str, required=False, default=None,
help='Name of the dataset. If None it"s extracted from the file name. \
["coco", "coco_25", "wholebody", "mpii", "ap10k", "apt36k", "aic"]')
args = parser.parse_args()
# Get dataset and model_cfg
dataset = args.dataset
if dataset is None:
dataset = infer_dataset_by_path(args.model_ckpt)
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
'The specified dataset is not valid'
model_cfg = dyn_model_import(dataset, args.model_name)
# Convert to onnx and save
print('>>> Converting to ONNX')
CKPT_PATH = args.model_ckpt
C, H, W = (3, 256, 192)
model = ViTPose(model_cfg)
ckpt = torch.load(CKPT_PATH, map_location='cpu', weights_only=True)
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
model.load_state_dict(ckpt)
model.eval()
input_names = ["input_0"]
output_names = ["output_0"]
device = next(model.parameters()).device
inputs = torch.randn(1, C, H, W).to(device)
dynamic_axes = {'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}}
out_name = os.path.basename(args.model_ckpt).replace('.pth', '.onnx')
if not os.path.isdir(args.output):
out_name = os.path.basename(args.output)
output_onnx = os.path.join(os.path.dirname(args.output), out_name)
torch_out = torch.onnx.export(model, inputs, output_onnx, export_params=True, verbose=False,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes)
print(f">>> Saved at: {os.path.abspath(output_onnx)}")
print('=' * 80)
print()
try:
import torch_tensorrt
except ModuleNotFoundError:
print('>>> TRT module not found, skipping')
import sys
sys.exit()
# From yolo convert script, onnx -> trt
print('>>> Converting to TRT')
trt_ts_module = torch_tensorrt.compile(model,
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
inputs = [
torch_tensorrt.Input( # Specify input object with shape and dtype
shape=[1, C, H, W],
dtype=torch.float32
)
],
# TODO: ADD Datatype for inference. Allowed options torch.(float|half|int8|int32|bool)
enabled_precisions = {torch.float32}, # half Run with FP16
workspace_size = 1 << 28
)
# Export
output_trt = output_onnx.replace('.onnx', '.engine')
input_names = ["input_0"]
output_names = ["output_0"]
device = next(model.parameters()).device
torch.jit.save(trt_ts_module, output_trt) # save the TRT embedded Torchscript
print(f">>> Saved at: {os.path.abspath(output_trt)}")