Skip to content

Commit

Permalink
refactored loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 5, 2024
1 parent 1e65a97 commit f9340cc
Show file tree
Hide file tree
Showing 5 changed files with 620 additions and 186 deletions.
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/codec_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
group.add_argument(
"--num-workers",
type=int,
default=2,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
Expand Down
156 changes: 106 additions & 50 deletions egs/libritts/CODEC/encodec/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import numpy as np
import torch
from loss import loss_dis, loss_g
from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss,
WavReconstructionLoss,
)
from torch import nn
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -47,19 +53,30 @@ def __init__(
self.cache_generator_outputs = cache_generator_outputs
self._cache = None

# construct loss functions
self.generator_adversarial_loss = GeneratorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate
)

def _forward_generator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool = False,
):
"""Perform generator forward.
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
return_sample (bool): Return the generator output.
Returns:
Expand Down Expand Up @@ -107,33 +124,56 @@ def _forward_generator(

# calculate losses
with autocast(enabled=False):
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
commit_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_p,
y_p_hat,
y_s,
y_s_hat,
fmap_p,
fmap_p_hat,
fmap_s,
fmap_s_hat,
args=self.params,
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)

feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
feature_period_loss = self.feature_match_loss(
feats=fmap_p, feats_hat=fmap_p_hat
)
feature_scale_loss = self.feature_match_loss(
feats=fmap_s, feats_hat=fmap_s_hat
)

wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat
)
mel_reconstruction_loss = self.mel_reconstruction_loss(
x=speech, x_hat=speech_hat
)

# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
# commit_loss,
# speech,
# speech_hat,
# fmap,
# fmap_hat,
# y,
# y_hat,
# y_p,
# y_p_hat,
# y_s,
# y_s_hat,
# fmap_p,
# fmap_p_hat,
# fmap_s,
# fmap_s_hat,
# args=self.params,
# )

stats = dict(
generator_loss=loss.item(),
generator_reconstruction_loss=rec_loss.item(),
generator_feature_loss=feat_loss.item(),
generator_adv_loss=adv_loss.item(),
# generator_loss=loss.item(),
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
generator_feature_stft_loss=feature_stft_loss.item(),
generator_feature_period_loss=feature_period_loss.item(),
generator_feature_scale_loss=feature_scale_loss.item(),
generator_stft_adv_loss=gen_stft_adv_loss.item(),
generator_period_adv_loss=gen_period_adv_loss.item(),
generator_scale_adv_loss=gen_scale_adv_loss.item(),
generator_commit_loss=commit_loss.item(),
d_weight=d_weight.item(),
# d_weight=d_weight.item(),
)

if return_sample:
Expand All @@ -147,19 +187,28 @@ def _forward_generator(
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
return (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats,
)

def _forward_discriminator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
):
"""
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
Returns:
* loss (Tensor): Loss scalar tensor.
Expand Down Expand Up @@ -206,52 +255,59 @@ def _forward_discriminator(
)
# calculate losses
with autocast(enabled=False):
loss = loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_p,
y_p_hat,
fmap_p,
fmap_p_hat,
y_s,
y_s_hat,
fmap_s,
fmap_s_hat,
global_step,
args=self.params,
)
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
(
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
(
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)

stats = dict(
discriminator_loss=loss.item(),
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
)

# reset cache
if reuse_cache or not self.training:
self._cache = None

return loss, stats
return (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats,
)

def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool,
forward_generator: bool,
):
if forward_generator:
return self._forward_generator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
return_sample=return_sample,
)
else:
return self._forward_discriminator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
)

def encode(self, x, target_bw=None, st=None):
Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_parser():
parser.add_argument(
"--target-bw",
type=float,
default=7.5,
default=24,
help="The target bandwidth for the generator",
)

Expand Down
Loading

0 comments on commit f9340cc

Please sign in to comment.