Skip to content

Commit

Permalink
refactor(tests): silence test_text/utils/wizard
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Nov 27, 2024
1 parent cad6b82 commit 9b89c74
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
32 changes: 18 additions & 14 deletions everyvoice/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from everyvoice.config.text_config import Punctuation, Symbols, TextConfig
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
from everyvoice.tests.basic_test_case import BasicTestCase
from everyvoice.tests.stubs import silence_c_stderr
from everyvoice.text.features import N_PHONOLOGICAL_FEATURES
from everyvoice.text.lookups import build_lookup, lookuptables_from_data
from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES, get_g2p_engine
Expand Down Expand Up @@ -50,12 +51,13 @@ def test_hardcoded_symbols(self):
def test_cleaners_with_upper(self):
text = "hello world"
text_upper = "HELLO WORLD"
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
with silence_c_stderr():
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
sequence = upper_text_processor.encode_text(text_upper)
self.assertEqual(upper_text_processor.decode_tokens(sequence, "", ""), text)

Expand All @@ -65,12 +67,13 @@ def test_no_duplicate_punctuation(self):

def test_punctuation(self):
text = "hello! How are you? My name's: foo;."
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
with silence_c_stderr():
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
tokens = upper_text_processor.apply_tokenization(
upper_text_processor.normalize_text(text)
)
Expand Down Expand Up @@ -197,7 +200,7 @@ def test_duplicates_removed(self):
symbols=Symbols(letters=list(string.ascii_letters), duplicate=["e"])
)
)
self.assertEquals(
self.assertEqual(
len([x for x in duplicate_symbols_text_processor.symbols if x == "e"]), 1
)

Expand Down Expand Up @@ -240,7 +243,8 @@ def test_normalization(self):

def test_missing_symbol(self):
text = "h3llo world"
sequence = self.base_text_processor.encode_text(text)
with silence_c_stderr():
sequence = self.base_text_processor.encode_text(text)
self.assertNotEqual(self.base_text_processor.decode_tokens(sequence), text)
self.assertIn("3", self.base_text_processor.missing_symbols)
self.assertEqual(self.base_text_processor.missing_symbols["3"], 1)
Expand Down
11 changes: 6 additions & 5 deletions everyvoice/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
path_is_a_directory,
relative_to_absolute_path,
)
from everyvoice.tests.stubs import capture_logs, patch_logger
from everyvoice.tests.stubs import capture_logs, patch_logger, silence_c_stderr
from everyvoice.utils import write_filelist
from everyvoice.utils.heavy import get_device_from_accelerator

Expand Down Expand Up @@ -259,10 +259,11 @@ def test_using_a_directory(self):

class GetDeviceFromAcceleratorTest(TestCase):
def test_auto(self):
self.assertEqual(
get_device_from_accelerator("auto"),
torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
)
with silence_c_stderr():
self.assertEqual(
get_device_from_accelerator("auto"),
torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
)

def test_cpu(self):
self.assertEqual(get_device_from_accelerator("cpu"), torch.device("cpu"))
Expand Down
4 changes: 4 additions & 0 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,8 @@ def test_festival(self):

def test_multilingual_multispeaker_true_config(self):
"""
Test mismatched multi-monolingual and multi-monospeaker datasets.
Makes sure that multilingual and multispeaker parameters of config are set to true when two monolingual and monospeaker datasets are provided with different specified languages and speakers.
"""
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -1621,6 +1623,8 @@ def test_multilingual_multispeaker_true_config(self):

def test_multilingual_multispeaker_false_config(self):
"""
Test matched multi-monospeaker and multi-monolingual datasets.
Makes sure that multilingual and multispeaker parameters of config are set to false when two monolingual and monospeaker datasets are provided with the specified languages and speakers which are the same.
"""
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 9b89c74

Please sign in to comment.