Skip to content

Commit

Permalink
feat: add the ability to output .readalong files
Browse files Browse the repository at this point in the history
On the path to implementing EveryVoiceTTS/EveryVoice#439
  • Loading branch information
joanise committed Dec 5, 2024
1 parent 2fd8eb4 commit f7c6b6c
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
1 change: 1 addition & 0 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def synthesize( # noqa: C901
**wav** is the default and will synthesize to a playable audio file;
**spec** will generate predicted Mel spectrograms. Tensors are time-oriented (T, K) where T is equal to the number of frames and K is equal to the number of Mel bands.
**textgrid** will generate a Praat TextGrid with alignment labels. This can be helpful for evaluation.
**readalong** will generate a ReadAlong from the given text and synthesized audio.
""",
),
teacher_forcing_directory: Path = typer.Option(
Expand Down
146 changes: 146 additions & 0 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from loguru import logger
from pympi import TextGrid
from pytorch_lightning.callbacks import Callback
from readalongs.api import Token, convert_to_readalong

from .config import FastSpeech2Config
from .type_definitions import SynthesizeOutputFormats
Expand Down Expand Up @@ -48,6 +49,15 @@ def get_synthesis_output_callbacks(
output_key=output_key,
)
)
if SynthesizeOutputFormats.readalong in output_type:
callbacks.append(
PredictionWritingReadAlongCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
)
if SynthesizeOutputFormats.wav in output_type:
if (
vocoder_model is None
Expand Down Expand Up @@ -305,6 +315,142 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
new_tg.to_file(filename)


class PredictionWritingReadAlongCallback(PredictionWritingCallbackBase):
"""
This callback runs inference on a provided text-to-spec model and saves the resulting readalong of the predicted durations to disk. Combined with the .wav output, this can be loaded in the ReadAlongs Web-Component for viewing.
"""

def __init__(
self,
config: FastSpeech2Config,
global_step: int,
output_dir: Path,
output_key: str,
):
super().__init__(
global_step=global_step,
file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.readalong",
save_dir=output_dir / "readalongs",
)
self.text_processor = TextProcessor(config.text)
self.output_key = output_key
self.config = config
logger.info(f"Saving pytorch output to {self.save_dir}")

def _get_filename(self, basename: str, speaker: str, language: str) -> Path:
# We don't truncate or alter the filename here because the basename is
# already truncated/cleaned in cli/synthesize.py
# the textgrid should not have the global step printed because it is used to fine-tune
# and the dataloader does not expect a global step in the filename
path = self.save_dir / self.sep.join(
[
basename,
speaker,
language,
self.file_extension,
]
)
path.parent.mkdir(
parents=True, exist_ok=True
) # synthesizing spec allows nested outputs
return path

def frames_to_seconds(self, frames: int) -> float:
return (
frames * self.config.preprocessing.audio.fft_hop_size
) / self.config.preprocessing.audio.output_sampling_rate

def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
self,
_trainer,
_pl_module,
outputs: dict[str, torch.Tensor | None],
batch: dict[str, Any],
_batch_idx: int,
_dataloader_idx: int = 0,
):
assert self.output_key in outputs and outputs[self.output_key] is not None
assert (
"duration_prediction" in outputs
and outputs["duration_prediction"] is not None
)
for basename, speaker, language, raw_text, text, duration in zip(
batch["basename"],
batch["speaker"],
batch["language"],
batch["raw_text"],
batch["text"], # type: ignore
outputs["duration_prediction"],
):
# Get all durations in frames
duration_frames = (
torch.clamp(torch.round(torch.exp(duration) - 1), min=0).int().tolist()
)
# Get all input labels
tokens: list[int] = text.tolist()
text_labels = self.text_processor.decode_tokens(tokens, join_character=None)
assert len(duration_frames) == len(
text_labels
), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})"
token_count = len(text_labels)
print(f"{text_labels=}")
# Create the sentence as a list of Tokens for readalong
sentence: list[Token] = []
for i in range(token_count):
if i >= 0:
sentence.append(Token(" "))
sentence.append(
Token(
text_labels[i],
)
)

phone_count = 0
word_count = 0
words: list[Token] = []
raw_text_words = raw_text.split()
current_word_duration = 0.0
last_phone_end = 0.0
last_word_end = 0.0
# skip padding
text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"]
duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)]
for label, duration in zip(
text_labels_no_padding, duration_frames_no_padding
):
# add phone label
phone_duration = self.frames_to_seconds(duration)
current_phone_end = last_phone_end + phone_duration
phone_count += 1
last_phone_end = current_phone_end
# accumulate phone to word label
current_word_duration += phone_duration
# if label is space or the last phone, add the word and recount
if label == " " or phone_count == len(text_labels_no_padding):
if word_count > 0:
words.append(Token(" "))
current_word_end = last_word_end + current_word_duration
words.append(
Token(
raw_text_words[word_count], last_word_end, current_word_end
)
)
word_count += 1
last_word_end = current_word_end
current_word_duration = 0
# Convert the ras_tokens to a readalong
readalong = convert_to_readalong([words], [language])
# get the filename
filename = self._get_filename(
basename=basename,
speaker=speaker,
language=language,
)
# write the file
with open(filename, "w", encoding="utf8") as f:
f.write(readalong)


class PredictionWritingWavCallback(PredictionWritingCallbackBase):
"""
Given text-to-spec, this callback does spec-to-wav and writes wav files.
Expand Down
1 change: 1 addition & 0 deletions fs2/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class SynthesizeOutputFormats(str, Enum):
wav = "wav"
spec = "spec"
textgrid = "textgrid"
readalong = "readalong"

0 comments on commit f7c6b6c

Please sign in to comment.