Skip to content

Commit

Permalink
fix: the vocoder expects [B, K, T] tensors and this applies during tr…
Browse files Browse the repository at this point in the history
…aining too
  • Loading branch information
roedoejet authored and joanise committed Dec 10, 2024
1 parent 229645c commit 9549082
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _validation_global_step_0(self, batch, batch_idx) -> None:
self.config.preprocessing.audio.output_sampling_rate,
)
if self.config.training.vocoder_path:
input_ = batch["mel"]
input_ = batch["mel"].transpose(1, 2)

Check warning on line 356 in fs2/model.py

View check run for this annotation

Codecov / codecov/patch

fs2/model.py#L356

Added line #L356 was not covered by tests
vocoder_ckpt = torch.load(
self.config.training.vocoder_path, map_location=input_.device
)
Expand Down Expand Up @@ -431,7 +431,7 @@ def _validation_batch_idx_0(self, batch, batch_idx, output) -> None:
)

if self.config.training.vocoder_path:
input_ = output[self.output_key]
input_ = output[self.output_key].transpose(1, 2)

Check warning on line 434 in fs2/model.py

View check run for this annotation

Codecov / codecov/patch

fs2/model.py#L434

Added line #L434 was not covered by tests
vocoder_ckpt = torch.load(
self.config.training.vocoder_path, map_location=input_.device
)
Expand Down

0 comments on commit 9549082

Please sign in to comment.