-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_onnx.py
47 lines (40 loc) · 1.01 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
from model.nsnet2 import NSNet2
cfg = {
'n_fft': 320,
'hop_len': 160,
'win_len': 320,
}
model = NSNet2(cfg=cfg)
# # save checkpoint
# torch.save(model.state_dict(), 'nsnet2.ckpt')
# load checkpoint
model.load_state_dict(torch.load('nsnet2.ckpt'))
model.eval()
# 20ms window, 10ms hop
x = torch.randn(cfg['n_fft'],)
n = torch.stft(x,
n_fft=cfg['n_fft'],
hop_length=cfg['hop_len'],
win_length=cfg['win_len'],
window=torch.hann_window(cfg['win_len']),
return_complex=True,
)
n_freq, n_frames = n.shape[-2:]
x = torch.randn(1, n_frames, n_freq)
# save as ONNX model
torch.onnx.export(
model, x,
"nsnet2.onnx",
do_constant_folding=True,
opset_version=16,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 1: 'frames', 2: 'freq_bins'},
'output': {0: 'batch_size', 2: 'frames', 3: 'freq_bins'},
}
)
# simplify ONNX model
os.system('python3 -m onnxsim nsnet2.onnx nsnet2_simplified.onnx')