Skip to content

Commit

Permalink
Merge pull request #837 from IAHispano/formatter/main
Browse files Browse the repository at this point in the history
chore(format): run black on main
  • Loading branch information
blaisewf authored Oct 26, 2024
2 parents 5ebc144 + 90a2e40 commit e8383ce
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def continue_overtrain_detector(training_file_path):

if cleanup:
print("Removing files from the prior training attempt...")

# Clean up unnecessary files
for root, dirs, files in os.walk(
os.path.join(now_dir, "logs", model_name), topdown=False
Expand All @@ -273,10 +273,11 @@ def continue_overtrain_detector(training_file_path):
os.rmdir(folder_path)

print("Cleanup done!")

continue_overtrain_detector(training_file_path)
start()


def run(
rank,
n_gpus,
Expand Down Expand Up @@ -675,7 +676,9 @@ def train_and_evaluate(
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
loss_kl = (
kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
)
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
Expand Down Expand Up @@ -717,11 +720,11 @@ def train_and_evaluate(
"loss/g/mel": loss_mel,
"loss/g/kl": loss_kl,
}
# commented out
#scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
#scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
#scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
# commented out
# scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
# scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})

image_dict = {
"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
Expand Down

0 comments on commit e8383ce

Please sign in to comment.