Skip to content

Commit

Permalink
refactor!: change model to expect mel-band oriented tensors instead o…
Browse files Browse the repository at this point in the history
…f time-oriented ones
  • Loading branch information
roedoejet committed Oct 30, 2024
1 parent 7dc68ff commit b301c36
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
29 changes: 26 additions & 3 deletions hfgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def synthesize(
exists=True,
dir_okay=False,
file_okay=True,
help="The path to a torch file containing time-oriented spectral features [T (frames), K (Mel bands)]",
help="The path to a torch file containing Mel band-oriented spectral features [K (Mel bands), T (frames)]",
autocompletion=complete_path,
),
generator_path: Path = typer.Option(
Expand All @@ -132,28 +132,51 @@ def synthesize(
help="The path to a trained EveryVoice spec-to-wav model",
autocompletion=complete_path,
),
time_oriented: bool = typer.Option(
False,
help="By default, EveryVoice assumes your spectrograms are of the shape [K (Mel bands), T (frames)]. If instead your spectrograms are of shape [T (frames), K (Mel bands)] then please add this flag to transpose the dimensions.",
),
):
"""Given some Mel spectrograms and a trained model, generate some audio. i.e. perform *copy synthesis*"""
import sys

with spinner():
import torch
import torchaudio
from pydantic import ValidationError
from scipy.io.wavfile import write

from .utils import load_hifigan_from_checkpoint, synthesize_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(generator_path, map_location=device)
data = torch.load(data_path, map_location=device)
if time_oriented:
data = data.transpose(0, 1)
data_size = data.size()
if (
checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"]["n_mels"]
not in data_size
):
raise ValueError(
f"Your model expects a spectrogram of dimensions [K (Mel bands), T (frames)] where K == {checkpoint['hyper_parameters']['config']['preprocessing']['audio']['n_mels']} but you provided a tensor of size {data_size}"
)
if (
data_size[0]
!= checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"]["n_mels"]
):
raise ValueError(
f"We expected the first dimension of your Mel spectrogram to correspond with the number of Mel bands declared by your model ({checkpoint['hyper_parameters']['config']['preprocessing']['audio']['n_mels']}). Instead, we found you model has the dimensions {data_size}. If your spectrogram is time-oriented, please re-run this command with the '--time-oriented' flag."
)
try:
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(checkpoint, device)
except (TypeError, ValidationError) as e:
logger.error(f"Unable to load {generator_path}: {e}")
sys.exit(1)
wav, sr = synthesize_data(data, vocoder_model, vocoder_config)
logger.info(f"Writing file {data_path}.wav")
write(f"{data_path}.wav", sr, wav)
torchaudio.save(
f"{data_path}.wav", wav, sr, format="wav", encoding="PCM_S", bits_per_sample=16
)


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions hfgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def synthesize_data(
"""Synthesize a batch of waveforms from spectral features
Args:
data (Tensor): data tensor, expects output from feature prediction network to be size (b=batch_size, t=number_of_frames, k=n_mels)
data (Tensor): data tensor, expects output from feature prediction network to be size (b=batch_size, k=n_mels, t=number_of_frames,)
ckpt (dict): HiFiGANLightning checkpoint, expects checkpoint to have a 'hyper_parameters.config' key and HiFiGANConfig object value as well as a 'state_dict' key with model weight as the value
Returns:
Tuple[np.ndarray, int]: a B, T array of the synthesized audio and the sampling rate
Expand All @@ -67,7 +67,7 @@ def synthesize_data(
model.generator.post_n_fft // 4,
).to(data.device)
with torch.no_grad():
mag, phase = model.generator(data.transpose(1, 2))
mag, phase = model.generator(data)
# We can remove this once the fix for https://github.com/pytorch/pytorch/issues/119088 is merged
if mag.device.type == "mps" or phase.device.type == "mps":
logger.warning(
Expand All @@ -79,10 +79,9 @@ def synthesize_data(
wavs = inverse_spectral_transform(mag * torch.exp(phase * 1j)).unsqueeze(-2)
else:
with torch.no_grad():
wavs = model.generator(data.transpose(1, 2))
# squeeze to remove the channel dimension
wavs = model.generator(data)
return (
wavs.squeeze(1).cpu().numpy(),
wavs.cpu(),
config.preprocessing.audio.output_sampling_rate,
)

Expand Down

0 comments on commit b301c36

Please sign in to comment.