Skip to content

Commit

Permalink
fix: allow slash to be used as a character
Browse files Browse the repository at this point in the history
Fixes #460
Fixes #540
  • Loading branch information
joanise committed Nov 12, 2024
1 parent 237278f commit 1f661dd
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 24 deletions.
2 changes: 1 addition & 1 deletion everyvoice/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def check_data(
preprocessor = Preprocessor(config)
checked_data = preprocessor.check_data(
filelist=combined_filelist_data,
heavy_clip_detction=heavy_clip_detection,
heavy_clip_detection=heavy_clip_detection,
heavy_objective_evaluation=heavy_objective_evaluation,
)
if not combined_filelist_data:
Expand Down
16 changes: 8 additions & 8 deletions everyvoice/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def process_attn_prior(self, item):
process_phones = False
character_tokens = (
self.text_processor.encode_string_tokens(
item["character_tokens"].split("/")
self.text_processor.split_tokens(item["character_tokens"])
)
if "character_tokens" in item
and item[
Expand All @@ -637,7 +637,9 @@ def process_attn_prior(self, item):
else None
)
phone_tokens = (
self.text_processor.encode_string_tokens(item["phone_tokens"].split("/"))
self.text_processor.encode_string_tokens(
self.text_processor.split_tokens(item["phone_tokens"])
)
if "phone_tokens" in item and item["phone_tokens"]
else None
)
Expand Down Expand Up @@ -792,11 +794,9 @@ def process_text(
# encode to string
if encode_as_string:
if phone_tokens is not None:
phones = text_processor.decode_tokens(phone_tokens, join_character="/")
phones = text_processor.decode_tokens(phone_tokens)
if character_tokens is not None:
characters = text_processor.decode_tokens(
character_tokens, join_character="/"
)
characters = text_processor.decode_tokens(character_tokens)
return (characters, phones, pfs)
else:
return (character_tokens, phone_tokens, pfs)
Expand Down Expand Up @@ -883,7 +883,7 @@ def check_data(
self,
filelist,
word_seg_token=" ",
heavy_clip_detction=False,
heavy_clip_detection=False,
heavy_objective_evaluation=False,
):
data = []
Expand Down Expand Up @@ -943,7 +943,7 @@ def check_data(
), f"Audio has {audio.size(0)} channels, but should be mono"
audio = audio.squeeze()

if heavy_clip_detction:
if heavy_clip_detection:
_, total_clipping = detect_clipping(audio)
else:
# this isn't a great way of detecting clipping,
Expand Down
6 changes: 6 additions & 0 deletions everyvoice/tests/data/metadata_slash_pipe.psv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
basename|raw_text|characters|speaker|language|clean_text|label|real_lang
LJ050-0269|The essential terms of such memoranda might well be embodied in an Executive order.|The essential terms of such memoranda might well be embodied in an Executive order.|default|default|the essential terms of such memoranda might well be embodied in an executive order.|LJ_TEST|eng
LJ050-0270|This Commission can\|recommend no procedures for the future protection of our Presidents which will guarantee security.|This Commission can\|recommend no procedures for the future protection of our Presidents which will guarantee security.|default|default|this commission can\|recommend no procedures for the future protection of our presidents which will guarantee security.|LJ_TEST|eng
LJ050-0271|The demands on/the President in the execution of His responsibilities in today's world are so varied and complex|The demands on/the President in the execution of His responsibilities in today's world are so varied and complex|default|default|the demands on/the president in the execution of his responsibilities in today's world are so varied and complex|LJ_TEST|eng
LJ050-0272.wav|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|default|default|and the traditions of the office in a democracy such as ours are so deep-seated as to preclude absolute security.|LJ_TEST|eng
LJ050-0273|The Commission has, however, from its examination of the facts of President Kennedy's assassination|The Commission has, however, from its examination of the facts of President Kennedy's assassination|default|default|the commission has, however, from its examination of the facts of president kennedy's assassination|LJ_TEST|eng
6 changes: 6 additions & 0 deletions everyvoice/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def test_text_processing(self):
mixed_representation_filelist = (
self.data_dir / "metadata_mixed_representation.psv"
)
slash_pipe_filelist = self.data_dir / "metadata_slash_pipe.psv"
fp_config = FeaturePredictionConfig(**self.fp_config.model_dump())
filelists_to_test = [
{
Expand Down Expand Up @@ -409,6 +410,11 @@ def test_text_processing(self):
"contains_characters": True,
"contains_phones": True,
}, # will tokenize characters and tokenize phones
{
"path": slash_pipe_filelist,
"contains_characters": True,
"contains_phones": False,
},
]
for filelist_test_info in filelists_to_test:
with tempfile.TemporaryDirectory(prefix="inputs", dir=".") as tmpdir:
Expand Down
40 changes: 34 additions & 6 deletions everyvoice/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

import everyvoice.demo.app
import everyvoice.text.utils
from everyvoice import exceptions
from everyvoice.config.text_config import Punctuation, Symbols, TextConfig
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
from everyvoice.tests.basic_test_case import BasicTestCase
from everyvoice.text.features import N_PHONOLOGICAL_FEATURES
from everyvoice.text.lookups import build_lookup, lookuptables_from_data
from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES, get_g2p_engine
from everyvoice.text.text_processor import TextProcessor
from everyvoice.text.text_processor import JOINER_SUBSTITUTION, TextProcessor
from everyvoice.utils import (
collapse_whitespace,
generic_psv_filelist_reader,
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_run_doctest(self):
def test_text_to_sequence(self):
text = "hello world"
sequence = self.base_text_processor.encode_text(text)
self.assertEqual(self.base_text_processor.decode_tokens(sequence, ""), text)
self.assertEqual(self.base_text_processor.decode_tokens(sequence, "", ""), text)

def test_token_sequence_to_text(self):
sequence = [51, 48, 55, 55, 58, 1, 66, 58, 61, 55, 47]
Expand All @@ -69,7 +70,7 @@ def test_cleaners_with_upper(self):
),
)
sequence = upper_text_processor.encode_text(text_upper)
self.assertEqual(upper_text_processor.decode_tokens(sequence, ""), text)
self.assertEqual(upper_text_processor.decode_tokens(sequence, "", ""), text)

def test_no_duplicate_punctuation(self):
with self.assertRaises(ValidationError):
Expand Down Expand Up @@ -198,7 +199,7 @@ def test_phonological_features(self):
apply_g2p=True,
encode_as_phonological_features=True,
)
self.assertEqual(moh_text_processor.decode_tokens(g2p_tokens, ""), "séːɡũ")
self.assertEqual(moh_text_processor.decode_tokens(g2p_tokens, "", ""), "séːɡũ")
self.assertEqual(len(g2p_tokens), len(feats))
self.assertNotEqual(len(g2p_tokens), len(one_hot_tokens))
self.assertEqual(len(feats[0]), N_PHONOLOGICAL_FEATURES)
Expand Down Expand Up @@ -239,9 +240,11 @@ def test_normalization(self):
)
text = "he\u0301llo world"
sequence = accented_text_processor.encode_text(text)
self.assertNotEqual(accented_text_processor.decode_tokens(sequence, ""), text)
self.assertNotEqual(
accented_text_processor.decode_tokens(sequence, "", ""), text
)
self.assertEqual(
accented_text_processor.decode_tokens(sequence, ""),
accented_text_processor.decode_tokens(sequence, "", ""),
normalize("NFC", text),
)
self.assertNotEqual(
Expand All @@ -255,6 +258,31 @@ def test_missing_symbol(self):
self.assertIn("3", self.base_text_processor.missing_symbols)
self.assertEqual(self.base_text_processor.missing_symbols["3"], 1)

def test_use_slash(self):
text = "word/token"
text_processor = TextProcessor(
TextConfig(symbols=Symbols(letters=list(string.ascii_letters) + ["/"])),
)
sequence = text_processor.encode_text(text)
decoded = text_processor.decode_tokens(sequence)
self.assertEqual(decoded, "w/o/r/d/" + JOINER_SUBSTITUTION + "/t/o/k/e/n")
encoded = text_processor.encode_escaped_string_sequence(decoded)
self.assertEqual(encoded, sequence)

with self.assertRaises(exceptions.OutOfVocabularySymbolError):
# / is OOV, so JOINER_SUBSTITUTION will also be OOV
self.base_text_processor.encode_escaped_string_sequence(decoded)

def test_encode_string_tokens(self):
self.assertEqual(
self.base_text_processor.encode_string_tokens(["a", "b", ",", " ", "c"]),
self.base_text_processor.encode_escaped_string_sequence("a/b/,/ /c"),
)
with self.assertRaises(exceptions.OutOfVocabularySymbolError):
self.base_text_processor.encode_string_tokens(["oov"])
with self.assertRaises(exceptions.OutOfVocabularySymbolError):
self.base_text_processor.encode_string_tokens([JOINER_SUBSTITUTION])


class LookupTableTest(TestCase):
def test_build_lookup(self):
Expand Down
58 changes: 49 additions & 9 deletions everyvoice/text/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)

PAD_SYMBOL = "\x80"
CHARACTER_JOINER = "/"
JOINER_SUBSTITUTION = "<SLASH>"


class TextProcessor:
Expand Down Expand Up @@ -402,35 +404,56 @@ def encode_string_tokens(self, sequence: list[str]) -> list[int]:
encoded_tokens.append(self._symbol_to_id[string_token])
except KeyError as e:
raise OutOfVocabularySymbolError(
f"Sequence {sequence} contains item {string_token}"
f"Sequence {sequence} contains item '{string_token}'"
) from e
return encoded_tokens

def encode_escaped_string_sequence(
self, string_of_tokens: str, split_character="/"
self,
string_of_tokens: str,
split_character=CHARACTER_JOINER,
joiner_substitution=JOINER_SUBSTITUTION,
):
assert (
len(split_character) >= 1
), "An escaped string sequence must have a character to split on (default is '/')"
return self.encode_string_tokens(
[token for token in string_of_tokens.split(split_character) if token]
[
token
for token in self.split_tokens(
string_of_tokens, split_character, joiner_substitution
)
if token
]
)

@overload
def decode_tokens( # noqa E704
self, sequence: list[int], join_character: None
def decode_tokens( # noqa E704 # pragma: no cover
self, sequence: list[int], join_character: None, joiner_substitution: None
) -> list[str]: ...

@overload
def decode_tokens( # noqa E704
self, sequence: list[int], join_character: str
def decode_tokens( # noqa E704 # pragma: no cover
self, sequence: list[int], join_character: str, joiner_substitution: str
) -> str: ...

@overload
def decode_tokens( # noqa E704 # pragma: no cover
self, sequence: list[int]
) -> str: ...

def decode_tokens(self, sequence: list[int], join_character="/") -> str | list[str]:
def decode_tokens(
self,
sequence: list[int],
join_character=CHARACTER_JOINER,
joiner_substitution=JOINER_SUBSTITUTION,
) -> str | list[str]:
"""Decode a sequence of encoded phone or character tokens into a sequence of strings
Args:
sequence (List[int]): sequence of phone or character tokens
join_character: if given, join the sequence with it
joiner_substitution: if joining, sub any occurrence of join_character by joiner_substitution
Returns:
str: the string equivalent of the sequence
Expand All @@ -442,4 +465,21 @@ def decode_tokens(self, sequence: list[int], join_character="/") -> str | list[s
if join_character is None:
return self.token_sequence_to_text_sequence(sequence)
else:
return join_character.join(self.token_sequence_to_text_sequence(sequence))
assert joiner_substitution is not None
return join_character.join(
x.replace(join_character, joiner_substitution)
for x in self.token_sequence_to_text_sequence(sequence)
)

def split_tokens(
self,
joined_sequence: str,
join_character=CHARACTER_JOINER,
joiner_substitution=JOINER_SUBSTITUTION,
) -> list[str]:
"""Split a sequence of decoded phone or character tokens that was joined by decode_tokens(),
undoing any jointer substitutions."""
return [
x.replace(joiner_substitution, join_character)
for x in joined_sequence.split(join_character)
]

0 comments on commit 1f661dd

Please sign in to comment.