This repository has been archived by the owner on Jan 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
export_onnx.py
110 lines (87 loc) · 3.16 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# import argparse
# from pathlib import Path
# from typing import Optional
# import torch
# import utils
# from models import SynthesizerTrn
# from text.symbols import symbols
# OPSET_VERSION = 15
# def main() -> None:
# torch.manual_seed(1234)
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--model-path", required=True, help="Path to model weights (.pth)"
# )
# parser.add_argument(
# "--config-path", required=True, help="Path to model config (.json)"
# )
# parser.add_argument("--output", required=True, help="Path to output model (.onnx)")
# args = parser.parse_args()
# args.model_path = Path(args.model_path)
# args.config_path = Path(args.config_path)
# args.output = Path(args.output)
# args.output.parent.mkdir(parents=True, exist_ok=True)
# hps = utils.get_hparams_from_file(args.config_path)
# if (
# "use_mel_posterior_encoder" in hps.model.keys()
# and hps.model.use_mel_posterior_encoder == True
# ):
# print("Using mel posterior encoder for VITS2")
# posterior_channels = 80 # vits2
# hps.data.use_mel_posterior_encoder = True
# else:
# print("Using lin posterior encoder for VITS1")
# posterior_channels = hps.data.filter_length // 2 + 1
# hps.data.use_mel_posterior_encoder = False
# model_g = SynthesizerTrn(
# len(symbols),
# posterior_channels,
# hps.train.segment_size // hps.data.hop_length,
# n_speakers=hps.data.n_speakers,
# **hps.model,
# )
# _ = model_g.eval()
# _ = utils.load_checkpoint(args.model_path, model_g, None)
# def infer_forward(text, text_lengths, scales, sid=None):
# noise_scale = scales[0]
# length_scale = scales[1]
# noise_scale_w = scales[2]
# audio = model_g.infer(
# text,
# text_lengths,
# noise_scale=noise_scale,
# length_scale=length_scale,
# noise_scale_w=noise_scale_w,
# sid=sid,
# )[0]
# return audio
# model_g.forward = infer_forward
# dummy_input_length = 50
# sequences = torch.randint(
# low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long
# )
# sequence_lengths = torch.LongTensor([sequences.size(1)])
# sid: Optional[torch.LongTensor] = None
# if hps.data.n_speakers > 1:
# sid = torch.LongTensor([0])
# # noise, length, noise_w
# scales = torch.FloatTensor([0.667, 1.0, 0.8])
# dummy_input = (sequences, sequence_lengths, scales, sid)
# # Export
# torch.onnx.export(
# model=model_g,
# args=dummy_input,
# f=str(args.output),
# verbose=False,
# opset_version=OPSET_VERSION,
# input_names=["input", "input_lengths", "scales", "sid"],
# output_names=["output"],
# dynamic_axes={
# "input": {0: "batch_size", 1: "phonemes"},
# "input_lengths": {0: "batch_size"},
# "output": {0: "batch_size", 1: "time1", 2: "time2"},
# },
# )
# print(f"Exported model to {args.output}")
# if __name__ == "__main__":
# main()