-
Notifications
You must be signed in to change notification settings - Fork 487
/
train.py
1023 lines (845 loc) · 37.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Trainining script for seq2seq text-to-speech synthesis model.
usage: train.py [options]
options:
--data-root=<dir> Directory contains preprocessed features.
--checkpoint-dir=<dir> Directory where to save model checkpoints [default: checkpoints].
--hparams=<parmas> Hyper parameters [default: ].
--preset=<json> Path of preset parameters (json).
--checkpoint=<path> Restore model from checkpoint path if given.
--checkpoint-seq2seq=<path> Restore seq2seq model from checkpoint path.
--checkpoint-postnet=<path> Restore postnet model from checkpoint path.
--train-seq2seq-only Train only seq2seq model.
--train-postnet-only Train only postnet model.
--restore-parts=<path> Restore part of the model.
--log-event-path=<name> Log event path.
--reset-optimizer Reset optimizer.
--load-embedding=<path> Load embedding from checkpoint.
--speaker-id=<N> Use specific speaker of data in case for multi-speaker datasets.
-h, --help Show this help message and exit
"""
from docopt import docopt
import sys
import gc
import platform
from os.path import dirname, join
from tqdm import tqdm, trange
from datetime import datetime
# The deepvoice3 model
from deepvoice3_pytorch import frontend, builder
import audio
import lrschedule
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
from torch.utils.data.sampler import Sampler
import numpy as np
from numba import jit
from nnmnkwii.datasets import FileSourceDataset, FileDataSource
from os.path import join, expanduser
import random
import librosa.display
from matplotlib import pyplot as plt
import sys
import os
from tensorboardX import SummaryWriter
from matplotlib import cm
from warnings import warn
from hparams import hparams, hparams_debug_string
global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
if use_cuda:
cudnn.benchmark = False
_frontend = None # to be set later
def _pad(seq, max_len, constant_values=0):
return np.pad(seq, (0, max_len - len(seq)),
mode='constant', constant_values=constant_values)
def _pad_2d(x, max_len, b_pad=0):
x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)],
mode="constant", constant_values=0)
return x
def plot_alignment(alignment, path, info=None):
fig, ax = plt.subplots()
im = ax.imshow(
alignment,
aspect='auto',
origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
plt.savefig(path, format='png')
plt.close()
class TextDataSource(FileDataSource):
def __init__(self, data_root, speaker_id=None):
self.data_root = data_root
self.speaker_ids = None
self.multi_speaker = False
# If not None, filter by speaker_id
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with open(meta, "rb") as f:
lines = f.readlines()
l = lines[0].decode("utf-8").split("|")
assert len(l) == 4 or len(l) == 5
self.multi_speaker = len(l) == 5
texts = list(map(lambda l: l.decode("utf-8").split("|")[3], lines))
if self.multi_speaker:
speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
if self.speaker_id is not None:
indices = np.array(speaker_ids) == self.speaker_id
texts = list(np.array(texts)[indices])
self.multi_speaker = False
return texts
return texts, speaker_ids
else:
return texts
def collect_features(self, *args):
if self.multi_speaker:
text, speaker_id = args
else:
text = args[0]
global _frontend
if _frontend is None:
_frontend = getattr(frontend, hparams.frontend)
seq = _frontend.text_to_sequence(text, p=hparams.replace_pronunciation_prob)
if platform.system() == "Windows":
if hasattr(hparams, 'gc_probability'):
_frontend = None # memory leaking prevention in Windows
if np.random.rand() < hparams.gc_probability:
gc.collect() # garbage collection enforced
print("GC done")
if self.multi_speaker:
return np.asarray(seq, dtype=np.int32), int(speaker_id)
else:
return np.asarray(seq, dtype=np.int32)
class _NPYDataSource(FileDataSource):
def __init__(self, data_root, col, speaker_id=None):
self.data_root = data_root
self.col = col
self.frame_lengths = []
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with open(meta, "rb") as f:
lines = f.readlines()
l = lines[0].decode("utf-8").split("|")
assert len(l) == 4 or len(l) == 5
multi_speaker = len(l) == 5
self.frame_lengths = list(
map(lambda l: int(l.decode("utf-8").split("|")[2]), lines))
paths = list(map(lambda l: l.decode("utf-8").split("|")[self.col], lines))
paths = list(map(lambda f: join(self.data_root, f), paths))
if multi_speaker and self.speaker_id is not None:
speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
indices = np.array(speaker_ids) == self.speaker_id
paths = list(np.array(paths)[indices])
self.frame_lengths = list(np.array(self.frame_lengths)[indices])
# aha, need to cast numpy.int64 to int
self.frame_lengths = list(map(int, self.frame_lengths))
return paths
def collect_features(self, path):
return np.load(path)
class MelSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(MelSpecDataSource, self).__init__(data_root, 1, speaker_id)
class LinearSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(LinearSpecDataSource, self).__init__(data_root, 0, speaker_id)
class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
"""Partially randmoized sampler
1. Sort by lengths
2. Pick a small patch and randomize it
3. Permutate mini-batchs
"""
def __init__(self, lengths, batch_size=16, batch_group_size=None,
permutate=True):
self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths))
self.batch_size = batch_size
if batch_group_size is None:
batch_group_size = min(batch_size * 32, len(self.lengths))
if batch_group_size % batch_size != 0:
batch_group_size -= batch_group_size % batch_size
self.batch_group_size = batch_group_size
assert batch_group_size % batch_size == 0
self.permutate = permutate
def __iter__(self):
indices = self.sorted_indices.clone()
batch_group_size = self.batch_group_size
s, e = 0, 0
for i in range(len(indices) // batch_group_size):
s = i * batch_group_size
e = s + batch_group_size
random.shuffle(indices[s:e])
# Permutate batches
if self.permutate:
perm = np.arange(len(indices[:e]) // self.batch_size)
random.shuffle(perm)
indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1)
# Handle last elements
s += batch_group_size
if s < len(indices):
random.shuffle(indices[s:])
return iter(indices)
def __len__(self):
return len(self.sorted_indices)
class PyTorchDataset(object):
def __init__(self, X, Mel, Y):
self.X = X
self.Mel = Mel
self.Y = Y
# alias
self.multi_speaker = X.file_data_source.multi_speaker
def __getitem__(self, idx):
if self.multi_speaker:
text, speaker_id = self.X[idx]
return text, self.Mel[idx], self.Y[idx], speaker_id
else:
return self.X[idx], self.Mel[idx], self.Y[idx]
def __len__(self):
return len(self.X)
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = sequence_length.unsqueeze(1) \
.expand_as(seq_range_expand)
return (seq_range_expand < seq_length_expand).float()
class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()
self.criterion = nn.L1Loss(reduction="sum")
def forward(self, input, target, lengths=None, mask=None, max_len=None):
if lengths is None and mask is None:
raise RuntimeError("Should provide either lengths or mask")
# (B, T, 1)
if mask is None:
mask = sequence_mask(lengths, max_len).unsqueeze(-1)
# (B, T, D)
mask_ = mask.expand_as(input)
loss = self.criterion(input * mask_, target * mask_)
return loss / mask_.sum()
def collate_fn(batch):
"""Create batch"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
multi_speaker = len(batch[0]) == 4
# Lengths
input_lengths = [len(x[0]) for x in batch]
max_input_len = max(input_lengths)
target_lengths = [len(x[1]) for x in batch]
max_target_len = max(target_lengths)
if max_target_len % r != 0:
max_target_len += r - max_target_len % r
assert max_target_len % r == 0
if max_target_len % downsample_step != 0:
max_target_len += downsample_step - max_target_len % downsample_step
assert max_target_len % downsample_step == 0
# Set 0 for zero beginning padding
# imitates initial decoder states
b_pad = r
max_target_len += b_pad * downsample_step
a = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.int)
x_batch = torch.LongTensor(a)
input_lengths = torch.LongTensor(input_lengths)
target_lengths = torch.LongTensor(target_lengths)
b = np.array([_pad_2d(x[1], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
mel_batch = torch.FloatTensor(b)
c = np.array([_pad_2d(x[2], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
y_batch = torch.FloatTensor(c)
# text positions
text_positions = np.array([_pad(np.arange(1, len(x[0]) + 1), max_input_len)
for x in batch], dtype=np.int)
text_positions = torch.LongTensor(text_positions)
max_decoder_target_len = max_target_len // r // downsample_step
# frame positions
s, e = 1, max_decoder_target_len + 1
# if b_pad > 0:
# s, e = s - 1, e - 1
# NOTE: needs clone to supress RuntimeError in dataloarder...
# ref: https://github.com/pytorch/pytorch/issues/10756
frame_positions = torch.arange(s, e).long().unsqueeze(0).expand(
len(batch), max_decoder_target_len).clone()
# done flags
done = np.array([_pad(np.zeros(len(x[1]) // r // downsample_step - 1),
max_decoder_target_len, constant_values=1)
for x in batch])
done = torch.FloatTensor(done).unsqueeze(-1)
if multi_speaker:
speaker_ids = torch.LongTensor([x[3] for x in batch])
else:
speaker_ids = None
return x_batch, input_lengths, mel_batch, y_batch, \
(text_positions, frame_positions), done, target_lengths, speaker_ids
def time_string():
return datetime.now().strftime('%Y-%m-%d %H:%M')
def save_alignment(path, attn):
plot_alignment(attn.T, path, info="{}, {}, step={}".format(
hparams.builder, time_string(), global_step))
def prepare_spec_image(spectrogram):
# [0, 1]
spectrogram = (spectrogram - np.min(spectrogram)) / (np.max(spectrogram) - np.min(spectrogram))
spectrogram = np.flip(spectrogram, axis=1) # flip against freq axis
return np.uint8(cm.magma(spectrogram.T) * 255)
def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker):
# harded coded
texts = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
import synthesis
synthesis._frontend = _frontend
eval_output_dir = join(checkpoint_dir, "eval")
os.makedirs(eval_output_dir, exist_ok=True)
# Prepare model for evaluation
model_eval = build_model().to(device)
model_eval.load_state_dict(model.state_dict())
# hard coded
speaker_ids = [0, 1, hparams.n_speakers-1] if ismultispeaker else [None]
for speaker_id in speaker_ids:
speaker_str = "multispeaker{}".format(speaker_id) if speaker_id is not None else "single"
for idx, text in enumerate(texts):
signal, alignment, _, mel = synthesis.tts(
model_eval, text, p=0, speaker_id=speaker_id, fast=True)
signal /= np.max(np.abs(signal))
# Alignment
path = join(eval_output_dir, "step{:09d}_text{}_{}_alignment.png".format(
global_step, idx, speaker_str))
save_alignment(path, alignment)
tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
try:
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
except Exception as e:
warn(str(e))
# Mel
try:
writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
prepare_spec_image(mel), global_step)
except Exception as e:
warn(str(e))
# Audio
path = join(eval_output_dir, "step{:09d}_text{}_{}_predicted.wav".format(
global_step, idx, speaker_str))
audio.save_wav(signal, path)
try:
writer.add_audio("(Eval) Predicted audio signal {}_{}".format(idx, speaker_str),
signal, global_step, sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
input_lengths, checkpoint_dir=None):
print("Save intermediate states at step {}".format(global_step))
# idx = np.random.randint(0, len(input_lengths))
idx = min(1, len(input_lengths) - 1)
input_length = input_lengths[idx]
# Alignment
# Multi-hop attention
if attn is not None and attn.dim() == 4:
for i, alignment in enumerate(attn):
alignment = alignment[idx].cpu().data.numpy()
tag = "alignment_layer{}".format(i + 1)
try:
writer.add_image(tag, np.uint8(cm.viridis(
np.flip(alignment, 1).T) * 255), global_step)
# save files as well for now
alignment_dir = join(
checkpoint_dir, "alignment_layer{}".format(i + 1))
os.makedirs(alignment_dir, exist_ok=True)
path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
global_step, i + 1))
save_alignment(path, alignment)
except Exception as e:
warn(str(e))
# Save averaged alignment
alignment_dir = join(checkpoint_dir, "alignment_ave")
os.makedirs(alignment_dir, exist_ok=True)
path = join(alignment_dir, "step{:09d}_layer_alignment.png".format(global_step))
alignment = attn.mean(0)[idx].cpu().data.numpy()
save_alignment(path, alignment)
tag = "averaged_alignment"
try:
writer.add_image(tag, np.uint8(cm.viridis(
np.flip(alignment, 1).T) * 255), global_step)
except Exception as e:
warn(str(e))
# Predicted mel spectrogram
if mel_outputs is not None:
mel_output = mel_outputs[idx].cpu().data.numpy()
mel_output = prepare_spec_image(audio._denormalize(mel_output))
try:
writer.add_image("Predicted mel spectrogram",
mel_output, global_step)
except Exception as e:
warn(str(e))
pass
# Predicted spectrogram
if linear_outputs is not None:
linear_output = linear_outputs[idx].cpu().data.numpy()
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
try:
writer.add_image("Predicted linear spectrogram",
spectrogram, global_step)
except Exception as e:
warn(str(e))
pass
# Predicted audio signal
signal = audio.inv_spectrogram(linear_output.T)
signal /= np.max(np.abs(signal))
path = join(checkpoint_dir, "step{:09d}_predicted.wav".format(
global_step))
try:
writer.add_audio("Predicted audio signal", signal,
global_step, sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
audio.save_wav(signal, path)
# Target mel spectrogram
if mel_outputs is not None:
mel_output = mel[idx].cpu().data.numpy()
mel_output = prepare_spec_image(audio._denormalize(mel_output))
try:
writer.add_image("Target mel spectrogram", mel_output, global_step)
except Exception as e:
warn(str(e))
pass
# Target spectrogram
if linear_outputs is not None:
linear_output = y[idx].cpu().data.numpy()
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
try:
writer.add_image("Target linear spectrogram",
spectrogram, global_step)
except Exception as e:
warn(str(e))
pass
def logit(x, eps=1e-8):
return torch.log(x + eps) - torch.log(1 - x + eps)
def masked_mean(y, mask):
# (B, T, D)
mask_ = mask.expand_as(y)
return (y * mask_).sum() / mask_.sum()
def spec_loss(y_hat, y, mask, priority_bin=None, priority_w=0):
masked_l1 = MaskedL1Loss()
l1 = nn.L1Loss()
w = hparams.masked_loss_weight
# L1 loss
if w > 0:
assert mask is not None
l1_loss = w * masked_l1(y_hat, y, mask=mask) + (1 - w) * l1(y_hat, y)
else:
assert mask is None
l1_loss = l1(y_hat, y)
# Priority L1 loss
if priority_bin is not None and priority_w > 0:
if w > 0:
priority_loss = w * masked_l1(
y_hat[:, :, :priority_bin], y[:, :, :priority_bin], mask=mask) \
+ (1 - w) * l1(y_hat[:, :, :priority_bin], y[:, :, :priority_bin])
else:
priority_loss = l1(y_hat[:, :, :priority_bin], y[:, :, :priority_bin])
l1_loss = (1 - priority_w) * l1_loss + priority_w * priority_loss
# Binary divergence loss
if hparams.binary_divergence_weight <= 0:
binary_div = y.data.new(1).zero_()
else:
y_hat_logits = logit(y_hat)
z = -y * y_hat_logits + torch.log1p(torch.exp(y_hat_logits))
if w > 0:
binary_div = w * masked_mean(z, mask) + (1 - w) * z.mean()
else:
binary_div = z.mean()
return l1_loss, binary_div
@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
W = np.zeros((max_N, max_T), dtype=np.float32)
for n in range(N):
for t in range(T):
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
return W
def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
B = len(input_lengths)
max_input_len = input_lengths.max()
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
for b in range(B):
W[b] = guided_attention(input_lengths[b], max_input_len,
target_lengths[b], max_target_len, g).T
return W
def train(device, model, data_loader, optimizer, writer,
init_lr=0.002,
checkpoint_dir=None, checkpoint_interval=None, nepochs=None,
clip_thresh=1.0,
train_seq2seq=True, train_postnet=True):
linear_dim = model.linear_dim
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
current_lr = init_lr
binary_criterion = nn.BCELoss()
assert train_seq2seq or train_postnet
global global_step, global_epoch
while global_epoch < nepochs:
running_loss = 0.
for step, (x, input_lengths, mel, y, positions, done, target_lengths,
speaker_ids) \
in tqdm(enumerate(data_loader)):
model.train()
ismultispeaker = speaker_ids is not None
# Learning rate schedule
if hparams.lr_schedule is not None:
lr_schedule_f = getattr(lrschedule, hparams.lr_schedule)
current_lr = lr_schedule_f(
init_lr, global_step, **hparams.lr_schedule_kwargs)
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
optimizer.zero_grad()
# Used for Position encoding
text_positions, frame_positions = positions
# Downsample mel spectrogram
if downsample_step > 1:
mel = mel[:, 0::downsample_step, :].contiguous()
# Lengths
input_lengths = input_lengths.long().numpy()
decoder_lengths = target_lengths.long().numpy() // r // downsample_step
max_seq_len = max(input_lengths.max(), decoder_lengths.max())
if max_seq_len >= hparams.max_positions:
raise RuntimeError(
"""max_seq_len ({}) >= max_posision ({})
Input text or decoder targget length exceeded the maximum length.
Please set a larger value for ``max_position`` in hyper parameters.""".format(
max_seq_len, hparams.max_positions))
# Transform data to CUDA device
if train_seq2seq:
x = x.to(device)
text_positions = text_positions.to(device)
frame_positions = frame_positions.to(device)
if train_postnet:
y = y.to(device)
mel, done = mel.to(device), done.to(device)
target_lengths = target_lengths.to(device)
speaker_ids = speaker_ids.to(device) if ismultispeaker else None
# Create mask if we use masked loss
if hparams.masked_loss_weight > 0:
# decoder output domain mask
decoder_target_mask = sequence_mask(
target_lengths // (r * downsample_step),
max_len=mel.size(1)).unsqueeze(-1)
if downsample_step > 1:
# spectrogram-domain mask
target_mask = sequence_mask(
target_lengths, max_len=y.size(1)).unsqueeze(-1)
else:
target_mask = decoder_target_mask
# shift mask
decoder_target_mask = decoder_target_mask[:, r:, :]
target_mask = target_mask[:, r:, :]
else:
decoder_target_mask, target_mask = None, None
# Apply model
if train_seq2seq and train_postnet:
mel_outputs, linear_outputs, attn, done_hat = model(
x, mel, speaker_ids=speaker_ids,
text_positions=text_positions, frame_positions=frame_positions,
input_lengths=input_lengths)
elif train_seq2seq:
assert speaker_ids is None
mel_outputs, attn, done_hat, _ = model.seq2seq(
x, mel,
text_positions=text_positions, frame_positions=frame_positions,
input_lengths=input_lengths)
# reshape
mel_outputs = mel_outputs.view(len(mel), -1, mel.size(-1))
linear_outputs = None
elif train_postnet:
assert speaker_ids is None
linear_outputs = model.postnet(mel)
mel_outputs, attn, done_hat = None, None, None
# Losses
w = hparams.binary_divergence_weight
# mel:
if train_seq2seq:
mel_l1_loss, mel_binary_div = spec_loss(
mel_outputs[:, :-r, :], mel[:, r:, :], decoder_target_mask)
mel_loss = (1 - w) * mel_l1_loss + w * mel_binary_div
# done:
if train_seq2seq:
done_loss = binary_criterion(done_hat, done)
# linear:
if train_postnet:
n_priority_freq = int(hparams.priority_freq / (hparams.sample_rate * 0.5) * linear_dim)
linear_l1_loss, linear_binary_div = spec_loss(
linear_outputs[:, :-r, :], y[:, r:, :], target_mask,
priority_bin=n_priority_freq,
priority_w=hparams.priority_freq_weight)
linear_loss = (1 - w) * linear_l1_loss + w * linear_binary_div
# Combine losses
if train_seq2seq and train_postnet:
loss = mel_loss + linear_loss + done_loss
elif train_seq2seq:
loss = mel_loss + done_loss
elif train_postnet:
loss = linear_loss
# attention
if train_seq2seq and hparams.use_guided_attention:
soft_mask = guided_attentions(input_lengths, decoder_lengths,
attn.size(-2),
g=hparams.guided_attention_sigma)
soft_mask = torch.from_numpy(soft_mask).to(device)
attn_loss = (attn * soft_mask).mean()
loss += attn_loss
if global_step > 0 and global_step % checkpoint_interval == 0:
save_states(
global_step, writer, mel_outputs, linear_outputs, attn,
mel, y, input_lengths, checkpoint_dir)
save_checkpoint(
model, optimizer, global_step, checkpoint_dir, global_epoch,
train_seq2seq, train_postnet)
if global_step > 0 and global_step % hparams.eval_interval == 0:
eval_model(global_step, writer, device, model,
checkpoint_dir, ismultispeaker)
# Update
loss.backward()
if clip_thresh > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.get_trainable_parameters(), clip_thresh)
optimizer.step()
# Logs
writer.add_scalar("loss", float(loss.item()), global_step)
if train_seq2seq:
writer.add_scalar("done_loss", float(done_loss.item()), global_step)
writer.add_scalar("mel loss", float(mel_loss.item()), global_step)
writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step)
writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step)
if train_postnet:
writer.add_scalar("linear_loss", float(linear_loss.item()), global_step)
writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step)
writer.add_scalar("linear_binary_div_loss", float(linear_binary_div.item()), global_step)
if train_seq2seq and hparams.use_guided_attention:
writer.add_scalar("attn_loss", float(attn_loss.item()), global_step)
if clip_thresh > 0:
writer.add_scalar("gradient norm", grad_norm, global_step)
writer.add_scalar("learning rate", current_lr, global_step)
global_step += 1
running_loss += loss.item()
averaged_loss = running_loss / (len(data_loader))
writer.add_scalar("loss (per epoch)", averaged_loss, global_epoch)
print("Loss: {}".format(running_loss / (len(data_loader))))
global_epoch += 1
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch,
train_seq2seq, train_postnet):
if train_seq2seq and train_postnet:
suffix = ""
m = model
elif train_seq2seq:
suffix = "_seq2seq"
m = model.seq2seq
elif train_postnet:
suffix = "_postnet"
m = model.postnet
checkpoint_path = join(
checkpoint_dir, "checkpoint_step{:09d}{}.pth".format(global_step, suffix))
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
torch.save({
"state_dict": m.state_dict(),
"optimizer": optimizer_state,
"global_step": step,
"global_epoch": epoch,
}, checkpoint_path)
print("Saved checkpoint:", checkpoint_path)
def build_model():
model = getattr(builder, hparams.builder)(
n_speakers=hparams.n_speakers,
speaker_embed_dim=hparams.speaker_embed_dim,
n_vocab=_frontend.n_vocab,
embed_dim=hparams.text_embed_dim,
mel_dim=hparams.num_mels,
linear_dim=hparams.fft_size // 2 + 1,
r=hparams.outputs_per_step,
downsample_step=hparams.downsample_step,
padding_idx=hparams.padding_idx,
dropout=hparams.dropout,
kernel_size=hparams.kernel_size,
encoder_channels=hparams.encoder_channels,
decoder_channels=hparams.decoder_channels,
converter_channels=hparams.converter_channels,
use_memory_mask=hparams.use_memory_mask,
trainable_positional_encodings=hparams.trainable_positional_encodings,
force_monotonic_attention=hparams.force_monotonic_attention,
use_decoder_state_for_postnet_input=hparams.use_decoder_state_for_postnet_input,
max_positions=hparams.max_positions,
speaker_embedding_weight_std=hparams.speaker_embedding_weight_std,
freeze_embedding=hparams.freeze_embedding,
window_ahead=hparams.window_ahead,
window_backward=hparams.window_backward,
key_projection=hparams.key_projection,
value_projection=hparams.value_projection,
)
return model
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_checkpoint(path, model, optimizer, reset_optimizer):
global global_step
global global_epoch
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
model.load_state_dict(checkpoint["state_dict"])
if not reset_optimizer:
optimizer_state = checkpoint["optimizer"]
if optimizer_state is not None:
print("Load optimizer state from {}".format(path))
optimizer.load_state_dict(checkpoint["optimizer"])
global_step = checkpoint["global_step"]
global_epoch = checkpoint["global_epoch"]
return model
def _load_embedding(path, model):
state = _load(path)["state_dict"]
key = "seq2seq.encoder.embed_tokens.weight"
model.seq2seq.encoder.embed_tokens.weight.data = state[key]
# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
def restore_parts(path, model):
print("Restore part of the model from: {}".format(path))
state = _load(path)["state_dict"]
model_dict = model.state_dict()
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
try:
model_dict.update(valid_state_dict)
model.load_state_dict(model_dict)
except RuntimeError as e:
# there should be invalid size of weight(s), so load them per parameter
print(str(e))
model_dict = model.state_dict()
for k, v in valid_state_dict.items():
model_dict[k] = v
try:
model.load_state_dict(model_dict)
except RuntimeError as e:
print(str(e))
warn("{}: may contain invalid size of weight. skipping...".format(k))
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
checkpoint_dir = args["--checkpoint-dir"]
checkpoint_path = args["--checkpoint"]
checkpoint_seq2seq_path = args["--checkpoint-seq2seq"]
checkpoint_postnet_path = args["--checkpoint-postnet"]
load_embedding = args["--load-embedding"]
checkpoint_restore_parts = args["--restore-parts"]
speaker_id = args["--speaker-id"]
speaker_id = int(speaker_id) if speaker_id is not None else None
preset = args["--preset"]
data_root = args["--data-root"]
if data_root is None:
data_root = join(dirname(__file__), "data", "ljspeech")
log_event_path = args["--log-event-path"]
reset_optimizer = args["--reset-optimizer"]
# Which model to be trained
train_seq2seq = args["--train-seq2seq-only"]
train_postnet = args["--train-postnet-only"]
# train both if not specified
if not train_seq2seq and not train_postnet:
print("Training whole model")
train_seq2seq, train_postnet = True, True
if train_seq2seq:
print("Training seq2seq model")
elif train_postnet:
print("Training postnet model")
else:
assert False, "must be specified wrong args"
# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])
# Preventing Windows specific error such as MemoryError
# Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch
if platform.system() == "Windows":
print(" [!] Windows Detected - IF THAllocator.c 0x05 error occurs SET num_workers to 1")
assert hparams.name == "deepvoice3"
print(hparams_debug_string())
_frontend = getattr(frontend, hparams.frontend)
os.makedirs(checkpoint_dir, exist_ok=True)
# Input dataset definitions
X = FileSourceDataset(TextDataSource(data_root, speaker_id))
Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id))
Y = FileSourceDataset(LinearSpecDataSource(data_root, speaker_id))
# Prepare sampler
frame_lengths = Mel.file_data_source.frame_lengths
sampler = PartialyRandomizedSimilarTimeLengthSampler(
frame_lengths, batch_size=hparams.batch_size)
# Dataset and Dataloader setup
dataset = PyTorchDataset(X, Mel, Y)
data_loader = data_utils.DataLoader(
dataset, batch_size=hparams.batch_size,
num_workers=hparams.num_workers, sampler=sampler,
collate_fn=collate_fn, pin_memory=hparams.pin_memory, drop_last=True)
device = torch.device("cuda" if use_cuda else "cpu")
# Model
model = build_model().to(device)
optimizer = optim.Adam(model.get_trainable_parameters(),
lr=hparams.initial_learning_rate, betas=(
hparams.adam_beta1, hparams.adam_beta2),
eps=hparams.adam_eps, weight_decay=hparams.weight_decay,
amsgrad=hparams.amsgrad)
if checkpoint_restore_parts is not None:
restore_parts(checkpoint_restore_parts, model)
# Load checkpoints
if checkpoint_postnet_path is not None:
load_checkpoint(checkpoint_postnet_path, model.postnet, optimizer, reset_optimizer)
if checkpoint_seq2seq_path is not None:
load_checkpoint(checkpoint_seq2seq_path, model.seq2seq, optimizer, reset_optimizer)
if checkpoint_path is not None:
load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer)
# Load embedding
if load_embedding is not None:
print("Loading embedding from {}".format(load_embedding))
_load_embedding(load_embedding, model)
# Setup summary writer for tensorboard
if log_event_path is None: