Skip to content

Commit

Permalink
fix(cli): remove requirement for vocoder when synthesizing non-wav ou…
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Sep 19, 2024
1 parent db4f336 commit 1cfeb4e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
60 changes: 29 additions & 31 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,11 @@ def get_global_step(model_path: Path) -> int:

def synthesize_helper(
model,
vocoder_model,
vocoder_config,
texts: list[str],
language: Optional[str],
speaker: Optional[str],
duration_control: Optional[float],
global_step: int,
vocoder_global_step: int,
output_type: list[SynthesizeOutputFormats],
text_representation: DatasetTextRepresentation,
accelerator: str,
Expand All @@ -191,6 +188,9 @@ def synthesize_helper(
filelist: Path,
output_dir: Path,
teacher_forcing_directory: Path,
vocoder_global_step: Optional[int] = None,
vocoder_model = None,
vocoder_config = None,
):
"""This is a helper to perform synthesis once the model has been loaded.
It allows us to use the same command for synthesis via the CLI and
Expand Down Expand Up @@ -425,37 +425,35 @@ def synthesize( # noqa: C901

# load vocoder
logger.info(f"Loading Vocoder from {vocoder_path}")
if vocoder_path is None:
logger.error(
"No vocoder was provided, please specify "
"--vocoder-path /path/to/vocoder on the command line."
)
sys.exit(1)
else:
if vocoder_path is not None:
vocoder_ckpt = torch.load(vocoder_path, map_location=device)
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(
vocoder_ckpt, device
)
# We can't just use model.global_step because it gets reset by lightning
vocoder_global_step = get_global_step(vocoder_path)
return synthesize_helper(
model=model,
texts=texts,
language=language,
speaker=speaker,
duration_control=duration_control,
global_step=global_step,
output_type=output_type,
text_representation=text_representation,
accelerator=accelerator,
devices=devices,
device=device,
batch_size=batch_size,
num_workers=num_workers,
filelist=filelist,
teacher_forcing_directory=teacher_forcing_directory,
output_dir=output_dir,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
else:
vocoder_model = None
vocoder_config = None
vocoder_global_step = None
return synthesize_helper(
model=model,
texts=texts,
language=language,
speaker=speaker,
duration_control=duration_control,
global_step=global_step,
output_type=output_type,
text_representation=text_representation,
accelerator=accelerator,
devices=devices,
device=device,
batch_size=batch_size,
num_workers=num_workers,
filelist=filelist,
teacher_forcing_directory=teacher_forcing_directory,
output_dir=output_dir,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
)
10 changes: 6 additions & 4 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Optional, Sequence

import numpy as np
import torch
Expand All @@ -21,9 +21,9 @@ def get_synthesis_output_callbacks(
output_key: str,
device: torch.device,
global_step: int,
vocoder_model: HiFiGAN,
vocoder_config: HiFiGANConfig,
vocoder_global_step: int,
vocoder_model: Optional[HiFiGAN] = None,
vocoder_config: Optional[HiFiGANConfig] = None,
vocoder_global_step: Optional[int] = None,
):
"""
Given a list of desired output file formats, return the proper callbacks
Expand All @@ -49,6 +49,8 @@ def get_synthesis_output_callbacks(
)
)
if SynthesizeOutputFormats.wav in output_type:
if vocoder_model is None or vocoder_config is None or vocoder_global_step is None:
raise ValueError("We cannot synthesize waveforms without a vocoder. Please ensure that a vocoder is specified.")
callbacks.append(
PredictionWritingWavCallback(
config=config,
Expand Down

0 comments on commit 1cfeb4e

Please sign in to comment.