Skip to content

Commit

Permalink
Added logic to preserve the structure of a local repo during upload (#93
Browse files Browse the repository at this point in the history
)

[For the purpose of keras upload
integration](https://buganizer.corp.google.com/issues/329851144)
Currently, KaggleHub does not preserve the structure of the repo its
uploading. The changes added in this PR make it preserve the repo
structure.
Furthermore, i added an integration test for uploading zip files and
updated integration tests logic to retry model uploading until the
instance's archive is updated
  • Loading branch information
mohami2000 authored Mar 21, 2024
1 parent 280f22c commit cc57f28
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 88 deletions.
84 changes: 83 additions & 1 deletion integration_tests/test_model_upload.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,63 @@
import logging
import os
import tempfile
import time
import unittest
import uuid
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Type, TypeVar

from kagglehub import model_upload, models_helpers
from kagglehub.config import get_kaggle_credentials
from kagglehub.exceptions import BackendError

LICENSE_NAME = "MIT"

logger = logging.getLogger(__name__)


ReturnType = TypeVar("ReturnType")


def retry(
times: int = 5, delay_seconds: int = 5, exception_to_check: Type[Exception] = Exception
) -> Callable[[Callable[..., ReturnType]], Callable[..., ReturnType]]:
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@wraps(func)
def wrapper(*args: object, **kwargs: object) -> ReturnType:
attempts = 0
while attempts < times:
try:
return func(*args, **kwargs)
except exception_to_check as e:
attempts += 1
if attempts == times:
time_out_message = "Maximum retries reached without success."
raise TimeoutError(time_out_message) from e
logger.info(f"Attempt {attempts} failed: {e}. Retrying in {delay_seconds} seconds...")
time.sleep(delay_seconds)
runtime_error_message = "Unexpected exit from retry loop. This should not happen."
raise RuntimeError(runtime_error_message)

return wrapper

return decorator


@retry(times=5, delay_seconds=5, exception_to_check=BackendError)
def upload_with_retries(handle: str, temp_dir: str, license_name: str) -> None:
"""
Uploads a model with retries on BackendError indicating the instance slug is already in use.
Args:
handle: The model handle.
temp_dir: Temporary directory where the model is stored.
license_name: License name for the model.
"""
model_upload(handle, temp_dir, license_name)


class TestModelUpload(unittest.TestCase):
def setUp(self) -> None:
Expand All @@ -30,10 +79,43 @@ def test_model_upload_and_versioning(self) -> None:
model_upload(self.handle, self.temp_dir, LICENSE_NAME)

# Create Version
model_upload(self.handle, self.temp_dir, LICENSE_NAME)
upload_with_retries(self.handle, self.temp_dir, LICENSE_NAME)

# If delete model does not raise an error, then the upload was successful.

def test_model_upload_and_versioning_zip(self) -> None:
with TemporaryDirectory() as temp_dir:
for i in range(60):
test_filepath = Path(temp_dir) / f"temp_test_file_{i}"
test_filepath.touch()

# Create Instance
model_upload(self.handle, temp_dir, LICENSE_NAME)

# Create Version
upload_with_retries(self.handle, temp_dir, LICENSE_NAME)

def test_model_upload_directory(self) -> None:
with TemporaryDirectory() as temp_dir:
# Create the new folder within temp_dir
inner_folder_path = Path(temp_dir) / "inner_folder"
inner_folder_path.mkdir()

for i in range(60):
# Create a file in the temp_dir
test_filepath = Path(temp_dir) / f"temp_test_file_{i}"
test_filepath.touch()

# Create the same file in the inner_folder
test_filepath_inner = inner_folder_path / f"temp_test_file_{i}"
test_filepath_inner.touch()

# Create Instance
model_upload(self.handle, temp_dir, LICENSE_NAME)

# Create Version
upload_with_retries(self.handle, temp_dir, LICENSE_NAME)

def test_model_upload_nested_dir(self) -> None:
# Create a nested directory within self.temp_dir
nested_dir = Path(self.temp_dir) / "nested"
Expand Down
98 changes: 17 additions & 81 deletions src/kagglehub/gcs_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import time
import zipfile
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional, Union
from typing import List, Union

import requests
from requests.exceptions import ConnectionError, Timeout
Expand Down Expand Up @@ -135,88 +136,23 @@ def _upload_blob(file_path: str, model_type: str) -> str:
return response["token"]


def upload_files(folder: str, model_type: str, quiet: bool = False) -> List[str]: # noqa: FBT002, FBT001
"""upload files in a folder. Zips the files if there are more than 50.
def upload_files(source_dir: str, model_type: str) -> List[str]:
"""Zip and Upload directory.
Parameters
==========
folder: the folder to upload from
quiet: suppress verbose output (default is False)
source_dir: the source_dir to upload from
model_type: Type of the model that is being uploaded.
"""

# Count the total number of files
file_count = 0
for _, _, files in os.walk(folder):
file_count += len(files)

if file_count > MAX_FILES_TO_UPLOAD:
if not quiet:
logger.info(f"More than {MAX_FILES_TO_UPLOAD} files detected, creating a zip archive...")

with TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, TEMP_ARCHIVE_FILE)
with zipfile.ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(folder):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, folder))

# Upload the zip file
return [
token
for token in [_upload_file_or_folder(temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet)]
if token is not None
]

tokens = []
for root, _, files in os.walk(folder):
for file in files:
token = _upload_file_or_folder(root, file, model_type, quiet)
if token is not None:
tokens.append(token)

return tokens


def _upload_file_or_folder(
parent_path: str, file_or_folder_name: str, model_type: str, quiet: bool = False # noqa: FBT002, FBT001
) -> Optional[str]:
"""
Uploads a file or each file inside a folder individually from a specified path to a remote service.
Parameters
==========
parent_path: The parent directory path from where the file or folder is to be uploaded.
file_or_folder_name: The name of the file or folder to be uploaded.
dir_mode: The mode to handle directories. Accepts 'zip', 'tar', or other values for skipping.
model_type: Type of the model that is being uploaded.
quiet: suppress verbose output (default is False)
:return: A token if the upload is successful, or None if the file is skipped or the upload fails.
"""
full_path = os.path.join(parent_path, file_or_folder_name)
if os.path.isfile(full_path):
return _upload_file(file_or_folder_name, full_path, quiet, model_type)
elif not quiet:
logger.info("Skipping: " + file_or_folder_name)
return None


def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001
"""Helper function to upload a single file
Parameters
==========
file_name: name of the file to upload
full_path: path to the file to upload
quiet: suppress verbose output
model_type: Type of the model that is being uploaded.
:return: None - upload unsuccessful; instance of UploadFile - upload successful
"""

if not quiet:
logger.info("Starting upload for file " + file_name)

content_length = os.path.getsize(full_path)
token = _upload_blob(full_path, model_type)
if not quiet:
logger.info("Upload successful: " + file_name + " (" + File.get_size(content_length) + ")")
return token
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
zip_path = temp_dir_path / TEMP_ARCHIVE_FILE
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
source_dir_path = Path(source_dir)
for file_path in source_dir_path.rglob("*"):
if file_path.is_file():
arcname = file_path.relative_to(source_dir_path)
zipf.write(file_path, arcname)

# Upload the zip file
return [token for token in [_upload_blob(str(zip_path), model_type)] if token]
12 changes: 6 additions & 6 deletions tests/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_model_upload_instance_with_valid_handle(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_instance_with_nested_directories(self) -> None:
# execution path: get_model -> create_model -> get_instance -> create_version
Expand All @@ -156,7 +156,7 @@ def test_model_upload_instance_with_nested_directories(self) -> None:
test_filepath.touch()
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_version_with_valid_handle(self) -> None:
# execution path: get_model -> get_instance -> create_instance
Expand All @@ -168,7 +168,7 @@ def test_model_upload_version_with_valid_handle(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_too_many_files(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_model_upload_resumable(self) -> None:
# Check that GcsAPIHandler received two PUT requests
self.assertEqual(GcsAPIHandler.put_requests_count, 2)
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_none_license(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand All @@ -209,7 +209,7 @@ def test_model_upload_with_none_license(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_without_license(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand All @@ -219,7 +219,7 @@ def test_model_upload_without_license(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_invalid_license_fails(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
from unittest.mock import MagicMock, patch

from integration_tests.test_model_upload import retry


class FunctionToTest:
def __init__(self):
self.attempt = 0

def __call__(self, success_on_attempt: int) -> str:
"""A simple function that raises an exception until it reaches the successful attempt."""
if self.attempt < success_on_attempt:
self.attempt += 1
value_error_message = "Test error"
raise ValueError(value_error_message)
return "Success"


class TestRetryDecorator(unittest.TestCase):
def setUp(self) -> None:
self.function_to_test = FunctionToTest()

@patch("integration_tests.test_model_upload.time.sleep", autospec=True)
@patch("integration_tests.test_model_upload.logger.info", autospec=True)
def test_retry_success_before_limit(self, mock_logger_info: MagicMock, mock_sleep: MagicMock) -> None:
decorated = retry(times=3, delay_seconds=1)(self.function_to_test)
result = decorated(2)
self.assertEqual(result, "Success")
self.assertEqual(self.function_to_test.attempt, 2)
self.assertEqual(mock_sleep.call_count, 2)
self.assertEqual(mock_logger_info.call_count, 2)

@patch("integration_tests.test_model_upload.time.sleep", autospec=True)
@patch("integration_tests.test_model_upload.logger.info", autospec=True)
def test_retry_reaches_limit_raises_timeout(self, mock_logger_info: MagicMock, mock_sleep: MagicMock) -> None:
decorated = retry(times=3, delay_seconds=2)(self.function_to_test)
with self.assertRaises(TimeoutError):
decorated(4)
self.assertEqual(self.function_to_test.attempt, 3)
self.assertEqual(mock_sleep.call_count, 2)
self.assertEqual(mock_logger_info.call_count, 2)


if __name__ == "__main__":
unittest.main()

0 comments on commit cc57f28

Please sign in to comment.