diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index 274be64..b7f4674 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -343,7 +343,7 @@ def synthesize( # noqa: C901 help="""Which format(s) to synthesize to. Multiple formats can be provided by repeating `--output-type`. '**wav**' is the default and will synthesize to a playable audio file; - '**spec**' will generate predicted Mel spectrograms. Tensors are time-oriented (T, K) where T is equal to the number of frames and K is equal to the number of Mel bands. + '**spec**' will generate predicted Mel spectrograms. Tensors are Mel band-oriented (K, T) where K is equal to the number of Mel bands and T is equal to the number of frames. '**textgrid**' will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation. '**readalong**' will generate a ReadAlong from the given text and synthesized audio (see https://github.com/ReadAlongs). """, @@ -436,8 +436,8 @@ def synthesize( # noqa: C901 global_step = get_global_step(model_path) # load vocoder - logger.info(f"Loading Vocoder from {vocoder_path}") if vocoder_path is not None: + logger.info(f"Loading Vocoder from {vocoder_path}") vocoder_ckpt = torch.load(vocoder_path, map_location=device) try: vocoder_model, vocoder_config = load_hifigan_from_checkpoint( diff --git a/fs2/model.py b/fs2/model.py index 5a41c38..23660ee 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -353,7 +353,7 @@ def _validation_global_step_0(self, batch, batch_idx) -> None: self.config.preprocessing.audio.output_sampling_rate, ) if self.config.training.vocoder_path: - input_ = batch["mel"] + input_ = batch["mel"].transpose(1, 2) vocoder_ckpt = torch.load( self.config.training.vocoder_path, map_location=input_.device ) @@ -431,7 +431,7 @@ def _validation_batch_idx_0(self, batch, batch_idx, output) -> None: ) if self.config.training.vocoder_path: - input_ = output[self.output_key] + input_ = output[self.output_key].transpose(1, 2) vocoder_ckpt = torch.load( self.config.training.vocoder_path, map_location=input_.device ) diff --git a/fs2/prediction_writing_callback.py b/fs2/prediction_writing_callback.py index 8a8799f..f92bb7f 100644 --- a/fs2/prediction_writing_callback.py +++ b/fs2/prediction_writing_callback.py @@ -106,7 +106,7 @@ def get_filename( speaker: str, language: str, include_global_step: bool = False, - ) -> Path: + ) -> str: # We don't truncate or alter the filename here because the basename is # already truncated/cleaned in cli/synthesize.py name_parts = [basename, speaker, language, self.file_extension] @@ -115,7 +115,7 @@ def get_filename( path = self.save_dir / self.sep.join(name_parts) # synthesizing spec allows nested outputs so we may need to make subdirs path.parent.mkdir(parents=True, exist_ok=True) - return path + return str(path) class PredictionWritingSpecCallback(PredictionWritingCallbackBase): @@ -159,7 +159,9 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] outputs["tgt_lens"], ): torch.save( - data[:unmasked_len].cpu(), + data[:unmasked_len] + .cpu() + .transpose(0, 1), # save tensors as [K (bands), T (frames)] self.get_filename(basename, speaker, language), ) @@ -409,7 +411,7 @@ def synthesize_audio(self, outputs): sr: int wavs: np.ndarray - output_value = outputs[self.output_key] + output_value = outputs[self.output_key].transpose(1, 2) if output_value is not None: wavs, sr = synthesize_data( output_value, self.vocoder_model, self.vocoder_config @@ -419,21 +421,15 @@ def synthesize_audio(self, outputs): f"{self.output_key} does not exist in the output of your model" ) - # wavs: [B (batch_size), T (samples)] + # wavs: [B (batch_size), C (channels), T (samples)] assert ( - wavs.ndim == 2 - ), f"The generated audio contained more than 2 dimensions. First dimension should be B(atch) and the second dimension should be T(ime) in samples. Got {wavs.shape} instead." - assert "output" in outputs and outputs["output"] is not None - assert wavs.shape[0] == outputs["output"].size( + wavs.ndim == 3 + ), f"The generated audio did not contain 3 dimensions. First dimension should be B(atch) and the second dimension should be C(hannels) and third dimension should be T(ime) in samples. Got {wavs.shape} instead." + assert self.output_key in outputs and outputs[self.output_key] is not None + assert wavs.shape[0] == outputs[self.output_key].size( 0 - ), f"You provided {outputs['output'].size(0)} utterances, but {wavs.shape[0]} audio files were synthesized instead." - - # synthesize 16 bit audio - # we don't do this higher up in the inference methods - # because tensorboard logs require audio data as floats - if (wavs >= -1.0).all() & (wavs <= 1.0).all(): - wavs = wavs * self.config.preprocessing.audio.max_wav_value - wavs = wavs.astype("int16") + ), f"You provided {outputs[self.output_key].size(0)} utterances, but {wavs.shape[0]} audio files were synthesized instead." + return wavs, sr def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] @@ -445,7 +441,7 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] _batch_idx: int, _dataloader_idx: int = 0, ): - from scipy.io.wavfile import write + import torchaudio logger.trace("Generating waveform...") @@ -459,11 +455,14 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] wavs, outputs["tgt_lens"], ): - write( + torchaudio.save( self.get_filename( basename, speaker, language, include_global_step=True ), - sr, # the vocoder output includes padding so we have to remove that wav[: (unmasked_len * self.output_hop_size)], + sr, + format="wav", + encoding="PCM_S", + bits_per_sample=16, )