diff --git a/fs2/model.py b/fs2/model.py index 76f6b4c..2e7cb42 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -313,7 +313,7 @@ def _validation_global_step_0(self, batch, batch_idx) -> None: self.config.preprocessing.audio.output_sampling_rate, ) if self.config.training.vocoder_path: - input_ = batch["mel"] + input_ = batch["mel"].transpose(1, 2) vocoder_ckpt = torch.load( self.config.training.vocoder_path, map_location=input_.device ) @@ -391,7 +391,7 @@ def _validation_batch_idx_0(self, batch, batch_idx, output) -> None: ) if self.config.training.vocoder_path: - input_ = output[self.output_key] + input_ = output[self.output_key].transpose(1, 2) vocoder_ckpt = torch.load( self.config.training.vocoder_path, map_location=input_.device )