From 046d25757087f155249102eab936e819f2d004e5 Mon Sep 17 00:00:00 2001 From: Masoud Azizi Date: Sat, 30 Nov 2024 18:55:13 +0330 Subject: [PATCH] Update vits.py FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. --- egs/ljspeech/TTS/vits/vits.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index a1fabf9ad6..816b65e65a 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -25,7 +25,7 @@ KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.cuda.amp import autocast +from torch.amp import autocast from utils import get_segments AVAILABLE_GENERATERS = { @@ -410,7 +410,7 @@ def _forward_generator( p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast('cuda',enabled=False): if not return_sample: mel_loss = self.mel_loss(speech_hat_, speech_) else: @@ -518,7 +518,7 @@ def _forward_discrminator( p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast('cuda',enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss