From 1f661ddcca08245ccff49814ab6120d4b6623f8c Mon Sep 17 00:00:00 2001 From: Eric Joanis Date: Mon, 4 Nov 2024 17:41:04 -0500 Subject: [PATCH] fix: allow slash to be used as a character Fixes #460 Fixes #540 --- everyvoice/cli.py | 2 +- everyvoice/preprocessor/preprocessor.py | 16 ++--- everyvoice/tests/data/metadata_slash_pipe.psv | 6 ++ everyvoice/tests/test_preprocessing.py | 6 ++ everyvoice/tests/test_text.py | 40 +++++++++++-- everyvoice/text/text_processor.py | 58 ++++++++++++++++--- 6 files changed, 104 insertions(+), 24 deletions(-) create mode 100644 everyvoice/tests/data/metadata_slash_pipe.psv diff --git a/everyvoice/cli.py b/everyvoice/cli.py index 055ffc10..3c7d75d4 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -474,7 +474,7 @@ def check_data( preprocessor = Preprocessor(config) checked_data = preprocessor.check_data( filelist=combined_filelist_data, - heavy_clip_detction=heavy_clip_detection, + heavy_clip_detection=heavy_clip_detection, heavy_objective_evaluation=heavy_objective_evaluation, ) if not combined_filelist_data: diff --git a/everyvoice/preprocessor/preprocessor.py b/everyvoice/preprocessor/preprocessor.py index af5fb16d..d0476640 100644 --- a/everyvoice/preprocessor/preprocessor.py +++ b/everyvoice/preprocessor/preprocessor.py @@ -628,7 +628,7 @@ def process_attn_prior(self, item): process_phones = False character_tokens = ( self.text_processor.encode_string_tokens( - item["character_tokens"].split("/") + self.text_processor.split_tokens(item["character_tokens"]) ) if "character_tokens" in item and item[ @@ -637,7 +637,9 @@ def process_attn_prior(self, item): else None ) phone_tokens = ( - self.text_processor.encode_string_tokens(item["phone_tokens"].split("/")) + self.text_processor.encode_string_tokens( + self.text_processor.split_tokens(item["phone_tokens"]) + ) if "phone_tokens" in item and item["phone_tokens"] else None ) @@ -792,11 +794,9 @@ def process_text( # encode to string if encode_as_string: if phone_tokens is not None: - phones = text_processor.decode_tokens(phone_tokens, join_character="/") + phones = text_processor.decode_tokens(phone_tokens) if character_tokens is not None: - characters = text_processor.decode_tokens( - character_tokens, join_character="/" - ) + characters = text_processor.decode_tokens(character_tokens) return (characters, phones, pfs) else: return (character_tokens, phone_tokens, pfs) @@ -883,7 +883,7 @@ def check_data( self, filelist, word_seg_token=" ", - heavy_clip_detction=False, + heavy_clip_detection=False, heavy_objective_evaluation=False, ): data = [] @@ -943,7 +943,7 @@ def check_data( ), f"Audio has {audio.size(0)} channels, but should be mono" audio = audio.squeeze() - if heavy_clip_detction: + if heavy_clip_detection: _, total_clipping = detect_clipping(audio) else: # this isn't a great way of detecting clipping, diff --git a/everyvoice/tests/data/metadata_slash_pipe.psv b/everyvoice/tests/data/metadata_slash_pipe.psv new file mode 100644 index 00000000..ce407a13 --- /dev/null +++ b/everyvoice/tests/data/metadata_slash_pipe.psv @@ -0,0 +1,6 @@ +basename|raw_text|characters|speaker|language|clean_text|label|real_lang +LJ050-0269|The essential terms of such memoranda might well be embodied in an Executive order.|The essential terms of such memoranda might well be embodied in an Executive order.|default|default|the essential terms of such memoranda might well be embodied in an executive order.|LJ_TEST|eng +LJ050-0270|This Commission can\|recommend no procedures for the future protection of our Presidents which will guarantee security.|This Commission can\|recommend no procedures for the future protection of our Presidents which will guarantee security.|default|default|this commission can\|recommend no procedures for the future protection of our presidents which will guarantee security.|LJ_TEST|eng +LJ050-0271|The demands on/the President in the execution of His responsibilities in today's world are so varied and complex|The demands on/the President in the execution of His responsibilities in today's world are so varied and complex|default|default|the demands on/the president in the execution of his responsibilities in today's world are so varied and complex|LJ_TEST|eng +LJ050-0272.wav|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|default|default|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|LJ_TEST|eng +LJ050-0273|The Commission has, however, from its examination of the facts of President Kennedy's assassination|The Commission has, however, from its examination of the facts of President Kennedy's assassination|default|default|the commission has, however, from its examination of the facts of president kennedy's assassination|LJ_TEST|eng diff --git a/everyvoice/tests/test_preprocessing.py b/everyvoice/tests/test_preprocessing.py index 5f568834..784f389d 100644 --- a/everyvoice/tests/test_preprocessing.py +++ b/everyvoice/tests/test_preprocessing.py @@ -382,6 +382,7 @@ def test_text_processing(self): mixed_representation_filelist = ( self.data_dir / "metadata_mixed_representation.psv" ) + slash_pipe_filelist = self.data_dir / "metadata_slash_pipe.psv" fp_config = FeaturePredictionConfig(**self.fp_config.model_dump()) filelists_to_test = [ { @@ -409,6 +410,11 @@ def test_text_processing(self): "contains_characters": True, "contains_phones": True, }, # will tokenize characters and tokenize phones + { + "path": slash_pipe_filelist, + "contains_characters": True, + "contains_phones": False, + }, ] for filelist_test_info in filelists_to_test: with tempfile.TemporaryDirectory(prefix="inputs", dir=".") as tmpdir: diff --git a/everyvoice/tests/test_text.py b/everyvoice/tests/test_text.py index 223319ec..ec8bf955 100644 --- a/everyvoice/tests/test_text.py +++ b/everyvoice/tests/test_text.py @@ -9,13 +9,14 @@ import everyvoice.demo.app import everyvoice.text.utils +from everyvoice import exceptions from everyvoice.config.text_config import Punctuation, Symbols, TextConfig from everyvoice.model.feature_prediction.config import FeaturePredictionConfig from everyvoice.tests.basic_test_case import BasicTestCase from everyvoice.text.features import N_PHONOLOGICAL_FEATURES from everyvoice.text.lookups import build_lookup, lookuptables_from_data from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES, get_g2p_engine -from everyvoice.text.text_processor import TextProcessor +from everyvoice.text.text_processor import JOINER_SUBSTITUTION, TextProcessor from everyvoice.utils import ( collapse_whitespace, generic_psv_filelist_reader, @@ -46,7 +47,7 @@ def test_run_doctest(self): def test_text_to_sequence(self): text = "hello world" sequence = self.base_text_processor.encode_text(text) - self.assertEqual(self.base_text_processor.decode_tokens(sequence, ""), text) + self.assertEqual(self.base_text_processor.decode_tokens(sequence, "", ""), text) def test_token_sequence_to_text(self): sequence = [51, 48, 55, 55, 58, 1, 66, 58, 61, 55, 47] @@ -69,7 +70,7 @@ def test_cleaners_with_upper(self): ), ) sequence = upper_text_processor.encode_text(text_upper) - self.assertEqual(upper_text_processor.decode_tokens(sequence, ""), text) + self.assertEqual(upper_text_processor.decode_tokens(sequence, "", ""), text) def test_no_duplicate_punctuation(self): with self.assertRaises(ValidationError): @@ -198,7 +199,7 @@ def test_phonological_features(self): apply_g2p=True, encode_as_phonological_features=True, ) - self.assertEqual(moh_text_processor.decode_tokens(g2p_tokens, ""), "séːɡũ") + self.assertEqual(moh_text_processor.decode_tokens(g2p_tokens, "", ""), "séːɡũ") self.assertEqual(len(g2p_tokens), len(feats)) self.assertNotEqual(len(g2p_tokens), len(one_hot_tokens)) self.assertEqual(len(feats[0]), N_PHONOLOGICAL_FEATURES) @@ -239,9 +240,11 @@ def test_normalization(self): ) text = "he\u0301llo world" sequence = accented_text_processor.encode_text(text) - self.assertNotEqual(accented_text_processor.decode_tokens(sequence, ""), text) + self.assertNotEqual( + accented_text_processor.decode_tokens(sequence, "", ""), text + ) self.assertEqual( - accented_text_processor.decode_tokens(sequence, ""), + accented_text_processor.decode_tokens(sequence, "", ""), normalize("NFC", text), ) self.assertNotEqual( @@ -255,6 +258,31 @@ def test_missing_symbol(self): self.assertIn("3", self.base_text_processor.missing_symbols) self.assertEqual(self.base_text_processor.missing_symbols["3"], 1) + def test_use_slash(self): + text = "word/token" + text_processor = TextProcessor( + TextConfig(symbols=Symbols(letters=list(string.ascii_letters) + ["/"])), + ) + sequence = text_processor.encode_text(text) + decoded = text_processor.decode_tokens(sequence) + self.assertEqual(decoded, "w/o/r/d/" + JOINER_SUBSTITUTION + "/t/o/k/e/n") + encoded = text_processor.encode_escaped_string_sequence(decoded) + self.assertEqual(encoded, sequence) + + with self.assertRaises(exceptions.OutOfVocabularySymbolError): + # / is OOV, so JOINER_SUBSTITUTION will also be OOV + self.base_text_processor.encode_escaped_string_sequence(decoded) + + def test_encode_string_tokens(self): + self.assertEqual( + self.base_text_processor.encode_string_tokens(["a", "b", ",", " ", "c"]), + self.base_text_processor.encode_escaped_string_sequence("a/b/,/ /c"), + ) + with self.assertRaises(exceptions.OutOfVocabularySymbolError): + self.base_text_processor.encode_string_tokens(["oov"]) + with self.assertRaises(exceptions.OutOfVocabularySymbolError): + self.base_text_processor.encode_string_tokens([JOINER_SUBSTITUTION]) + class LookupTableTest(TestCase): def test_build_lookup(self): diff --git a/everyvoice/text/text_processor.py b/everyvoice/text/text_processor.py index a00c5fcc..00f02d4a 100644 --- a/everyvoice/text/text_processor.py +++ b/everyvoice/text/text_processor.py @@ -18,6 +18,8 @@ ) PAD_SYMBOL = "\x80" +CHARACTER_JOINER = "/" +JOINER_SUBSTITUTION = "" class TextProcessor: @@ -402,35 +404,56 @@ def encode_string_tokens(self, sequence: list[str]) -> list[int]: encoded_tokens.append(self._symbol_to_id[string_token]) except KeyError as e: raise OutOfVocabularySymbolError( - f"Sequence {sequence} contains item {string_token}" + f"Sequence {sequence} contains item '{string_token}'" ) from e return encoded_tokens def encode_escaped_string_sequence( - self, string_of_tokens: str, split_character="/" + self, + string_of_tokens: str, + split_character=CHARACTER_JOINER, + joiner_substitution=JOINER_SUBSTITUTION, ): assert ( len(split_character) >= 1 ), "An escaped string sequence must have a character to split on (default is '/')" return self.encode_string_tokens( - [token for token in string_of_tokens.split(split_character) if token] + [ + token + for token in self.split_tokens( + string_of_tokens, split_character, joiner_substitution + ) + if token + ] ) @overload - def decode_tokens( # noqa E704 - self, sequence: list[int], join_character: None + def decode_tokens( # noqa E704 # pragma: no cover + self, sequence: list[int], join_character: None, joiner_substitution: None ) -> list[str]: ... @overload - def decode_tokens( # noqa E704 - self, sequence: list[int], join_character: str + def decode_tokens( # noqa E704 # pragma: no cover + self, sequence: list[int], join_character: str, joiner_substitution: str + ) -> str: ... + + @overload + def decode_tokens( # noqa E704 # pragma: no cover + self, sequence: list[int] ) -> str: ... - def decode_tokens(self, sequence: list[int], join_character="/") -> str | list[str]: + def decode_tokens( + self, + sequence: list[int], + join_character=CHARACTER_JOINER, + joiner_substitution=JOINER_SUBSTITUTION, + ) -> str | list[str]: """Decode a sequence of encoded phone or character tokens into a sequence of strings Args: sequence (List[int]): sequence of phone or character tokens + join_character: if given, join the sequence with it + joiner_substitution: if joining, sub any occurrence of join_character by joiner_substitution Returns: str: the string equivalent of the sequence @@ -442,4 +465,21 @@ def decode_tokens(self, sequence: list[int], join_character="/") -> str | list[s if join_character is None: return self.token_sequence_to_text_sequence(sequence) else: - return join_character.join(self.token_sequence_to_text_sequence(sequence)) + assert joiner_substitution is not None + return join_character.join( + x.replace(join_character, joiner_substitution) + for x in self.token_sequence_to_text_sequence(sequence) + ) + + def split_tokens( + self, + joined_sequence: str, + join_character=CHARACTER_JOINER, + joiner_substitution=JOINER_SUBSTITUTION, + ) -> list[str]: + """Split a sequence of decoded phone or character tokens that was joined by decode_tokens(), + undoing any jointer substitutions.""" + return [ + x.replace(joiner_substitution, join_character) + for x in joined_sequence.split(join_character) + ]