From 90a2e4062f62440d62fb3ad4f0304e40368a5318 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 26 Oct 2024 11:20:16 +0000 Subject: [PATCH] chore(format): run black on main --- rvc/train/train.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/rvc/train/train.py b/rvc/train/train.py index ca4c8da6..d9026f75 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -248,7 +248,7 @@ def continue_overtrain_detector(training_file_path): if cleanup: print("Removing files from the prior training attempt...") - + # Clean up unnecessary files for root, dirs, files in os.walk( os.path.join(now_dir, "logs", model_name), topdown=False @@ -273,10 +273,11 @@ def continue_overtrain_detector(training_file_path): os.rmdir(folder_path) print("Cleanup done!") - + continue_overtrain_detector(training_file_path) start() + def run( rank, n_gpus, @@ -675,7 +676,9 @@ def train_and_evaluate( y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl + loss_kl = ( + kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl + ) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl @@ -717,11 +720,11 @@ def train_and_evaluate( "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, } - # commented out - #scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) - #scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}) - #scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}) - + # commented out + # scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}) + image_dict = { "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),