Skip to content

Commit

Permalink
fixed loss functions & scaling factors
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 6, 2024
1 parent 58f6562 commit 01cc307
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 42 deletions.
21 changes: 14 additions & 7 deletions egs/libritts/CODEC/encodec/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))

def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) - x).mean()
return F.relu(torch.ones_like(x) - x).mean()

def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) + x).mean()
return F.relu(torch.ones_like(x) + x).mean()


class FeatureLoss(torch.nn.Module):
Expand Down Expand Up @@ -200,7 +200,7 @@ def forward(
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += (
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
).mean()
if self.average_by_layers:
feat_match_loss_ /= j + 1
Expand Down Expand Up @@ -272,9 +272,16 @@ def forward(
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))

mel_loss += F.l1_loss(
mel_hat, mel, reduce=True, reduction="mean"
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
mel_loss += (
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
+ (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)

# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
Expand Down Expand Up @@ -307,7 +314,7 @@ def forward(
Tensor: Wav loss value.
"""
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
wav_loss = F.l1_loss(x, x_hat)

return wav_loss

Expand Down
75 changes: 50 additions & 25 deletions egs/libritts/CODEC/encodec/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,40 @@
from torch.optim.lr_scheduler import _LRScheduler


# It will be replaced with huggingface optimization
class WarmUpLR(_LRScheduler):
"""warmup_training learning rate scheduler
Args:
optimizer: optimzier(e.g. SGD)
total_iters: totoal_iters of warmup phase
"""

def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1):

self.total_iters = iter_per_epoch * warmup_epoch
self.iter_per_epoch = iter_per_epoch
super().__init__(optimizer, last_epoch)

def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters
"""
return [
base_lr * self.last_epoch / (self.total_iters + 1e-8)
for base_lr in self.base_lrs
]


class WarmupLrScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.warmup_epoch = warmup_epoch
self.warmup_iter = warmup_iter
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
Expand All @@ -24,7 +48,7 @@ def get_lr(self):
return lrs

def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch:
if self.last_epoch < self.warmup_iter:
ratio = self.get_warmup_ratio()
else:
ratio = self.get_main_ratio()
Expand All @@ -35,7 +59,7 @@ def get_main_ratio(self):

def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch
alpha = self.last_epoch / self.warmup_iter
if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp":
Expand All @@ -48,22 +72,22 @@ def __init__(
self,
optimizer,
power,
max_epoch,
warmup_epoch=500,
max_iter,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.power = power
self.max_epoch = max_epoch
self.max_iter = max_iter
super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)

def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
real_max_epoch = self.max_epoch - self.warmup_epoch
alpha = real_epoch / real_max_epoch
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
alpha = real_iter / real_max_iter
ratio = (1 - alpha) ** self.power
return ratio

Expand All @@ -74,46 +98,47 @@ def __init__(
optimizer,
gamma,
interval=1,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.gamma = gamma
self.interval = interval
super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)

def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** (real_epoch // self.interval)
real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** (real_iter // self.interval)
return ratio


class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
max_epoch,
max_iter,
eta_ratio=0,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.eta_ratio = eta_ratio
self.max_epoch = max_epoch
self.max_iter = max_iter
super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)

def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
return (
self.eta_ratio
+ (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
* (1 + math.cos(math.pi * self.last_epoch / real_max_iter))
/ 2
)

Expand All @@ -124,18 +149,18 @@ def __init__(
optimizer,
milestones: list,
gamma=0.1,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.milestones = milestones
self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)

def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** bisect_right(self.milestones, real_iter)
return ratio
26 changes: 16 additions & 10 deletions egs/libritts/CODEC/encodec/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def get_params() -> AttributeDict:
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 24000,
"audio_normalization": False,
"chunk_size": 1.0, # in seconds
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
Expand Down Expand Up @@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module:
}
discriminator_params = {
"stft_discriminator_n_filters": 32,
"discriminator_epoch_start": 3,
"discriminator_epoch_start": 5,
"n_ffts": [1024, 2048, 512],
"hop_lengths": [256, 512, 128],
"win_lengths": [1024, 2048, 512],
}
inference_params = {
"target_bw": 12,
"target_bw": 6,
}

params.update(generator_params)
Expand Down Expand Up @@ -353,6 +354,11 @@ def prepare_input(
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
]

if params.audio_normalization:
mean = audio.mean(dim=-1, keepdim=True)
std = audio.std(dim=-1, keepdim=True)
audio = (audio - mean) / (std + 1e-7)

return audio, audio_lens, features, features_lens


Expand Down Expand Up @@ -532,6 +538,10 @@ def save_bad_model(suffix: str = ""):
save_bad_model()
raise

# step per iteration
scheduler_g.step()
scheduler_d.step()

if params.print_diagnostics and batch_idx == 5:
return

Expand Down Expand Up @@ -1009,16 +1019,16 @@ def run(rank, world_size, args):

scheduler_g = WarmupCosineLrScheduler(
optimizer=optimizer_g,
max_epoch=params.num_epochs,
max_iter=params.num_epochs * 1500,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4,
)
scheduler_d = WarmupCosineLrScheduler(
optimizer=optimizer_d,
max_epoch=params.num_epochs,
max_iter=params.num_epochs * 1500,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4,
)

Expand Down Expand Up @@ -1128,10 +1138,6 @@ def run(rank, world_size, args):
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)

# step per epoch
scheduler_g.step()
scheduler_d.step()

logging.info("Done!")

if world_size > 1:
Expand Down

0 comments on commit 01cc307

Please sign in to comment.