Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: change model to output mel-band oriented tensors instead of time-oriented ones #94

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@
help="""Which format(s) to synthesize to.
Multiple formats can be provided by repeating `--output-type`.
'**wav**' is the default and will synthesize to a playable audio file;
'**spec**' will generate predicted Mel spectrograms. Tensors are time-oriented (T, K) where T is equal to the number of frames and K is equal to the number of Mel bands.
'**spec**' will generate predicted Mel spectrograms. Tensors are Mel band-oriented (K, T) where K is equal to the number of Mel bands and T is equal to the number of frames.
'**textgrid**' will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation.
'**readalong**' will generate a ReadAlong from the given text and synthesized audio (see https://github.com/ReadAlongs).
""",
Expand Down Expand Up @@ -436,8 +436,8 @@
global_step = get_global_step(model_path)

# load vocoder
logger.info(f"Loading Vocoder from {vocoder_path}")
if vocoder_path is not None:
logger.info(f"Loading Vocoder from {vocoder_path}")

Check warning on line 440 in fs2/cli/synthesize.py

View check run for this annotation

Codecov / codecov/patch

fs2/cli/synthesize.py#L440

Added line #L440 was not covered by tests
vocoder_ckpt = torch.load(vocoder_path, map_location=device)
try:
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@
self.config.preprocessing.audio.output_sampling_rate,
)
if self.config.training.vocoder_path:
input_ = batch["mel"]
input_ = batch["mel"].transpose(1, 2)

Check warning on line 356 in fs2/model.py

View check run for this annotation

Codecov / codecov/patch

fs2/model.py#L356

Added line #L356 was not covered by tests
vocoder_ckpt = torch.load(
self.config.training.vocoder_path, map_location=input_.device
)
Expand Down Expand Up @@ -431,7 +431,7 @@
)

if self.config.training.vocoder_path:
input_ = output[self.output_key]
input_ = output[self.output_key].transpose(1, 2)

Check warning on line 434 in fs2/model.py

View check run for this annotation

Codecov / codecov/patch

fs2/model.py#L434

Added line #L434 was not covered by tests
vocoder_ckpt = torch.load(
self.config.training.vocoder_path, map_location=input_.device
)
Expand Down
39 changes: 19 additions & 20 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_filename(
speaker: str,
language: str,
include_global_step: bool = False,
) -> Path:
) -> str:
# We don't truncate or alter the filename here because the basename is
# already truncated/cleaned in cli/synthesize.py
name_parts = [basename, speaker, language, self.file_extension]
Expand All @@ -115,7 +115,7 @@ def get_filename(
path = self.save_dir / self.sep.join(name_parts)
# synthesizing spec allows nested outputs so we may need to make subdirs
path.parent.mkdir(parents=True, exist_ok=True)
return path
return str(path)


class PredictionWritingSpecCallback(PredictionWritingCallbackBase):
Expand Down Expand Up @@ -159,7 +159,9 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
outputs["tgt_lens"],
):
torch.save(
data[:unmasked_len].cpu(),
data[:unmasked_len]
.cpu()
.transpose(0, 1), # save tensors as [K (bands), T (frames)]
self.get_filename(basename, speaker, language),
)

Expand Down Expand Up @@ -409,7 +411,7 @@ def synthesize_audio(self, outputs):

sr: int
wavs: np.ndarray
output_value = outputs[self.output_key]
output_value = outputs[self.output_key].transpose(1, 2)
if output_value is not None:
wavs, sr = synthesize_data(
output_value, self.vocoder_model, self.vocoder_config
Expand All @@ -419,21 +421,15 @@ def synthesize_audio(self, outputs):
f"{self.output_key} does not exist in the output of your model"
)

# wavs: [B (batch_size), T (samples)]
# wavs: [B (batch_size), C (channels), T (samples)]
assert (
wavs.ndim == 2
), f"The generated audio contained more than 2 dimensions. First dimension should be B(atch) and the second dimension should be T(ime) in samples. Got {wavs.shape} instead."
assert "output" in outputs and outputs["output"] is not None
assert wavs.shape[0] == outputs["output"].size(
wavs.ndim == 3
), f"The generated audio did not contain 3 dimensions. First dimension should be B(atch) and the second dimension should be C(hannels) and third dimension should be T(ime) in samples. Got {wavs.shape} instead."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen due to a user error (like providing the wrong kind of input file), or is this strictly due to a programmer error? If the latter, OK, if the former, I don't like using assert.

assert self.output_key in outputs and outputs[self.output_key] is not None
assert wavs.shape[0] == outputs[self.output_key].size(
0
), f"You provided {outputs['output'].size(0)} utterances, but {wavs.shape[0]} audio files were synthesized instead."

# synthesize 16 bit audio
# we don't do this higher up in the inference methods
# because tensorboard logs require audio data as floats
if (wavs >= -1.0).all() & (wavs <= 1.0).all():
wavs = wavs * self.config.preprocessing.audio.max_wav_value
wavs = wavs.astype("int16")
), f"You provided {outputs[self.output_key].size(0)} utterances, but {wavs.shape[0]} audio files were synthesized instead."

return wavs, sr

def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
Expand All @@ -445,7 +441,7 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
_batch_idx: int,
_dataloader_idx: int = 0,
):
from scipy.io.wavfile import write
import torchaudio

logger.trace("Generating waveform...")

Expand All @@ -459,11 +455,14 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
wavs,
outputs["tgt_lens"],
):
write(
torchaudio.save(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the change of audio writer function related to this PR, or just an unrelated improvement? I assume you've tested and you can confirm this works well?

self.get_filename(
basename, speaker, language, include_global_step=True
),
sr,
# the vocoder output includes padding so we have to remove that
wav[: (unmasked_len * self.output_hop_size)],
sr,
format="wav",
encoding="PCM_S",
bits_per_sample=16,
)
Loading