diff --git a/fs2/prediction_writing_callback.py b/fs2/prediction_writing_callback.py index f02d285..027e166 100644 --- a/fs2/prediction_writing_callback.py +++ b/fs2/prediction_writing_callback.py @@ -117,11 +117,6 @@ def _get_filename( path.parent.mkdir(parents=True, exist_ok=True) return path - def frames_to_seconds(self, frames: int) -> float: - return ( - frames * self.config.preprocessing.audio.fft_hop_size - ) / self.config.preprocessing.audio.output_sampling_rate - class PredictionWritingSpecCallback(PredictionWritingCallbackBase): """ @@ -169,28 +164,39 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] ) -class PredictionWritingTextGridCallback(PredictionWritingCallbackBase): +class PredictionWritingAlignedTextCallback(PredictionWritingCallbackBase): """ - This callback runs inference on a provided text-to-spec model and saves the resulting textgrid of the predicted durations to disk. This can be used for evaluation. + This callback runs inference on a provided text-to-spec model and saves the resulting and extract the text for various output format options. """ def __init__( self, config: FastSpeech2Config, global_step: int, - output_dir: Path, output_key: str, + file_extension: str, + save_dir: Path, ): super().__init__( config=config, global_step=global_step, - file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.TextGrid", - save_dir=output_dir / "textgrids", + file_extension=file_extension, + save_dir=save_dir, ) self.text_processor = TextProcessor(config.text) self.output_key = output_key logger.info(f"Saving pytorch output to {self.save_dir}") + def save_text( + self, max_seconds, phones, words, language, filename + ): # pragma: no cover + raise NotImplementedError # subclasses must implement this + + def frames_to_seconds(self, frames: int) -> float: + return ( + frames * self.config.preprocessing.audio.fft_hop_size + ) / self.config.preprocessing.audio.output_sampling_rate + def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] self, _trainer, @@ -225,8 +231,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] ), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})" # get the duration of the audio: (sum_of_frames * hop_size) / sample_rate xmax_seconds = self.frames_to_seconds(sum(duration_frames)) - # create new textgrid - new_tg = TextGrid(xmax=xmax_seconds) # create the tiers words: list[tuple[float, float, str]] = [] phones: list[tuple[float, float, str]] = [] @@ -234,10 +238,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] current_word_duration = 0.0 last_phone_end = 0.0 last_word_end = 0.0 - phone_tier = new_tg.add_tier("phones") - phone_annotation_tier = new_tg.add_tier("phone annotations") - word_tier = new_tg.add_tier("words") - word_annotation_tier = new_tg.add_tier("word annotations") # skip padding text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"] duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)] @@ -249,8 +249,6 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] current_phone_end = last_phone_end + phone_duration interval = (last_phone_end, current_phone_end, label) phones.append(interval) - phone_annotation_tier.add_interval(interval[0], interval[1], "") - phone_tier.add_interval(*interval) last_phone_end = current_phone_end # accumulate phone to word label current_word_duration += phone_duration @@ -263,17 +261,54 @@ def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] raw_text_words[len(words)], ) words.append(interval) - word_tier.add_interval(*interval) - word_annotation_tier.add_interval(interval[0], interval[1], "") last_word_end = current_word_end current_word_duration = 0 + # get the filename filename = self._get_filename(basename, speaker, language) - # write the file - new_tg.to_file(filename) + # Save the output (the subclass has to implement this) + self.save_text(xmax_seconds, phones, words, language, filename) -class PredictionWritingReadAlongCallback(PredictionWritingCallbackBase): +class PredictionWritingTextGridCallback(PredictionWritingAlignedTextCallback): + """ + This callback runs inference on a provided text-to-spec model and saves the resulting textgrid of the predicted durations to disk. This can be used for evaluation. + """ + + def __init__( + self, + config: FastSpeech2Config, + global_step: int, + output_dir: Path, + output_key: str, + ): + super().__init__( + config=config, + global_step=global_step, + output_key=output_key, + file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.TextGrid", + save_dir=output_dir / "textgrids", + ) + + def save_text(self, max_seconds, phones, words, language, filename): + """Save the aligned text as a TextGrid with phones and words layers""" + new_tg = TextGrid(xmax=max_seconds) + phone_tier = new_tg.add_tier("phones") + phone_annotation_tier = new_tg.add_tier("phone annotations") + for interval in phones: + phone_annotation_tier.add_interval(interval[0], interval[1], "") + phone_tier.add_interval(*interval) + + word_tier = new_tg.add_tier("words") + word_annotation_tier = new_tg.add_tier("word annotations") + for interval in words: + word_tier.add_interval(*interval) + word_annotation_tier.add_interval(interval[0], interval[1], "") + + new_tg.to_file(filename) + + +class PredictionWritingReadAlongCallback(PredictionWritingAlignedTextCallback): """ This callback runs inference on a provided text-to-spec model and saves the resulting readalong of the predicted durations to disk. Combined with the .wav output, this can be loaded in the ReadAlongs Web-Component for viewing. """ @@ -288,6 +323,7 @@ def __init__( super().__init__( config=config, global_step=global_step, + output_key=output_key, file_extension=f"{config.preprocessing.audio.input_sampling_rate}-{config.preprocessing.audio.spec_type}.readalong", save_dir=output_dir / "readalongs", ) @@ -295,90 +331,19 @@ def __init__( self.output_key = output_key logger.info(f"Saving pytorch output to {self.save_dir}") - def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] - self, - _trainer, - _pl_module, - outputs: dict[str, torch.Tensor | None], - batch: dict[str, Any], - _batch_idx: int, - _dataloader_idx: int = 0, - ): - assert self.output_key in outputs and outputs[self.output_key] is not None - assert ( - "duration_prediction" in outputs - and outputs["duration_prediction"] is not None - ) - for basename, speaker, language, raw_text, text, duration in zip( - batch["basename"], - batch["speaker"], - batch["language"], - batch["raw_text"], - batch["text"], # type: ignore - outputs["duration_prediction"], - ): - # Get all durations in frames - duration_frames = ( - torch.clamp(torch.round(torch.exp(duration) - 1), min=0).int().tolist() - ) - # Get all input labels - tokens: list[int] = text.tolist() - text_labels = self.text_processor.decode_tokens(tokens, join_character=None) - assert len(duration_frames) == len( - text_labels - ), f"can't synthesize {raw_text} because the number of predicted duration steps ({len(duration_frames)}) doesn't equal the number of input text labels ({len(text_labels)})" - token_count = len(text_labels) - # Create the sentence as a list of Tokens for readalong - sentence: list[Token] = [] - for i in range(token_count): - if i >= 0: - sentence.append(Token(" ")) - sentence.append( - Token( - text_labels[i], - ) - ) + def save_text(self, max_seconds, phones, words, language, filename): + """Save the aligned text as a .readalong file""" - phone_count = 0 - word_count = 0 - words: list[Token] = [] - raw_text_words = raw_text.split() - current_word_duration = 0.0 - last_phone_end = 0.0 - last_word_end = 0.0 - # skip padding - text_labels_no_padding = [tl for tl in text_labels if tl != "\x80"] - duration_frames_no_padding = duration_frames[: len(text_labels_no_padding)] - for label, duration in zip( - text_labels_no_padding, duration_frames_no_padding - ): - # add phone label - phone_duration = self.frames_to_seconds(duration) - current_phone_end = last_phone_end + phone_duration - phone_count += 1 - last_phone_end = current_phone_end - # accumulate phone to word label - current_word_duration += phone_duration - # if label is space or the last phone, add the word and recount - if label == " " or phone_count == len(text_labels_no_padding): - if word_count > 0: - words.append(Token(" ")) - current_word_end = last_word_end + current_word_duration - words.append( - Token( - raw_text_words[word_count], last_word_end, current_word_end - ) - ) - word_count += 1 - last_word_end = current_word_end - current_word_duration = 0 - # Convert the ras_tokens to a readalong - readalong = convert_to_readalong([words], [language]) - # get the filename - filename = self._get_filename(basename, speaker, language) - # write the file - with open(filename, "w", encoding="utf8") as f: - f.write(readalong) + # Convert the (time, end, label) word tuples into RAS Tokens + ras_tokens: list[Token] = [] + for word in words: + if ras_tokens: + ras_tokens.append(Token(" ")) + ras_tokens.append(Token(word[2], word[0], word[1])) + + readalong = convert_to_readalong([ras_tokens], [language]) + with open(filename, "w", encoding="utf8") as f: + f.write(readalong) class PredictionWritingWavCallback(PredictionWritingCallbackBase):