Skip to content

Commit

Permalink
Merge pull request #792 from AznamirWoW/audio_tensorboard
Browse files Browse the repository at this point in the history
made a randomly picked, but static reference for tensorboard
  • Loading branch information
blaisewf authored Oct 7, 2024
2 parents 8c5ef4e + ca44c03 commit 8e98824
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,16 @@ def run(
scaler = GradScaler(enabled=config.train.fp16_run and device.type == "cuda")

cache = []
# get the first sample as reference for tensorboard evaluation
for info in train_loader:
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
reference = (phone.to(device),
phone_lengths.to(device),
pitch.to(device) if pitch_guidance else None,
pitchf.to(device) if pitch_guidance else None,
sid.to(device))
break

for epoch in range(epoch_str, total_epoch + 1):
if rank == 0:
train_and_evaluate(
Expand All @@ -504,6 +514,7 @@ def run(
custom_save_every_weights,
custom_total_epoch,
device,
reference
)
else:
train_and_evaluate(
Expand All @@ -519,6 +530,7 @@ def run(
custom_save_every_weights,
custom_total_epoch,
device,
reference,
)
scheduler_g.step()
scheduler_d.step()
Expand All @@ -537,6 +549,7 @@ def train_and_evaluate(
custom_save_every_weights,
custom_total_epoch,
device,
reference,
):
"""
Trains and evaluates the model for one epoch.
Expand Down Expand Up @@ -778,10 +791,10 @@ def train_and_evaluate(
),
"all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}
audio_dict = {}

with torch.no_grad():
o, *_ = net_g.infer(phone, phone_lengths, pitch, pitchf, sid)
audio_dict.update({f"gen/audio_{global_step:07d}": o[0, :, :]})
o, *_ = net_g.infer(*reference)
audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, : ]}

summarize(
writer=writer,
Expand Down

0 comments on commit 8e98824

Please sign in to comment.