diff --git a/laser_encoders/validate_models.py b/laser_encoders/validate_models.py index 08f54502..0748dfee 100644 --- a/laser_encoders/validate_models.py +++ b/laser_encoders/validate_models.py @@ -50,24 +50,10 @@ def test_validate_language_models_and_tokenize_laser2(lang): print(f"{lang} model validated successfully") -class MockLaserModelDownloader: +class MockLaserModelDownloader(LaserModelDownloader): def __init__(self, model_dir): self.model_dir = model_dir - def get_language_code(self, language_list: dict, lang: str) -> str: - try: - lang_3_4 = language_list[lang] - if isinstance(lang_3_4, tuple): - options = ", ".join(f"'{opt}'" for opt in lang_3_4) - raise ValueError( - f"Language '{lang_3_4}' has multiple options: {options}. Please specify using --lang." - ) - return lang_3_4 - except KeyError: - raise ValueError( - f"language name: {lang} not found in language list. Specify a supported language name" - ) - def download_laser3(self, lang): lang = self.get_language_code(LASER3_LANGUAGE, lang) file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt")