diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index a1fabf9ad6..816b65e65a 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -25,7 +25,7 @@ KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.cuda.amp import autocast +from torch.amp import autocast from utils import get_segments AVAILABLE_GENERATERS = { @@ -410,7 +410,7 @@ def _forward_generator( p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast('cuda',enabled=False): if not return_sample: mel_loss = self.mel_loss(speech_hat_, speech_) else: @@ -518,7 +518,7 @@ def _forward_discrminator( p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast('cuda',enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss