forked from NVIDIA/tacotron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
71 lines (52 loc) · 1.95 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import write
import os
from hparams import create_hparams
# from model import Tacotron2
# from layers import TacotronSTFT, STFT
from audio_processing import griffin_lim
from text import text_to_sequence
import sys
sys.path.append('waveglow')
from denoiser import Denoiser
def plot_data(data, figsize=(16, 4)):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
axes[i].imshow(data[i], aspect='auto', origin='bottom',
interpolation='none')
plt.show()
def save_audio(path, sampling_rate, audio):
print("saving audio to", path)
write(path, sampling_rate, audio.T.astype(np.float32))
def load_waveglow(path):
waveglow = torch.load(path)['model']
waveglow.cuda().eval().half()
for k in waveglow.convinv:
k.float()
denoiser = Denoiser(waveglow)
return waveglow, denoiser
def infer(tacotron2, waveglow_path, text, audio_path, denoiser_strength=0.006):
hparams = create_hparams()
hparams.max_wav_value=32768.0
hparams.sampling_rate = 22050
hparams.filter_length=1024
hparams.hop_length=256
hparams.win_length=1024
waveglow, denoiser = load_waveglow(waveglow_path)
sequence = np.array(text_to_sequence(text, ['german_cleaners']))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda().long()
# text -> mel spectogram
mel_outputs, mel_outputs_postnet, _, alignments = tacotron2.inference(sequence)
# mel spectogram -> sound wave
with torch.no_grad():
audio = waveglow.infer(mel_outputs_postnet, sigma=0.85)
# denoise
audio_denoised = denoiser(audio, strength=denoiser_strength)[:, 0]
audio_denoised_np = audio_denoised.cpu().numpy()
if audio_path:
save_audio(audio_path, hparams.sampling_rate, audio_denoised_np)
return audio_denoised_np