Skip to content

Commit

Permalink
Update vits.py
Browse files Browse the repository at this point in the history
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  • Loading branch information
mablue authored Nov 30, 2024
1 parent b7b6520 commit 046d257
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions egs/ljspeech/TTS/vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from torch.amp import autocast
from utils import get_segments

AVAILABLE_GENERATERS = {
Expand Down Expand Up @@ -410,7 +410,7 @@ def _forward_generator(
p = self.discriminator(speech_)

# calculate losses
with autocast(enabled=False):
with autocast('cuda',enabled=False):
if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_)
else:
Expand Down Expand Up @@ -518,7 +518,7 @@ def _forward_discrminator(
p = self.discriminator(speech_)

# calculate losses
with autocast(enabled=False):
with autocast('cuda',enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss

Expand Down

0 comments on commit 046d257

Please sign in to comment.