diff --git a/hfgl/cli.py b/hfgl/cli.py index 9b2ca86..e83da6e 100644 --- a/hfgl/cli.py +++ b/hfgl/cli.py @@ -119,7 +119,7 @@ def synthesize( exists=True, dir_okay=False, file_okay=True, - help="The path to a torch file containing time-oriented spectral features [T (frames), K (Mel bands)]", + help="The path to a torch file containing Mel band-oriented spectral features [K (Mel bands), T (frames)]", autocompletion=complete_path, ), generator_path: Path = typer.Option( @@ -132,20 +132,41 @@ def synthesize( help="The path to a trained EveryVoice spec-to-wav model", autocompletion=complete_path, ), + time_oriented: bool = typer.Option( + False, + help="By default, EveryVoice assumes your spectrograms are of the shape [K (Mel bands), T (frames)]. If instead your spectrograms are of shape [T (frames), K (Mel bands)] then please add this flag to transpose the dimensions.", + ), ): """Given some Mel spectrograms and a trained model, generate some audio. i.e. perform *copy synthesis*""" import sys with spinner(): import torch + import torchaudio from pydantic import ValidationError - from scipy.io.wavfile import write from .utils import load_hifigan_from_checkpoint, synthesize_data device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(generator_path, map_location=device) data = torch.load(data_path, map_location=device) + if time_oriented: + data = data.transpose(0, 1) + data_size = data.size() + if ( + checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"]["n_mels"] + not in data_size + ): + raise ValueError( + f"Your model expects a spectrogram of dimensions [K (Mel bands), T (frames)] where K == {checkpoint['hyper_parameters']['config']['preprocessing']['audio']['n_mels']} but you provided a tensor of size {data_size}" + ) + if ( + data_size[0] + != checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"]["n_mels"] + ): + raise ValueError( + f"We expected the first dimension of your Mel spectrogram to correspond with the number of Mel bands declared by your model ({checkpoint['hyper_parameters']['config']['preprocessing']['audio']['n_mels']}). Instead, we found you model has the dimensions {data_size}. If your spectrogram is time-oriented, please re-run this command with the '--time-oriented' flag." + ) try: vocoder_model, vocoder_config = load_hifigan_from_checkpoint(checkpoint, device) except (TypeError, ValidationError) as e: @@ -153,7 +174,9 @@ def synthesize( sys.exit(1) wav, sr = synthesize_data(data, vocoder_model, vocoder_config) logger.info(f"Writing file {data_path}.wav") - write(f"{data_path}.wav", sr, wav) + torchaudio.save( + f"{data_path}.wav", wav, sr, format="wav", encoding="PCM_S", bits_per_sample=16 + ) if __name__ == "__main__": diff --git a/hfgl/utils.py b/hfgl/utils.py index 2e88274..d955c20 100644 --- a/hfgl/utils.py +++ b/hfgl/utils.py @@ -54,7 +54,7 @@ def synthesize_data( """Synthesize a batch of waveforms from spectral features Args: - data (Tensor): data tensor, expects output from feature prediction network to be size (b=batch_size, t=number_of_frames, k=n_mels) + data (Tensor): data tensor, expects output from feature prediction network to be size (b=batch_size, k=n_mels, t=number_of_frames,) ckpt (dict): HiFiGANLightning checkpoint, expects checkpoint to have a 'hyper_parameters.config' key and HiFiGANConfig object value as well as a 'state_dict' key with model weight as the value Returns: Tuple[np.ndarray, int]: a B, T array of the synthesized audio and the sampling rate @@ -67,7 +67,7 @@ def synthesize_data( model.generator.post_n_fft // 4, ).to(data.device) with torch.no_grad(): - mag, phase = model.generator(data.transpose(1, 2)) + mag, phase = model.generator(data) # We can remove this once the fix for https://github.com/pytorch/pytorch/issues/119088 is merged if mag.device.type == "mps" or phase.device.type == "mps": logger.warning( @@ -79,10 +79,9 @@ def synthesize_data( wavs = inverse_spectral_transform(mag * torch.exp(phase * 1j)).unsqueeze(-2) else: with torch.no_grad(): - wavs = model.generator(data.transpose(1, 2)) - # squeeze to remove the channel dimension + wavs = model.generator(data) return ( - wavs.squeeze(1).cpu().numpy(), + wavs.cpu(), config.preprocessing.audio.output_sampling_rate, )