diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index 17a5db35..1f6ecb77 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -46,21 +46,31 @@ def __init__(self, model_dir: str = None): def download(self, filename: str): url = os.path.join(self.base_url, filename) - local_file_path = self.model_dir / filename - if local_file_path.exists(): + local_file_path = os.path.join(self.model_dir, filename) + temp_file_path = os.path.join('/tmp', filename) + + if os.path.exists(local_file_path): logger.info(f" - {filename} already downloaded") else: logger.info(f" - Downloading {filename}") + + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + response = requests.get(url, stream=True) total_size = int(response.headers.get("Content-Length", 0)) progress_bar = tqdm(total=total_size, unit_scale=True, unit="B") - with open(local_file_path, "wb") as f: + + # Download to /tmp first + with open(temp_file_path, "wb") as f: for chunk in response.iter_content(chunk_size=1024): f.write(chunk) progress_bar.update(len(chunk)) progress_bar.close() + os.rename(temp_file_path, local_file_path) + def get_language_code(self, language_list: dict, lang: str) -> str: try: lang_3_4 = language_list[lang]