Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mablue authored Nov 30, 2024
1 parent a1ade8e commit b7b6520
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions egs/ljspeech/TTS/vits/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -396,7 +396,7 @@ def save_bad_model(suffix: str = ""):
loss_info["samples"] = batch_size

try:
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
text=tokens,
Expand All @@ -414,7 +414,7 @@ def save_bad_model(suffix: str = ""):
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)

with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
text=tokens,
Expand Down Expand Up @@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom(
)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
Expand All @@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
Expand Down Expand Up @@ -838,7 +838,7 @@ def remove_short_and_long_utt(c: Cut):
params=params,
)

scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
Expand Down

0 comments on commit b7b6520

Please sign in to comment.