Skip to content

Commit

Permalink
Resume download works now. Downloads are faster. Other small fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
bgorlick committed Jun 18, 2024
1 parent 27c3b1f commit 1ad0703
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 30 deletions.
93 changes: 65 additions & 28 deletions getai/core/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,24 @@ async def download_model_files(
semaphore = asyncio.Semaphore(self.max_connections)

async def download_file(link: str):
"""Download a single file."""
"""Download a single file with improved semaphore management."""
async with semaphore:
filename = Path(link).name
file_hash: Optional[str] = sha256_dict.get(filename)

# Log the download start
self.logger.info(f"Starting download for {filename}")

if file_hash is None:
print(
f"Warning: No SHA256 hash found for {filename}. Downloading without sha256 verification."
)

await self._download_model_file(session, link, output_folder, file_hash)

tasks = [asyncio.ensure_future(download_file(link)) for link in links]
# Ensure all tasks are completed before proceeding
await asyncio.gather(*tasks)
self.logger.info("All download tasks completed")

async def _download_model_file(
self,
Expand All @@ -257,39 +261,72 @@ async def _download_model_file(
output_folder: Path,
file_hash: Optional[str],
):
"""Download and save a model file."""
"""Download and save a model file with support for resuming downloads."""
from rainbow_tqdm import tqdm

filename = Path(url.rsplit("/", 1)[1])
output_path = output_folder / filename

# Check if the file already exists and get its size
if output_path.exists():
current_hash = await self.calculate_file_sha256(output_path)
if current_hash == file_hash:
print(
f"'{filename}' exists and matches expected SHA256 hash; skipping."
)
return
else:
print(
f"'{filename}' exists but SHA256 hash matching failed; redownloading."
)
file_size = output_path.stat().st_size
else:
file_size = 0

async with session.get(url) as response:
response.raise_for_status()
total_size = response.content_length or 0
async with aiofiles.open(output_path, "wb") as f:
progress_bar = tqdm(
total=total_size,
desc=filename.name,
unit="iB",
unit_scale=True,
ncols=100,
)
async for chunk in response.content.iter_chunked(1024):
await f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
headers = {}
if file_size > 0:
headers["Range"] = f"bytes={file_size}-"
self.logger.info(
f"Resuming download for '{filename}' from byte {file_size}"
)

# Retry logic for downloading the file
max_attempts = self.max_retries
attempt = 0
while attempt < max_attempts:
try:
async with session.get(url, headers=headers) as response:
if response.status == 416:
# Handle the case where the file is already completely downloaded
self.logger.info("'%s' is already fully downloaded.", filename)
return

response.raise_for_status()
total_size = response.content_length or 0
# If resuming, adjust total_size to reflect the remaining bytes
if "Content-Range" in response.headers:
content_range = response.headers["Content-Range"]
total_size = int(content_range.split("/")[1]) - file_size

mode = "ab" if file_size > 0 else "wb"
# Ensure correct handling of file path with aiofiles.open in binary mode
async with aiofiles.open(output_path, mode) as f:
progress_bar = tqdm(
total=total_size,
desc=filename.name,
unit="iB",
unit_scale=True,
ncols=100,
initial=file_size, # Start the progress bar from the current file size
)
async for chunk in response.content.iter_chunked(
1024 * 10
): # Use larger chunk size for efficiency
await f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()

break
except Exception as e:
attempt += 1
self.logger.warning(f"Attempt {attempt} failed: {e}")
if attempt >= max_attempts:
self.logger.error(
"Failed to download '%s' after %d attempts",
filename,
max_attempts,
)
raise

async def check_model_files(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "getai"
version = "0.0.982"
version = "0.0.983"
description = "GetAI - An asynchronous AI search and download tool for AI models, datasets, and tools. Designed to streamline the process of downloading machine learning models, datasets, and more."
authors = ["Ben Gorlick <[email protected]>"]
license = "MIT - with attribution"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="getai",
version="0.0.982",
version="0.0.983",
author="Ben Gorlick",
author_email="[email protected]",
description="GetAI - Asynchronous AI Downloader for models, datasets and tools",
Expand Down

0 comments on commit 1ad0703

Please sign in to comment.