Skip to content

Commit

Permalink
Update validate_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NIXBLACK11 authored Nov 14, 2023
1 parent 3944556 commit 0a4d983
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions laser_encoders/validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,15 @@ def get_language_code(self, language_list: dict, lang: str) -> str:
def download_laser3(self, lang):
lang = self.get_language_code(LASER3_LANGUAGE, lang)
file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt")
if os.path.exists(file_path):
return False
else:
return True
if not os.path.exists(file_path):
raise FileNotFoundError(f"Could not find {file_path}.")

def download_laser2(self):
files = ["laser2.pt", "laser2.spm", "laser2.cvocab"]
for file_name in files:
file_path = os.path.join(self.model_dir, file_name)
if not os.path.exists(file_path):
return True

return False
raise FileNotFoundError(f"Could not find {file_path}.")


CACHE_DIR = "/home/user/.cache/models" # Change this to the desired cache directory
Expand All @@ -93,10 +89,12 @@ def download_laser2(self):
@pytest.mark.parametrize("lang", LASER3_LANGUAGE)
def test_validate_language_models_and_tokenize_mock_laser3(lang):
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR)
err = downloader.download_laser3(lang)
if err == True:
raise pytest.error(f"Skipping test for {lang} language.")


try:
downloader.download_laser3(lang)
except FileNotFoundError as e:
raise pytest.error(str(e))

encoder = initialize_encoder(lang, model_dir=CACHE_DIR)
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR)

Expand All @@ -110,9 +108,11 @@ def test_validate_language_models_and_tokenize_mock_laser3(lang):
@pytest.mark.parametrize("lang", LASER2_LANGUAGE)
def test_validate_language_models_and_tokenize_mock_laser2(lang):
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR)
err = downloader.download_laser2()
if err == True:
raise pytest.error()

try:
downloader.download_laser2()
except FileNotFoundError as e:
raise pytest.error(str(e))

encoder = initialize_encoder(lang, model_dir=CACHE_DIR)
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR)
Expand Down

0 comments on commit 0a4d983

Please sign in to comment.