diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index e84f08e708..e77a255e56 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -139,7 +139,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-workers", type=int, - default=2, + default=8, help="The number of training dataloader workers that " "collect the batches.", ) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 4f45be9c25..4701423922 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -4,7 +4,13 @@ import numpy as np import torch -from loss import loss_dis, loss_g +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramReconstructionLoss, + WavReconstructionLoss, +) from torch import nn from torch.cuda.amp import autocast @@ -47,11 +53,23 @@ def __init__( self.cache_generator_outputs = cache_generator_outputs self._cache = None + # construct loss functions + self.generator_adversarial_loss = GeneratorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.feature_match_loss = FeatureMatchLoss(average_by_layers=False) + self.wav_reconstruction_loss = WavReconstructionLoss() + self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( + sampling_rate=self.sampling_rate + ) + def _forward_generator( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, return_sample: bool = False, ): """Perform generator forward. @@ -59,7 +77,6 @@ def _forward_generator( Args: speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). - global_step (int): Global step. return_sample (bool): Return the generator output. Returns: @@ -107,33 +124,56 @@ def _forward_generator( # calculate losses with autocast(enabled=False): - loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( - commit_loss, - speech, - speech_hat, - fmap, - fmap_hat, - y, - y_hat, - global_step, - y_p, - y_p_hat, - y_s, - y_s_hat, - fmap_p, - fmap_p_hat, - fmap_s, - fmap_s_hat, - args=self.params, + gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) + gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) + gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) + + feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) + feature_period_loss = self.feature_match_loss( + feats=fmap_p, feats_hat=fmap_p_hat + ) + feature_scale_loss = self.feature_match_loss( + feats=fmap_s, feats_hat=fmap_s_hat + ) + + wav_reconstruction_loss = self.wav_reconstruction_loss( + x=speech, x_hat=speech_hat + ) + mel_reconstruction_loss = self.mel_reconstruction_loss( + x=speech, x_hat=speech_hat ) + # loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( + # commit_loss, + # speech, + # speech_hat, + # fmap, + # fmap_hat, + # y, + # y_hat, + # y_p, + # y_p_hat, + # y_s, + # y_s_hat, + # fmap_p, + # fmap_p_hat, + # fmap_s, + # fmap_s_hat, + # args=self.params, + # ) + stats = dict( - generator_loss=loss.item(), - generator_reconstruction_loss=rec_loss.item(), - generator_feature_loss=feat_loss.item(), - generator_adv_loss=adv_loss.item(), + # generator_loss=loss.item(), + generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), + generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), + generator_feature_stft_loss=feature_stft_loss.item(), + generator_feature_period_loss=feature_period_loss.item(), + generator_feature_scale_loss=feature_scale_loss.item(), + generator_stft_adv_loss=gen_stft_adv_loss.item(), + generator_period_adv_loss=gen_period_adv_loss.item(), + generator_scale_adv_loss=gen_scale_adv_loss.item(), generator_commit_loss=commit_loss.item(), - d_weight=d_weight.item(), + # d_weight=d_weight.item(), ) if return_sample: @@ -147,19 +187,28 @@ def _forward_generator( # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats + return ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats, + ) def _forward_discriminator( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, ): """ Args: speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). - global_step (int): Global step. Returns: * loss (Tensor): Loss scalar tensor. @@ -206,37 +255,46 @@ def _forward_discriminator( ) # calculate losses with autocast(enabled=False): - loss = loss_dis( - y, - y_hat, - fmap, - fmap_hat, - y_p, - y_p_hat, - fmap_p, - fmap_p_hat, - y_s, - y_s_hat, - fmap_s, - fmap_s_hat, - global_step, - args=self.params, - ) + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) + ( + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat) + ( + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat) + stats = dict( - discriminator_loss=loss.item(), + discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), + discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(), + discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(), + discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(), + discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(), + discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(), ) # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats + return ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats, + ) def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, return_sample: bool, forward_generator: bool, ): @@ -244,14 +302,12 @@ def forward( return self._forward_generator( speech=speech, speech_lengths=speech_lengths, - global_step=global_step, return_sample=return_sample, ) else: return self._forward_discriminator( speech=speech, speech_lengths=speech_lengths, - global_step=global_step, ) def encode(self, x, target_bw=None, st=None): diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index c407b4a593..e5d69fa600 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -71,7 +71,7 @@ def get_parser(): parser.add_argument( "--target-bw", type=float, - default=7.5, + default=24, help="The target bandwidth for the generator", ) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 96300e9d67..7e9bf5590d 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -1,8 +1,310 @@ +from typing import List, Tuple, Union + import torch import torch.nn.functional as F +from lhotse.features.kaldi import Wav2LogFilterBank from torchaudio.transforms import MelSpectrogram +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return F.relu(1 - x).mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(x.new_ones(x.size()) - x).mean() + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(x.new_ones(x.size()) + x).mean() + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramReconstructionLoss(torch.nn.Module): + """Mel Spec Reconstruction loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + n_mels: int = 64, + use_fft_mag: bool = True, + return_mel: bool = False, + ): + super().__init__() + self.wav_to_specs = [] + for i in range(5, 12): + s = 2**i + # self.wav_to_specs.append( + # Wav2LogFilterBank( + # sampling_rate=sampling_rate, + # frame_length=s, + # frame_shift=s // 4, + # use_fft_mag=use_fft_mag, + # num_filters=n_mels, + # ) + # ) + self.wav_to_specs.append( + MelSpectrogram( + sample_rate=sampling_rate, + n_fft=max(s, 512), + win_length=s, + hop_length=s // 4, + n_mels=n_mels, + ) + ) + self.return_mel = return_mel + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_loss = 0.0 + + for i, wav_to_spec in enumerate(self.wav_to_specs): + s = 2 ** (i + 5) + wav_to_spec.to(x.device) + + mel_hat = wav_to_spec(x_hat.squeeze(1)) + mel = wav_to_spec(x.squeeze(1)) + + alpha = (s / 2) ** 0.5 + mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel) + + # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) + # mel = self.wav_to_spec(x.squeeze(1)) + # mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel) + + if self.return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +class WavReconstructionLoss(torch.nn.Module): + """Wav Reconstruction loss.""" + + def __init__(self): + super().__init__() + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Calculate wav loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + + Returns: + Tensor: Wav loss value. + + """ + wav_loss = F.mse_loss(x, x_hat) + + return wav_loss + + def adversarial_g_loss(y_disc_gen): """Hinge loss""" loss = 0.0 @@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7): return L -def criterion_d( - y_disc_r, - y_disc_gen, - fmap_r_det, - fmap_gen_det, - y_df_hat_r, - y_df_hat_g, - fmap_f_r, - fmap_f_g, - y_ds_hat_r, - y_ds_hat_g, - fmap_s_r, - fmap_s_g, -): - """Hinge Loss""" - loss = 0.0 - loss1 = 0.0 - loss2 = 0.0 - loss3 = 0.0 - for i in range(len(y_disc_r)): - loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean() - for i in range(len(y_df_hat_r)): - loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean() - for i in range(len(y_ds_hat_r)): - loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean() - - loss = ( - loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r) - ) / 3.0 - - return loss - - -def criterion_g( - commit_loss, - x, - G_x, - fmap_r, - fmap_gen, - y_disc_r, - y_disc_gen, - y_df_hat_r, - y_df_hat_g, - fmap_f_r, - fmap_f_g, - y_ds_hat_r, - y_ds_hat_g, - fmap_s_r, - fmap_s_g, - args, -): - adv_g_loss = adversarial_g_loss(y_disc_gen) - feat_loss = ( - feature_loss(fmap_r, fmap_gen) - + sim_loss(y_disc_r, y_disc_gen) - + feature_loss(fmap_f_r, fmap_f_g) - + sim_loss(y_df_hat_r, y_df_hat_g) - + feature_loss(fmap_s_r, fmap_s_g) - + sim_loss(y_ds_hat_r, y_ds_hat_g) - ) / 3.0 - rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) - total_loss = ( - args.lambda_com * commit_loss - + args.lambda_adv * adv_g_loss - + args.lambda_feat * feat_loss - + args.lambda_rec * rec_loss - ) - return total_loss, adv_g_loss, feat_loss, rec_loss - - def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight -def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): - if global_step % 3 == 0: - weight = value - return weight - - def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] @@ -166,7 +392,6 @@ def loss_g( fmap_hat, y, y_hat, - global_step, y_df, y_df_hat, y_ds, @@ -215,9 +440,10 @@ def loss_g( feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 d_weight = torch.tensor(1.0) - disc_factor = adopt_weight( - args.lambda_adv, global_step, threshold=args.discriminator_iter_start - ) + # disc_factor = adopt_weight( + # args.lambda_adv, global_step, threshold=args.discriminator_iter_start + # ) + disc_factor = 1 if disc_factor == 0.0: fm_loss_wt = 0 else: @@ -232,37 +458,9 @@ def loss_g( return loss, rec_loss, adv_loss, feat_loss_tot, d_weight -def loss_dis( - y, - y_hat, - fmap, - fmap_hat, - y_df, - y_df_hat, - fmap_f, - fmap_f_hat, - y_ds, - y_ds_hat, - fmap_s, - fmap_s_hat, - global_step, - args, -): - disc_factor = adopt_weight( - args.lambda_adv, global_step, threshold=args.discriminator_iter_start - ) - d_loss = disc_factor * criterion_d( - y, - y_hat, - fmap, - fmap_hat, - y_df, - y_df_hat, - fmap_f, - fmap_f_hat, - y_ds, - y_ds_hat, - fmap_s, - fmap_s_hat, - ) - return d_loss +if __name__ == "__main__": + la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False) + aa = [torch.rand(192, 192) for _ in range(3)] + bb = [torch.rand(192, 192) for _ in range(3)] + print(la(bb, aa)) + print(feature_loss(aa, bb)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 65aec13831..206a72a760 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -15,12 +15,13 @@ from encodec import Encodec from lhotse.cut import Cut from lhotse.utils import fix_random_seed +from loss import adopt_weight from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from utils import MetricsTracker, plot_feature, save_checkpoint +from utils import MetricsTracker, save_checkpoint from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module: from modules.seanet import SEANetDecoder, SEANetEncoder from quantization import ResidualVectorQuantizer + # generator_params = { + # "generator_n_filters": 32, + # "dimension": 512, + # "ratios": [2, 2, 2, 4], + # "target_bandwidths": [7.5, 15], + # "bins": 1024, + # } + # discriminator_params = { + # "stft_discriminator_n_filters": 32, + # "discriminator_iter_start": 500, + # } + # inference_params = { + # "target_bw": 7.5, + # } + generator_params = { "generator_n_filters": 32, "dimension": 512, - "ratios": [2, 2, 2, 4], - "target_bandwidths": [7.5, 15], + "ratios": [8, 5, 4, 2], + "target_bandwidths": [1.5, 3, 6, 12, 24], "bins": 1024, } discriminator_params = { @@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module: "discriminator_iter_start": 500, } inference_params = { - "target_bw": 7.5, + "target_bw": 12, } params.update(generator_params) @@ -419,36 +435,93 @@ def save_bad_model(suffix: str = ""): try: with autocast(enabled=params.use_fp16): + d_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward discriminator - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) + disc_loss = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * d_weight + / 3 + ) for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator optimizer_d.zero_grad() - scaler.scale(loss_d).backward() + scaler.scale(disc_loss).backward() scaler.step(optimizer_d) with autocast(enabled=params.use_fp16): + g_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward generator - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, forward_generator=True, return_sample=params.batch_idx_train % params.log_interval == 0, ) + gen_adv_loss = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * g_weight + / 3 + ) + feature_loss = ( + feature_stft_loss + feature_period_loss + feature_scale_loss + ) / 3 + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + + mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + params.lambda_rec * reconstruction_loss + + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_com * commit_loss + ) for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size # update generator optimizer_g.zero_grad() - scaler.scale(loss_g).backward() + scaler.scale(gen_loss).backward() scaler.step(optimizer_g) scaler.update() @@ -619,27 +692,84 @@ def compute_validation_loss( loss_info = MetricsTracker() loss_info["samples"] = batch_size + d_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + # forward discriminator - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) - assert loss_d.requires_grad is False + disc_loss = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * d_weight + / 3 + ) + assert disc_loss.requires_grad is False for k, v in stats_d.items(): loss_info[k] = v * batch_size + g_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward generator - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, forward_generator=True, return_sample=False, ) - assert loss_g.requires_grad is False + gen_adv_loss = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * g_weight + / 3 + ) + feature_loss = ( + feature_stft_loss + feature_period_loss + feature_scale_loss + ) / 3 + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + params.lambda_rec * reconstruction_loss + + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_com * commit_loss + ) + assert gen_loss.requires_grad is False for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size @@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom( try: # for discriminator with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) + loss_d = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + / 3 + ) optimizer_d.zero_grad() loss_d.backward() # for generator with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, forward_generator=True, - global_step=params.batch_idx_train, return_sample=False, ) + loss_g = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + / 3 + + params.lambda_rec + * ( + params.lambda_wav * wav_reconstruction_loss + + mel_reconstruction_loss + ) + + params.lambda_feat + * (feature_stft_loss + feature_period_loss + feature_scale_loss) + + params.lambda_com * commit_loss + ) optimizer_g.zero_grad() loss_g.backward() except Exception as e: