Skip to content

Commit

Permalink
fix(cli): remove requirement for vocoder when synthesizing non-wav ou…
Browse files Browse the repository at this point in the history
…tputs
  • Loading branch information
roedoejet committed Sep 19, 2024
1 parent db4f336 commit 6b1608d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
62 changes: 30 additions & 32 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,11 @@ def get_global_step(model_path: Path) -> int:

def synthesize_helper(
model,
vocoder_model,
vocoder_config,
texts: list[str],
language: Optional[str],
speaker: Optional[str],
duration_control: Optional[float],
global_step: int,
vocoder_global_step: int,
output_type: list[SynthesizeOutputFormats],
text_representation: DatasetTextRepresentation,
accelerator: str,
Expand All @@ -191,6 +188,9 @@ def synthesize_helper(
filelist: Path,
output_dir: Path,
teacher_forcing_directory: Path,
vocoder_global_step: Optional[int] = None,
vocoder_model=None,
vocoder_config=None,
):
"""This is a helper to perform synthesis once the model has been loaded.
It allows us to use the same command for synthesis via the CLI and
Expand Down Expand Up @@ -416,7 +416,7 @@ def synthesize( # noqa: C901

# Load checkpoints
print(f"Loading checkpoint from {model_path}", file=sys.stderr)
model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device)
model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device) # type: ignore
model.eval()

# get global step
Expand All @@ -425,37 +425,35 @@ def synthesize( # noqa: C901

# load vocoder
logger.info(f"Loading Vocoder from {vocoder_path}")
if vocoder_path is None:
logger.error(
"No vocoder was provided, please specify "
"--vocoder-path /path/to/vocoder on the command line."
)
sys.exit(1)
else:
if vocoder_path is not None:
vocoder_ckpt = torch.load(vocoder_path, map_location=device)
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(
vocoder_ckpt, device
)
# We can't just use model.global_step because it gets reset by lightning
vocoder_global_step = get_global_step(vocoder_path)
return synthesize_helper(
model=model,
texts=texts,
language=language,
speaker=speaker,
duration_control=duration_control,
global_step=global_step,
output_type=output_type,
text_representation=text_representation,
accelerator=accelerator,
devices=devices,
device=device,
batch_size=batch_size,
num_workers=num_workers,
filelist=filelist,
teacher_forcing_directory=teacher_forcing_directory,
output_dir=output_dir,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
else:
vocoder_model = None
vocoder_config = None
vocoder_global_step = None
return synthesize_helper(
model=model,
texts=texts,
language=language,
speaker=speaker,
duration_control=duration_control,
global_step=global_step,
output_type=output_type,
text_representation=text_representation,
accelerator=accelerator,
devices=devices,
device=device,
batch_size=batch_size,
num_workers=num_workers,
filelist=filelist,
teacher_forcing_directory=teacher_forcing_directory,
output_dir=output_dir,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
16 changes: 12 additions & 4 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Optional, Sequence

import numpy as np
import torch
Expand All @@ -21,9 +21,9 @@ def get_synthesis_output_callbacks(
output_key: str,
device: torch.device,
global_step: int,
vocoder_model: HiFiGAN,
vocoder_config: HiFiGANConfig,
vocoder_global_step: int,
vocoder_model: Optional[HiFiGAN] = None,
vocoder_config: Optional[HiFiGANConfig] = None,
vocoder_global_step: Optional[int] = None,
):
"""
Given a list of desired output file formats, return the proper callbacks
Expand All @@ -49,6 +49,14 @@ def get_synthesis_output_callbacks(
)
)
if SynthesizeOutputFormats.wav in output_type:
if (
vocoder_model is None
or vocoder_config is None
or vocoder_global_step is None
):
raise ValueError(
"We cannot synthesize waveforms without a vocoder. Please ensure that a vocoder is specified."
)
callbacks.append(
PredictionWritingWavCallback(
config=config,
Expand Down

0 comments on commit 6b1608d

Please sign in to comment.