Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented chunked download for all models #125

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 103 additions & 19 deletions nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Tuple
import shutil
import requests
import concurrent.futures
import time
import os
from tqdm import tqdm

from nexa.constants import (
NEXA_API_URL,
Expand Down Expand Up @@ -140,7 +144,7 @@ def pull_model_from_hub(model_path):

try:
result = get_model_presigned_link(model_path, token)
run_type = result['type']
run_type = result['run_type']
presigned_links = result['presigned_urls']
except Exception as e:
print(f"Failed to get download models: {e}")
Expand Down Expand Up @@ -170,7 +174,7 @@ def pull_model_from_hub(model_path):
for file_path, presigned_link in presigned_links.items():
try:
download_path = NEXA_MODELS_HUB_DIR / file_path
download_file_with_progress(presigned_link, download_path)
download_file_with_progress(presigned_link, download_path, use_processes=True)

if local_path is None:
if model_type == "onnx" or model_type == "bin":
Expand Down Expand Up @@ -269,27 +273,107 @@ def get_model_presigned_link(full_path, token):
raise


def download_file_with_progress(url, file_path: Path):
def download_chunk(url, start, end, output_file, chunk_number):
headers = {"Range": f"bytes={start}-{end}"}
max_retries = 3
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
chunk_file = f"{output_file}.part{chunk_number}"
with open(chunk_file, "wb") as f:
f.write(response.content)
return len(response.content), chunk_number
except requests.RequestException as e:
if attempt == max_retries - 1:
raise
time.sleep(2 ** attempt) # Exponential backoff


def download_file_with_progress(
url: str,
file_path: Path,
chunk_size: int = 40 * 1024 * 1024,
max_workers: int = 20,
use_processes: bool = False
):
file_path.parent.mkdir(parents=True, exist_ok=True)

response = requests.get(url, stream=True)
response.raise_for_status()
try:
response = requests.head(url, timeout=30)
response.raise_for_status()
file_size = int(response.headers.get("Content-Length", 0))
if file_size == 0:
raise ValueError("File size is 0 or Content-Length header is missing")

total_size = int(response.headers.get("content-length", 0))
block_size = 1024
chunks = [
(i, min(i + chunk_size - 1, file_size - 1))
for i in range(0, file_size, chunk_size)
]

from tqdm import tqdm
progress_bar = tqdm(
total=file_size,
unit="B",
unit_scale=True,
desc=file_path.name,
unit_divisor=1024,
)

start_time = time.time()

with open(file_path, "wb") as file, tqdm(
desc=file_path.name,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = file.write(data)
progress_bar.update(size)
executor_class = (
concurrent.futures.ProcessPoolExecutor
if use_processes
else concurrent.futures.ThreadPoolExecutor
)

with executor_class(max_workers=max_workers) as executor:
future_to_chunk = {
executor.submit(
download_chunk, url, start, end, str(file_path), i
): (i, start, end)
for i, (start, end) in enumerate(chunks)
}

completed_chunks = [False] * len(chunks)
for future in concurrent.futures.as_completed(future_to_chunk):
try:
chunk_size, chunk_number = future.result()
completed_chunks[chunk_number] = True
progress_bar.update(chunk_size)
except Exception as e:
print(f"Error downloading chunk {chunk_number}: {e}")

progress_bar.close()

if all(completed_chunks):
with open(file_path, "wb") as final_file:
for i in range(len(chunks)):
chunk_file = f"{file_path}.part{i}"
with open(chunk_file, "rb") as part_file:
final_file.write(part_file.read())
os.remove(chunk_file)

end_time = time.time()
total_time = end_time - start_time
average_speed = file_size / total_time / (1024 * 1024) # in MB/s

else:
raise Exception("Some chunks failed to download")

except requests.exceptions.RequestException as e:
print(f"Error occurred while making the request: {e}")
except ValueError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
# Clean up partial files
for i in range(len(chunks)):
chunk_file = f"{file_path}.part{i}"
if os.path.exists(chunk_file):
os.remove(chunk_file)
if os.path.exists(file_path):
os.remove(file_path)


def download_model_from_official(model_path, model_type):
Expand All @@ -304,7 +388,7 @@ def download_model_from_official(model_path, model_type):
download_url = f"{NEXA_OFFICIAL_BUCKET}{filepath}"

full_path.parent.mkdir(parents=True, exist_ok=True)
download_file_with_progress(download_url, full_path)
download_file_with_progress(download_url, full_path, use_processes=True)

if model_type == "onnx" or model_type == "bin":
unzipped_folder = full_path.parent / model_version
Expand Down
Loading