diff --git a/getai/core/model_downloader.py b/getai/core/model_downloader.py index ee5e3c3..900055f 100644 --- a/getai/core/model_downloader.py +++ b/getai/core/model_downloader.py @@ -235,20 +235,24 @@ async def download_model_files( semaphore = asyncio.Semaphore(self.max_connections) async def download_file(link: str): - """Download a single file.""" + """Download a single file with improved semaphore management.""" async with semaphore: filename = Path(link).name file_hash: Optional[str] = sha256_dict.get(filename) + # Log the download start + self.logger.info(f"Starting download for {filename}") + if file_hash is None: print( f"Warning: No SHA256 hash found for {filename}. Downloading without sha256 verification." ) - await self._download_model_file(session, link, output_folder, file_hash) tasks = [asyncio.ensure_future(download_file(link)) for link in links] + # Ensure all tasks are completed before proceeding await asyncio.gather(*tasks) + self.logger.info("All download tasks completed") async def _download_model_file( self, @@ -257,39 +261,72 @@ async def _download_model_file( output_folder: Path, file_hash: Optional[str], ): - """Download and save a model file.""" + """Download and save a model file with support for resuming downloads.""" from rainbow_tqdm import tqdm filename = Path(url.rsplit("/", 1)[1]) output_path = output_folder / filename + # Check if the file already exists and get its size if output_path.exists(): - current_hash = await self.calculate_file_sha256(output_path) - if current_hash == file_hash: - print( - f"'{filename}' exists and matches expected SHA256 hash; skipping." - ) - return - else: - print( - f"'{filename}' exists but SHA256 hash matching failed; redownloading." - ) + file_size = output_path.stat().st_size + else: + file_size = 0 - async with session.get(url) as response: - response.raise_for_status() - total_size = response.content_length or 0 - async with aiofiles.open(output_path, "wb") as f: - progress_bar = tqdm( - total=total_size, - desc=filename.name, - unit="iB", - unit_scale=True, - ncols=100, - ) - async for chunk in response.content.iter_chunked(1024): - await f.write(chunk) - progress_bar.update(len(chunk)) - progress_bar.close() + headers = {} + if file_size > 0: + headers["Range"] = f"bytes={file_size}-" + self.logger.info( + f"Resuming download for '{filename}' from byte {file_size}" + ) + + # Retry logic for downloading the file + max_attempts = self.max_retries + attempt = 0 + while attempt < max_attempts: + try: + async with session.get(url, headers=headers) as response: + if response.status == 416: + # Handle the case where the file is already completely downloaded + self.logger.info("'%s' is already fully downloaded.", filename) + return + + response.raise_for_status() + total_size = response.content_length or 0 + # If resuming, adjust total_size to reflect the remaining bytes + if "Content-Range" in response.headers: + content_range = response.headers["Content-Range"] + total_size = int(content_range.split("/")[1]) - file_size + + mode = "ab" if file_size > 0 else "wb" + # Ensure correct handling of file path with aiofiles.open in binary mode + async with aiofiles.open(output_path, mode) as f: + progress_bar = tqdm( + total=total_size, + desc=filename.name, + unit="iB", + unit_scale=True, + ncols=100, + initial=file_size, # Start the progress bar from the current file size + ) + async for chunk in response.content.iter_chunked( + 1024 * 10 + ): # Use larger chunk size for efficiency + await f.write(chunk) + progress_bar.update(len(chunk)) + progress_bar.close() + + break + except Exception as e: + attempt += 1 + self.logger.warning(f"Attempt {attempt} failed: {e}") + if attempt >= max_attempts: + self.logger.error( + "Failed to download '%s' after %d attempts", + filename, + max_attempts, + ) + raise async def check_model_files( self, diff --git a/pyproject.toml b/pyproject.toml index 46496ef..e0689f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "getai" -version = "0.0.982" +version = "0.0.983" description = "GetAI - An asynchronous AI search and download tool for AI models, datasets, and tools. Designed to streamline the process of downloading machine learning models, datasets, and more." authors = ["Ben Gorlick "] license = "MIT - with attribution" diff --git a/setup.py b/setup.py index 1363b4f..2bfd1c7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="getai", - version="0.0.982", + version="0.0.983", author="Ben Gorlick", author_email="ben@unifiedlearning.ai", description="GetAI - Asynchronous AI Downloader for models, datasets and tools",