From c185b41ef7d58179c5299177fb1a1df1515234ec Mon Sep 17 00:00:00 2001 From: Benjamin Gorlick Date: Fri, 7 Jun 2024 05:23:09 -0700 Subject: [PATCH] End to End Refactor for more modular scaling --- getai/0 | 1 - getai/__init__.py | 40 ++- getai/__main__.py | 8 +- getai/api.py | 118 -------- getai/api/__init__.py | 25 ++ getai/api/datasets.py | 56 ++++ getai/api/models.py | 64 +++++ getai/api/utils.py | 29 ++ getai/cli/__init__.py | 13 + getai/cli/cli.py | 200 ++++++++++++++ getai/cli/utils.py | 100 +++++++ getai/core/__init__.py | 31 +++ getai/{ => core}/dataset_downloader.py | 119 ++++---- getai/{ => core}/dataset_search.py | 158 +++++------ getai/{ => core}/model_downloader.py | 259 +++++++++--------- getai/{ => core}/model_search.py | 23 +- getai/{ => core}/session_manager.py | 2 +- getai/core/utils.py | 116 ++++++++ .../example_getai_stanfordnlp_imdb_dataset.py | 0 getai/getai_config.yaml | 1 - getai/main.py | 184 ------------- getai/utils.py | 81 ------ pyproject.toml | 10 +- setup.py | 1 - 24 files changed, 947 insertions(+), 692 deletions(-) delete mode 100644 getai/0 delete mode 100644 getai/api.py create mode 100644 getai/api/__init__.py create mode 100644 getai/api/datasets.py create mode 100644 getai/api/models.py create mode 100644 getai/api/utils.py create mode 100644 getai/cli/__init__.py create mode 100644 getai/cli/cli.py create mode 100644 getai/cli/utils.py create mode 100644 getai/core/__init__.py rename getai/{ => core}/dataset_downloader.py (85%) rename getai/{ => core}/dataset_search.py (66%) rename getai/{ => core}/model_downloader.py (57%) rename getai/{ => core}/model_search.py (98%) rename getai/{ => core}/session_manager.py (90%) create mode 100644 getai/core/utils.py rename {examples => getai/examples}/example_getai_stanfordnlp_imdb_dataset.py (100%) delete mode 100644 getai/getai_config.yaml delete mode 100644 getai/main.py delete mode 100644 getai/utils.py diff --git a/getai/0 b/getai/0 deleted file mode 100644 index 38d373d..0000000 --- a/getai/0 +++ /dev/null @@ -1 +0,0 @@ -6920 diff --git a/getai/__init__.py b/getai/__init__.py index 93dcf6c..bae0c0c 100644 --- a/getai/__init__.py +++ b/getai/__init__.py @@ -1,21 +1,31 @@ -# getai/__init__.py for GetAI - Contains the core API functions for searching and downloading datasets and models. +""" +Initialization module for the getai package. +""" -from getai.utils import get_hf_token -from getai.session_manager import SessionManager - -# Import the core API functions from the api module -from getai.api import ( - search_datasets, - download_dataset, - search_models, - download_model, +from .core import ( + AsyncDatasetDownloader, + AsyncDatasetSearch, + AsyncModelDownloader, + AsyncModelSearch, + SessionManager, + convert_to_bytes, + interactive_branch_selection, + get_hf_token, + get_hf_token_from_cli, + hf_login, ) __all__ = [ - "search_datasets", - "download_dataset", - "search_models", - "download_model", - "get_hf_token", + "AsyncDatasetDownloader", + "AsyncDatasetSearch", + "AsyncModelDownloader", + "AsyncModelSearch", "SessionManager", + "convert_to_bytes", + "interactive_branch_selection", + "get_hf_token", + "get_hf_token_from_cli", + "hf_login", ] + +# Path: getai/cli/__init__.py diff --git a/getai/__main__.py b/getai/__main__.py index 34690b7..87e4e69 100644 --- a/getai/__main__.py +++ b/getai/__main__.py @@ -1,11 +1,11 @@ -# __main__.py +# getai/__main__.py import asyncio -from getai.main import main +from getai.cli import cli_main def run(): - asyncio.run(main()) + asyncio.run(cli_main()) -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/getai/api.py b/getai/api.py deleted file mode 100644 index b5e6e65..0000000 --- a/getai/api.py +++ /dev/null @@ -1,118 +0,0 @@ -""" api.py for GetAI - Contains the core API functions for searching and downloading datasets and models. """ - -from pathlib import Path -import logging - -from getai.utils import get_hf_token -from getai.session_manager import SessionManager - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -# Define default directories -DEFAULT_MODEL_DIR = Path.home() / ".getai" / "models" -DEFAULT_DATASET_DIR = Path.home() / ".getai" / "datasets" - - -async def search_datasets( - query, hf_token=None, max_connections=5, output_dir=None, **kwargs -): - """Search datasets on Hugging Face based on a query.""" - hf_token = hf_token or get_hf_token() - output_dir = Path(output_dir) if output_dir else DEFAULT_DATASET_DIR / query - - from getai.dataset_search import AsyncDatasetSearch - - async with SessionManager( - max_connections=max_connections, hf_token=hf_token - ) as manager: - session = await manager.get_session() - searcher = AsyncDatasetSearch( - query=query, - filtered_datasets=[], # Initial empty list - total_datasets=0, - output_dir=output_dir, - max_connections=max_connections, - hf_token=hf_token, - session=session, - ) - # Display search results interactively - await searcher.display_dataset_search_results() - - -async def download_dataset( - identifier, hf_token=None, max_connections=5, output_dir=None, **kwargs -): - """Download a dataset from Hugging Face by its identifier.""" - hf_token = hf_token or get_hf_token() - output_dir = Path(output_dir) if output_dir else DEFAULT_DATASET_DIR / identifier - - from getai.dataset_downloader import AsyncDatasetDownloader - - async with SessionManager( - max_connections=max_connections, hf_token=hf_token - ) as manager: - session = await manager.get_session() - downloader = AsyncDatasetDownloader( - session=session, max_connections=max_connections, output_dir=output_dir - ) - await downloader.download_dataset_info( - identifier, - **{ - key: value - for key, value in kwargs.items() - if key in downloader.get_expected_kwargs() - } - ) - - -async def search_models(query, hf_token=None, max_connections=5, **kwargs): - """Search models on Hugging Face based on a query.""" - hf_token = hf_token or get_hf_token() - - from getai.model_search import AsyncModelSearch - from getai.model_downloader import AsyncModelDownloader - - async with SessionManager( - max_connections=max_connections, hf_token=hf_token - ) as manager: - session = await manager.get_session() - searcher = AsyncModelSearch( - query=query, - max_connections=max_connections, - session=session, - hf_token=hf_token, - ) - await searcher.search_models(query, **kwargs) - - downloader = AsyncModelDownloader( - session=session, - max_connections=max_connections, - output_dir=DEFAULT_MODEL_DIR, - ) - for model in searcher.filtered_models: - await downloader.download_model(model["id"], branch="main", **kwargs) - - -async def download_model( - identifier, - branch="main", - hf_token=None, - max_connections=5, - output_dir=None, - **kwargs -): - """Download a model from Hugging Face by its identifier and branch.""" - hf_token = hf_token or get_hf_token() - output_dir = Path(output_dir) if output_dir else DEFAULT_MODEL_DIR - - from getai.model_downloader import AsyncModelDownloader - - async with SessionManager( - max_connections=max_connections, hf_token=hf_token - ) as manager: - session = await manager.get_session() - downloader = AsyncModelDownloader( - session=session, max_connections=max_connections, output_dir=output_dir - ) - await downloader.download_model(identifier, branch, **kwargs) diff --git a/getai/api/__init__.py b/getai/api/__init__.py new file mode 100644 index 0000000..d336f59 --- /dev/null +++ b/getai/api/__init__.py @@ -0,0 +1,25 @@ +# getai/api/__init__.py - Initialization module for the getai.api package. + +from getai.api.datasets import DatasetAPI +from getai.api.models import ModelAPI +from getai.api.utils import UtilsAPI + +# Exposing class methods as module-level functions +search_datasets = DatasetAPI.search_datasets +download_dataset = DatasetAPI.download_dataset +search_models = ModelAPI.search_models +download_model = ModelAPI.download_model +get_hf_token = UtilsAPI.get_hf_token +hf_login = UtilsAPI.hf_login + +__all__ = [ + "search_datasets", + "download_dataset", + "search_models", + "download_model", + "get_hf_token", + "hf_login", + "DatasetAPI", + "ModelAPI", + "UtilsAPI", +] diff --git a/getai/api/datasets.py b/getai/api/datasets.py new file mode 100644 index 0000000..da19487 --- /dev/null +++ b/getai/api/datasets.py @@ -0,0 +1,56 @@ +# getai/api/datasets.py - This module contains the API methods for searching and downloading datasets. + +from pathlib import Path +from typing import Optional +from getai.api.utils import get_hf_token + + +class DatasetAPI: + + @staticmethod + async def search_datasets( + query: str, + hf_token: Optional[str] = None, + max_connections: int = 5, + output_dir: Optional[Path] = None, + **kwargs, + ): + from getai.core.dataset_search import AsyncDatasetSearch + + hf_token = hf_token or get_hf_token() + output_dir = output_dir or Path.home() / ".getai" / "datasets" / query + + searcher = AsyncDatasetSearch( + query=query, + output_dir=output_dir, + max_connections=max_connections, + hf_token=hf_token, + ) + await searcher.display_dataset_search_results() + + @staticmethod + async def download_dataset( + identifier: str, + hf_token: Optional[str] = None, + max_connections: int = 5, + output_dir: Optional[Path] = None, + **kwargs, + ): + from getai.core.dataset_downloader import AsyncDatasetDownloader + + hf_token = hf_token or get_hf_token() + output_dir = output_dir or Path.home() / ".getai" / "datasets" / identifier + + downloader = AsyncDatasetDownloader( + output_dir=output_dir, + max_connections=max_connections, + hf_token=hf_token, + ) + await downloader.download_dataset_info( + dataset_id=identifier, + **{ + key: value + for key, value in kwargs.items() + if key in downloader.get_expected_kwargs() + }, + ) diff --git a/getai/api/models.py b/getai/api/models.py new file mode 100644 index 0000000..8d6c7b5 --- /dev/null +++ b/getai/api/models.py @@ -0,0 +1,64 @@ +# getai/api/models.py - GetAI API methods for searching and downloading models. + +# getai/api/models.py - This module contains the API methods for searching and downloading models. + +from functools import lru_cache +from typing import Optional +from pathlib import Path + + +class ModelAPI: + + @staticmethod + @lru_cache(maxsize=None) + def _get_hf_token(token: Optional[str] = None) -> str: + from getai.core.utils import get_hf_token + + return token or get_hf_token() + + @staticmethod + async def search_models( + query: str, + hf_token: Optional[str] = None, + max_connections: int = 5, + **kwargs, + ): + from getai.api.utils import get_hf_token + from getai.core.model_search import AsyncModelSearch + from getai.core.session_manager import SessionManager + + hf_token = hf_token or get_hf_token() + + async with SessionManager( + max_connections=max_connections, hf_token=hf_token + ) as manager: + session = await manager.get_session() + searcher = AsyncModelSearch( + query=query, + max_connections=max_connections, + session=session, + hf_token=hf_token, + ) + await searcher.search_models(query, **kwargs) + + @staticmethod + async def download_model( + identifier: str, + branch: str = "main", + hf_token: Optional[str] = None, + max_connections: int = 5, + output_dir: Optional[Path] = None, + **kwargs, + ): + from getai.api.utils import get_hf_token + from getai.core.model_downloader import AsyncModelDownloader + + hf_token = hf_token or get_hf_token() + output_dir = output_dir or Path.home() / ".getai" / "models" + + downloader = AsyncModelDownloader( + output_dir=output_dir, + max_connections=max_connections, + hf_token=hf_token, + ) + await downloader.download_model(identifier, branch, **kwargs) diff --git a/getai/api/utils.py b/getai/api/utils.py new file mode 100644 index 0000000..9c1030a --- /dev/null +++ b/getai/api/utils.py @@ -0,0 +1,29 @@ +# getai/api/utils.py - This module contains utility functions for the API. + +from functools import lru_cache +from typing import Optional +from getai.core.utils import CoreUtils + + +class UtilsAPI: + @staticmethod + @lru_cache(maxsize=None) + def get_hf_token(token: Optional[str] = None) -> str: + """Retrieve the Hugging Face token using caching for efficiency.""" + return token or CoreUtils.get_hf_token() + + @staticmethod + def hf_login(): + """Log in using Hugging Face CLI.""" + CoreUtils.hf_login() + + +# Expose class methods as module-level functions +get_hf_token = UtilsAPI.get_hf_token +hf_login = UtilsAPI.hf_login + +__all__ = [ + "UtilsAPI", + "get_hf_token", + "hf_login", +] diff --git a/getai/cli/__init__.py b/getai/cli/__init__.py new file mode 100644 index 0000000..d8f0e9a --- /dev/null +++ b/getai/cli/__init__.py @@ -0,0 +1,13 @@ +""" +getai/cli/__init__.py - Initialization module for the getai.cli package. +""" + +from getai.cli.cli import main as cli_main, add_common_arguments, define_subparsers +from getai.cli.utils import CLIUtils + +__all__ = [ + "cli_main", + "add_common_arguments", + "define_subparsers", + "CLIUtils", +] diff --git a/getai/cli/cli.py b/getai/cli/cli.py new file mode 100644 index 0000000..4fd38b8 --- /dev/null +++ b/getai/cli/cli.py @@ -0,0 +1,200 @@ +import argparse +from pathlib import Path +import asyncio +import logging +from aiohttp import ClientError + +from getai.api import search_datasets, download_dataset, search_models, download_model + +from getai.cli.utils import CLIUtils +from getai.core.dataset_search import AsyncDatasetSearch + + +# Configure logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def add_common_arguments(parser: argparse.ArgumentParser): + """Add common arguments for dataset and model parsers.""" + parser.add_argument("--output-dir", type=str, default=None, help="Output directory") + parser.add_argument( + "--max-connections", type=int, default=5, help="Max connections" + ) + + +def define_subparsers(parser: argparse.ArgumentParser): + """Define top-level subparsers for search and download commands.""" + subparsers = parser.add_subparsers(dest="mode", help="Mode of operation") + + # Search mode + search_parser = subparsers.add_parser( + "search", help="Search for models or datasets" + ) + search_subparsers = search_parser.add_subparsers( + dest="search_mode", help="Search mode" + ) + + # Search datasets + search_datasets_parser = search_subparsers.add_parser( + "datasets", help="Search for datasets" + ) + search_datasets_parser.add_argument( + "query", type=str, help="Search query for datasets" + ) + search_datasets_parser.add_argument("--author", type=str, help="Filter by author") + search_datasets_parser.add_argument( + "--filter-criteria", type=str, help="Filter criteria" + ) + search_datasets_parser.add_argument("--sort", type=str, help="Sort by") + search_datasets_parser.add_argument("--direction", type=str, help="Sort direction") + search_datasets_parser.add_argument("--limit", type=int, help="Limit results") + search_datasets_parser.add_argument( + "--full", action="store_true", help="Full dataset info" + ) + add_common_arguments(search_datasets_parser) + + # Search models + search_models_parser = search_subparsers.add_parser( + "models", help="Search for models" + ) + search_models_parser.add_argument("query", type=str, help="Search query for models") + add_common_arguments(search_models_parser) + + # Download mode + download_parser = subparsers.add_parser( + "download", help="Download models or datasets" + ) + download_subparsers = download_parser.add_subparsers( + dest="download_mode", help="Download mode" + ) + + # Download dataset + download_dataset_parser = download_subparsers.add_parser( + "dataset", help="Download a dataset" + ) + download_dataset_parser.add_argument( + "identifier", type=str, help="Dataset identifier" + ) + download_dataset_parser.add_argument( + "--revision", type=str, help="Dataset revision" + ) + download_dataset_parser.add_argument( + "--full", action="store_true", help="Full dataset info" + ) + add_common_arguments(download_dataset_parser) + + # Download model + download_model_parser = download_subparsers.add_parser( + "model", help="Download a model" + ) + download_model_parser.add_argument("identifier", type=str, help="Model identifier") + download_model_parser.add_argument( + "--branch", type=str, default="main", help="Model branch" + ) + download_model_parser.add_argument( + "--clean", action="store_true", help="Start from scratch" + ) + download_model_parser.add_argument( + "--check", action="store_true", help="Check files" + ) + add_common_arguments(download_model_parser) + + +async def main(): + """Main function for the GetAI CLI""" + parser = argparse.ArgumentParser(description="GetAI CLI") + define_subparsers(parser) + parser.add_argument( + "--hf-login", action="store_true", help="Log in using Hugging Face CLI" + ) + args = parser.parse_args() + + logger.info("Parsed arguments: %s", args) + + if args.hf_login: + logger.info("Logging in to Hugging Face CLI") + CLIUtils.hf_login() + return + + hf_token = CLIUtils.get_hf_token() + + if not args.mode: + logger.error("Invalid mode. Please specify a valid mode.") + parser.print_help() + return + + try: + if args.mode == "search": + if args.search_mode == "datasets": + logger.info("Searching datasets with query: %s", args.query) + search_instance = AsyncDatasetSearch( + query=args.query, + output_dir=args.output_dir or Path.home() / ".getai" / "datasets", + max_connections=args.max_connections, + hf_token=hf_token, + ) + await search_instance.display_dataset_search_results() + elif args.search_mode == "models": + logger.info("Searching models with query: %s", args.query) + await search_models( + query=args.query, + hf_token=hf_token, + max_connections=args.max_connections, + ) + else: + logger.error( + "Invalid search subcommand. Please specify 'datasets' or 'models'." + ) + parser.print_help() + + elif args.mode == "download": + if args.download_mode == "dataset": + logger.info("Downloading dataset: %s", args.identifier) + await download_dataset( + identifier=args.identifier, + revision=args.revision, + full=args.full, + output_dir=args.output_dir, + max_connections=args.max_connections, + ) + elif args.download_mode == "model": + logger.info("Downloading model: %s", args.identifier) + await download_model( + identifier=args.identifier, + branch=args.branch, + hf_token=hf_token, + max_connections=args.max_connections, + output_dir=args.output_dir, + clean=args.clean, + check=args.check, + ) + else: + logger.error( + "Invalid download subcommand. Please specify 'dataset' or 'model'." + ) + parser.print_help() + + except KeyboardInterrupt: + logger.info("KeyboardInterrupt received. Closing operation...") + except ClientError as e: + logger.error("HTTP error during operation: %s", e) + except asyncio.CancelledError: + logger.info("Task cancelled during operation.") + except ValueError as e: + logger.error("Value error during operation: %s", e) + except Exception as e: + logger.error("Unexpected error during operation: %s", e) + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) + except KeyboardInterrupt: + logger.info("KeyboardInterrupt received. Closing operation...") + finally: + pending_tasks = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*pending_tasks, return_exceptions=True)) + loop.close() diff --git a/getai/cli/utils.py b/getai/cli/utils.py new file mode 100644 index 0000000..be19252 --- /dev/null +++ b/getai/cli/utils.py @@ -0,0 +1,100 @@ +# getai/cli/utils.py - Utility functions for the CLI. + +import os +from pathlib import Path +import logging +import subprocess + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +class CLIUtils: + @staticmethod + def convert_to_bytes(size_str): + """Convert size string like '2.3 GB' or '200 MB' to bytes.""" + try: + size_units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3} + size, unit = size_str.split() + return int(float(size) * size_units[unit]) + except Exception as e: + logging.exception("Error converting size to bytes: %s", e) + raise + + @staticmethod + async def interactive_branch_selection(branches): + """Prompt user to select a branch interactively from a list.""" + try: + from prompt_toolkit import PromptSession + from prompt_toolkit.completion import WordCompleter + + branch_completer = WordCompleter(branches, ignore_case=True) + session = PromptSession(completer=branch_completer) + selected_branch = await session.prompt_async( + "Select a branch [Press TAB]: " + ) + return selected_branch if selected_branch in branches else "main" + except Exception as e: + logging.exception("Error during interactive branch selection: %s", e) + raise + + @staticmethod + def get_hf_token(): + """Retrieve the Hugging Face token securely from environment variables or the CLI.""" + try: + hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN") + if hf_token: + logging.info("Using Hugging Face token from environment variable.") + return hf_token + + hf_token_file = Path.home() / ".huggingface" / "token" + if hf_token_file.exists(): + with open(hf_token_file, "r", encoding="utf-8") as f: + logging.info("Using Hugging Face token from ~/.huggingface/token.") + return f.read().strip() + + hf_token = CLIUtils.get_hf_token_from_cli() + if hf_token: + logging.info("Using Hugging Face token from Hugging Face CLI.") + return hf_token + + raise ValueError( + "No Hugging Face token found. Please log in using the Hugging Face CLI." + ) + except Exception as e: + logging.exception("Error retrieving Hugging Face token: %s", e) + raise + + @staticmethod + def get_hf_token_from_cli(): + """Retrieve Hugging Face token using the CLI.""" + token_file = os.path.expanduser("~/.cache/huggingface/token") + try: + with open(token_file, "r", encoding="utf-8") as f: + return f.read().strip() + except FileNotFoundError: + logging.error( + "Hugging Face token file not found. Please log in using `huggingface-cli login`." + ) + return None + except Exception as e: + logging.exception("Error retrieving Hugging Face token from file: %s", e) + return None + + @staticmethod + def hf_login(): + """Log in using Hugging Face CLI.""" + try: + result = subprocess.run( + ["huggingface-cli", "login"], check=True, capture_output=True, text=True + ) + logging.info("Hugging Face CLI login successful: %s", result.stdout) + except subprocess.CalledProcessError as e: + logging.error("Hugging Face CLI login failed: %s", e.stderr) + except FileNotFoundError: + logging.error( + "Hugging Face CLI not found. Please install it and try again." + ) + except Exception as e: + logging.exception("Unexpected error during Hugging Face CLI login: %s", e) diff --git a/getai/core/__init__.py b/getai/core/__init__.py new file mode 100644 index 0000000..d122c4c --- /dev/null +++ b/getai/core/__init__.py @@ -0,0 +1,31 @@ +""" +getai/core/__init__.py - Initialization module for the getai.core package. +""" + +from getai.core.dataset_downloader import AsyncDatasetDownloader +from getai.core.dataset_search import AsyncDatasetSearch +from getai.core.model_downloader import AsyncModelDownloader +from getai.core.model_search import AsyncModelSearch +from getai.core.session_manager import SessionManager +from getai.core.utils import ( + CoreUtils, + convert_to_bytes, + interactive_branch_selection, + get_hf_token, + get_hf_token_from_cli, + hf_login, +) + +__all__ = [ + "AsyncDatasetDownloader", + "AsyncDatasetSearch", + "AsyncModelDownloader", + "AsyncModelSearch", + "SessionManager", + "convert_to_bytes", + "interactive_branch_selection", + "get_hf_token", + "get_hf_token_from_cli", + "hf_login", + "CoreUtils", # Exposing the class as well +] diff --git a/getai/dataset_downloader.py b/getai/core/dataset_downloader.py similarity index 85% rename from getai/dataset_downloader.py rename to getai/core/dataset_downloader.py index 65e9af7..1ca459e 100644 --- a/getai/dataset_downloader.py +++ b/getai/core/dataset_downloader.py @@ -1,4 +1,4 @@ -""" dataset_downloader.py - GetAI Asynchronous Dataset Downloader. """ +""" getai/core/dataset_downloader.py - GetAI Asynchronous Dataset Downloader. """ import logging import json @@ -8,7 +8,6 @@ import asyncio import aiohttp import aiofiles -from aiohttp import ClientSession from rainbow_tqdm import tqdm from tenacity import retry, stop_after_attempt, wait_exponential @@ -28,13 +27,11 @@ class AsyncDatasetDownloader: def __init__( self, - session: ClientSession, output_dir: Optional[Path] = None, max_connections: int = 5, hf_token: Optional[str] = None, ): """Initialize the downloader with settings.""" - self.session = session self.output_dir: Path = ( output_dir if output_dir else Path.home() / ".getai" / "datasets" ) @@ -50,36 +47,43 @@ async def download_dataset_info( output_folder: Optional[Path] = None, ): """Download dataset info from Hugging Face.""" - if not self.session: - raise RuntimeError("Session is not initialized") + from getai.core.session_manager import SessionManager - try: - url = f"{BASE_URL}/api/datasets/{dataset_id}" - if revision: - url += f"/revision/{revision}" - - params = {"full": str(full).lower()} - async with self.session.get(url, params=params) as response: - if response.status == 200: - dataset_info = await response.json() - logger.info("Dataset info for %s:", dataset_id) - logger.info("%s", dataset_info) - output_folder = output_folder or self.output_dir / dataset_id - output_folder.mkdir(parents=True, exist_ok=True) - await self.download_dataset_files( - dataset_id, revision, output_folder, dataset_info - ) - await self.validate_checksums_and_sizes(output_folder, dataset_info) - else: - logger.error( - "Error fetching dataset info: HTTP %s", response.status - ) - except aiohttp.ClientError as e: - logger.exception("Client error while downloading dataset info: %s", e) - except asyncio.TimeoutError as e: - logger.exception("Timeout error while downloading dataset info: %s", e) - except Exception as e: - logger.exception("Unexpected error while downloading dataset info: %s", e) + async with SessionManager( + max_connections=self.max_connections, hf_token=self.hf_token + ) as manager: + session = await manager.get_session() + try: + url = f"{BASE_URL}/api/datasets/{dataset_id}" + if revision: + url += f"/revision/{revision}" + + params = {"full": str(full).lower()} + async with session.get(url, params=params) as response: + if response.status == 200: + dataset_info = await response.json() + logger.info("Dataset info for %s:", dataset_id) + logger.info("%s", dataset_info) + output_folder = output_folder or self.output_dir / dataset_id + output_folder.mkdir(parents=True, exist_ok=True) + await self.download_dataset_files( + session, dataset_id, revision, output_folder, dataset_info + ) + await self.validate_checksums_and_sizes( + output_folder, dataset_info + ) + else: + logger.error( + "Error fetching dataset info: HTTP %s", response.status + ) + except aiohttp.ClientError as e: + logger.exception("Client error while downloading dataset info: %s", e) + except asyncio.TimeoutError as e: + logger.exception("Timeout error while downloading dataset info: %s", e) + except Exception as e: + logger.exception( + "Unexpected error while downloading dataset info: %s", e + ) def get_expected_kwargs(self): """Return the expected keyword arguments for download_dataset_info.""" @@ -87,21 +91,19 @@ def get_expected_kwargs(self): async def download_dataset_files( self, + session: aiohttp.ClientSession, dataset_id: str, revision: Optional[str] = None, output_folder: Optional[Path] = None, dataset_info: Optional[Dict[str, Any]] = None, ): """Download dataset files from Hugging Face.""" - if not self.session: - raise RuntimeError("Session is not initialized") - output_folder = output_folder or self.output_dir / dataset_id output_folder.mkdir(parents=True, exist_ok=True) try: url = f"{BASE_URL}/api/datasets/{dataset_id}/tree/{revision or 'main'}" - async with self.session.get(url) as response: + async with session.get(url) as response: if response.status == 200: file_tree = await response.json() logger.info("Downloading dataset files to: %s", output_folder) @@ -115,7 +117,7 @@ async def download_dataset_files( else: tasks.append( self.download_dataset_file( - file_url, output_folder / file["path"] + session, file_url, output_folder / file["path"] ) ) await asyncio.gather(*tasks) @@ -126,7 +128,9 @@ async def download_dataset_files( if sibling["rfilename"].endswith(".parquet"): lfs_url = f"{BASE_URL}/datasets/{dataset_id}/resolve/{revision or 'main'}/{sibling['rfilename']}" await self.download_git_lfs_file( - lfs_url, output_folder / sibling["rfilename"] + session, + lfs_url, + output_folder / sibling["rfilename"], ) metadata = await self.parse_json_files_for_metadata(output_folder) @@ -135,7 +139,7 @@ async def download_dataset_files( if lfs_metadata: lfs_url = f"https://huggingface.co/{dataset_id}/resolve/{revision or 'main'}/{lfs_file['path']}" await self.download_file_from_url( - lfs_url, output_folder / lfs_file["path"] + session, lfs_url, output_folder / lfs_file["path"] ) else: logger.error( @@ -145,7 +149,7 @@ async def download_dataset_files( source_datasets = metadata.get("source_datasets", []) for source_url in source_datasets: await self.download_file_from_url( - source_url, output_folder / Path(source_url).name + session, source_url, output_folder / Path(source_url).name ) else: logger.error( @@ -158,13 +162,12 @@ async def download_dataset_files( except Exception as e: logger.exception("Unexpected error while downloading dataset files: %s", e) - async def download_dataset_file(self, url: str, file_path: Path): + async def download_dataset_file( + self, session: aiohttp.ClientSession, url: str, file_path: Path + ): """Download a single dataset file.""" - if not self.session: - raise RuntimeError("Session is not initialized") - try: - async with self.session.get(url) as response: + async with session.get(url) as response: if response.status == 200: total_size = response.content_length or 0 output_path = Path(file_path) @@ -184,7 +187,7 @@ async def download_dataset_file(self, url: str, file_path: Path): logger.info("Downloaded dataset file: %s", output_path) elif response.status == 404: logger.info("Attempting to use LFS for %s", url) - await self.download_git_lfs_file(url, file_path) + await self.download_git_lfs_file(session, url, file_path) else: logger.error( "Error downloading dataset file: HTTP %s", response.status @@ -199,21 +202,20 @@ async def download_dataset_file(self, url: str, file_path: Path): @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) ) - async def download_git_lfs_file(self, url: str, file_path: Path): + async def download_git_lfs_file( + self, session: aiohttp.ClientSession, url: str, file_path: Path + ): """Download a Git LFS file.""" - if not self.session: - raise RuntimeError("Session is not initialized") - try: - async with self.session.get(url) as response: + async with session.get(url) as response: if response.status == 200: size = response.content_length or 0 if size > 0: - await self.download_file_from_url(url, file_path) + await self.download_file_from_url(session, url, file_path) else: logger.error("Missing size in LFS metadata.") - await self.download_file_from_url(url, file_path) + await self.download_file_from_url(session, url, file_path) else: logger.error( "Error fetching Git LFS file info: HTTP %s", response.status @@ -228,13 +230,12 @@ async def download_git_lfs_file(self, url: str, file_path: Path): @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) ) - async def download_file_from_url(self, url: str, file_path: Path): + async def download_file_from_url( + self, session: aiohttp.ClientSession, url: str, file_path: Path + ): """Download a file from a URL.""" - if not self.session: - raise RuntimeError("Session is not initialized") - try: - async with self.session.get(url) as response: + async with session.get(url) as response: if response.status == 200: total_size = response.content_length or 0 output_path = Path(file_path) diff --git a/getai/dataset_search.py b/getai/core/dataset_search.py similarity index 66% rename from getai/dataset_search.py rename to getai/core/dataset_search.py index 5b1a331..9318138 100644 --- a/getai/dataset_search.py +++ b/getai/core/dataset_search.py @@ -1,10 +1,11 @@ +""" getai/core/dataset_search.py - This module handles the asynchronous dataset search functionality. """ + import logging from datetime import datetime from pathlib import Path from typing import List, Dict, Any, Optional import aiohttp import asyncio -from aiohttp import ClientSession from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter @@ -25,41 +26,34 @@ class AsyncDatasetSearch: def __init__( self, query: str, - filtered_datasets: List[Dict[str, Any]], - total_datasets: int, output_dir: Path, max_connections: int, hf_token: Optional[str], - session: ClientSession, ): self.config = { "query": query, - "total_datasets": total_datasets, "output_dir": output_dir, "max_connections": max_connections, "hf_token": hf_token, - "session": session, "page_size": 20, "timeout": aiohttp.ClientTimeout(total=None), } - self.data = { - "filtered_datasets": self.sort_by_last_modified(filtered_datasets), - "filtered_dataset_ids": {dataset["id"] for dataset in filtered_datasets}, - "main_search_datasets": filtered_datasets.copy(), - "search_history": [(filtered_datasets.copy(), 1)], + self.data: Dict[str, Any] = { + "filtered_datasets": [], + "filtered_dataset_ids": set(), + "main_search_datasets": [], + "search_history": [], } - self.session = session @staticmethod def sort_by_last_modified(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Sort the datasets by lastModified date in descending order.""" return sorted(datasets, key=lambda x: x.get("lastModified", ""), reverse=True) - async def search_datasets(self, query: str, **kwargs): + async def search_datasets( + self, query: str, session: aiohttp.ClientSession, **kwargs + ): """Search datasets based on a query.""" - if not self.session: - raise RuntimeError("Session is not initialized") - params = { "search": query, "limit": 50, @@ -74,7 +68,7 @@ async def search_datasets(self, query: str, **kwargs): } search_url = f"{BASE_URL}/api/datasets" - async with self.session.get(search_url, params=params) as response: + async with session.get(search_url, params=params) as response: if response.status == 200: datasets = await response.json() sorted_datasets = self.sort_by_last_modified(datasets) @@ -87,79 +81,87 @@ async def search_datasets(self, query: str, **kwargs): async def display_dataset_search_results(self): """Display dataset search results and handle user interaction for downloading.""" - await self.search_datasets(self.config["query"]) - while True: - total_pages = ( - len(self.data["filtered_datasets"]) + self.config["page_size"] - 1 - ) // self.config["page_size"] - current_page = self.data["search_history"][-1][1] + from getai.core import SessionManager + + async with SessionManager( + max_connections=self.config["max_connections"], + hf_token=self.config["hf_token"], + ) as manager: + session = await manager.get_session() + await self.search_datasets(self.config["query"], session) while True: - await self.display_current_dataset_page(current_page, total_pages) - user_input = await self.get_dataset_user_input( - current_page, total_pages - ) + total_pages = ( + len(self.data["filtered_datasets"]) + self.config["page_size"] - 1 + ) // self.config["page_size"] + current_page = self.data["search_history"][-1][1] - if user_input.lower() == "n" and current_page < total_pages: - current_page += 1 - self.data["search_history"][-1] = ( - self.data["filtered_datasets"], - current_page, + while True: + await self.display_current_dataset_page(current_page, total_pages) + user_input = await self.get_dataset_user_input( + current_page, total_pages ) - elif user_input.lower() == "p" and current_page > 1: - current_page -= 1 - self.data["search_history"][-1] = ( - self.data["filtered_datasets"], - current_page, - ) - elif user_input.lower() == "f": - await self.filter_dataset_search_results() - break - elif user_input.lower() == "s": - await self.sort_dataset_search_results() - break - elif user_input.lower() == "r": - if len(self.data["search_history"]) > 1: - self.data["search_history"].pop() - ( + + if user_input.lower() == "n" and current_page < total_pages: + current_page += 1 + self.data["search_history"][-1] = ( + self.data["filtered_datasets"], + current_page, + ) + elif user_input.lower() == "p" and current_page > 1: + current_page -= 1 + self.data["search_history"][-1] = ( self.data["filtered_datasets"], current_page, - ) = self.data[ - "search_history" - ][-1] - self.config["total_datasets"] = len( - self.data["filtered_datasets"] ) + elif user_input.lower() == "f": + await self.filter_dataset_search_results() break - else: - logger.info("You are already at the main search results.") - elif user_input.isdigit() and 1 <= int(user_input) <= len( - self.get_current_datasets(current_page) - ): - selected_dataset = self.get_current_datasets(current_page)[ - int(user_input) - 1 - ] - output_folder = ( - Path(self.config["output_dir"]) / selected_dataset["id"] - ) - output_folder.mkdir(parents=True, exist_ok=True) + elif user_input.lower() == "s": + await self.sort_dataset_search_results() + break + elif user_input.lower() == "r": + if len(self.data["search_history"]) > 1: + self.data["search_history"].pop() + ( + self.data["filtered_datasets"], + current_page, + ) = self.data[ + "search_history" + ][-1] + self.config["total_datasets"] = len( + self.data["filtered_datasets"] + ) + break + else: + logger.info("You are already at the main search results.") + elif user_input.isdigit() and 1 <= int(user_input) <= len( + self.get_current_datasets(current_page) + ): + selected_dataset = self.get_current_datasets(current_page)[ + int(user_input) - 1 + ] + output_folder = ( + Path(self.config["output_dir"]) / selected_dataset["id"] + ) + output_folder.mkdir(parents=True, exist_ok=True) - # Delayed import to avoid circular import issue - from getai.api import download_dataset + # Delayed import to avoid circular import issue + from getai.api import download_dataset - # Call download_dataset from getai.api - await download_dataset( - identifier=selected_dataset["id"], - hf_token=self.config["hf_token"], - max_connections=self.config["max_connections"], - output_dir=output_folder, - ) - break - else: - logger.error("Invalid input. Please try again.") + # Call download_dataset from getai.api + await download_dataset( + identifier=selected_dataset["id"], + hf_token=self.config["hf_token"], + max_connections=self.config["max_connections"], + output_dir=output_folder, + ) + break + else: + logger.error("Invalid input. Please try again.") - await self.handle_dataset_user_choice() + await self.handle_dataset_user_choice() async def display_current_dataset_page(self, current_page, total_pages): """Display the current page of dataset search results.""" diff --git a/getai/model_downloader.py b/getai/core/model_downloader.py similarity index 57% rename from getai/model_downloader.py rename to getai/core/model_downloader.py index ad59a2b..633b56b 100644 --- a/getai/model_downloader.py +++ b/getai/core/model_downloader.py @@ -1,4 +1,4 @@ -"""model_downloader.py: Async Downloads models from the Hugging Face Hub.""" +""" getai/core/model_downloader.py: Async Downloads models from the Hugging Face Hub.""" import base64 import datetime @@ -11,7 +11,7 @@ from aiofiles import open as aio_open import aiofiles from aiohttp import ClientSession -from rainbow_tqdm import tqdm + BASE_URL = "https://huggingface.co" logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) @@ -27,13 +27,12 @@ class AsyncModelDownloader: def __init__( self, - session: ClientSession, max_retries: int = 5, output_dir: Optional[Path] = None, max_connections: int = 7, hf_token: Optional[str] = None, ): - """Initialize downloader with session and settings.""" + """Initialize downloader with settings.""" self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) self.output_dir = ( @@ -42,7 +41,6 @@ def __init__( self.max_retries = max_retries self.max_connections = max_connections self.token = hf_token - self.session = session self.branch_sizes: Dict[str, int] = {} self.file_locks: Dict[Path, asyncio.Lock] = {} @@ -51,143 +49,145 @@ async def download_model( ) -> None: """Download and optionally verify a model.""" print(f"Downloading model '{model_id}' from branch '{branch}'") - links, sha256, is_lora, is_llamacpp = ( - await self.get_download_links_from_huggingface(model_id, branch) - ) - output_folder = self.get_output_folder(model_id, branch, is_lora, is_llamacpp) - await self.download_model_files( - model_id, branch, links, dict(sha256), output_folder - ) - if check: - await self.check_model_files( - model_id, branch, links, dict(sha256), output_folder + + # Import SessionManager inside the function to avoid circular import + from getai.core import SessionManager + + async with SessionManager( + max_connections=self.max_connections, hf_token=self.token + ) as manager: + session = await manager.get_session() + links, sha256, is_lora, is_llamacpp = ( + await self.get_download_links_from_huggingface( + session, model_id, branch + ) + ) + output_folder = self.get_output_folder( + model_id, branch, is_lora, is_llamacpp + ) + await self.download_model_files( + session, model_id, branch, links, dict(sha256), output_folder ) + if check: + await self.check_model_files( + session, model_id, branch, links, dict(sha256), output_folder + ) async def get_download_links_from_huggingface( self, + session: ClientSession, model: str, branch: str, text_only: bool = False, specific_file: Optional[str] = None, ) -> Tuple[List[str], List[Tuple[str, str]], bool, bool]: """Fetch model download links from Hugging Face.""" - if self.session: - page = f"/api/models/{model}/tree/{branch}" - cursor = b"" - links: List[str] = [] - sha256: List[Tuple[str, str]] = [] - classifications: List[str] = [] - has_pytorch = False - has_pt = False - has_gguf = False - has_safetensors = False - is_lora = False - - while True: - url = f"{BASE_URL}{page}" + ( - f"?cursor={cursor.decode()}" if cursor else "" - ) - self.logger.debug("Making request to: %s", url) - - async with self.session.get(url) as response: - response.raise_for_status() - content = await response.json() + page = f"/api/models/{model}/tree/{branch}" + cursor = b"" + links: List[str] = [] + sha256: List[Tuple[str, str]] = [] + classifications: List[str] = [] + has_pytorch = False + has_pt = False + has_gguf = False + has_safetensors = False + is_lora = False + + while True: + url = f"{BASE_URL}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") + self.logger.debug("Making request to: %s", url) + + async with session.get(url) as response: + response.raise_for_status() + content = await response.json() + + if not content: + break - if not content: - break + for i, item in enumerate(content): + fname = item.get("path", "") + if specific_file and fname != specific_file: + continue + + if not is_lora and fname.endswith( + ("adapter_config.json", "adapter_model.bin") + ): + is_lora = True + + is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname) + is_safetensors = re.match(r".*\.safetensors", fname) + is_pt = re.match(r".*\.pt", fname) + is_gguf = re.match(r".*\.gguf", fname) + is_tiktoken = re.match(r".*\.tiktoken", fname) + is_tokenizer = ( + re.match(r"(tokenizer|ice|spiece).*\.model", fname) + or is_tiktoken + ) + is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer + + if any( + ( + is_pytorch, + is_safetensors, + is_pt, + is_gguf, + is_tokenizer, + is_text, + ) + ): + if "lfs" in item: + sha256.append((fname, item["lfs"]["oid"])) - for i, item in enumerate(content): - fname = item.get("path", "") - if specific_file and fname != specific_file: + if is_text: + links.append( + f"https://huggingface.co/{model}/resolve/{branch}/{fname}" + ) + classifications.append("text") continue - if not is_lora and fname.endswith( - ("adapter_config.json", "adapter_model.bin") - ): - is_lora = True - - is_pytorch = re.match( - r"(pytorch|adapter|gptq)_model.*\.bin", fname - ) - is_safetensors = re.match(r".*\.safetensors", fname) - is_pt = re.match(r".*\.pt", fname) - is_gguf = re.match(r".*\.gguf", fname) - is_tiktoken = re.match(r".*\.tiktoken", fname) - is_tokenizer = ( - re.match(r"(tokenizer|ice|spiece).*\.model", fname) - or is_tiktoken - ) - is_text = ( - re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer - ) - - if any( - ( - is_pytorch, - is_safetensors, - is_pt, - is_gguf, - is_tokenizer, - is_text, + if not text_only: + links.append( + f"https://huggingface.co/{model}/resolve/{branch}/{fname}" ) - ): - if "lfs" in item: - sha256.append((fname, item["lfs"]["oid"])) - - if is_text: - links.append( - f"https://huggingface.co/{model}/resolve/{branch}/{fname}" - ) - classifications.append("text") - continue - - if not text_only: - links.append( - f"https://huggingface.co/{model}/resolve/{branch}/{fname}" - ) - if is_safetensors: - has_safetensors = True - classifications.append("safetensors") - elif is_pytorch: - has_pytorch = True - classifications.append("pytorch") - elif is_pt: - has_pt = True - classifications.append("pt") - elif is_gguf: - has_gguf = True - classifications.append("gguf") - - cursor = ( - base64.b64encode( - f'{{"file_name":"{content[-1]["path"]}"}}'.encode() - ) - + b":50" + if is_safetensors: + has_safetensors = True + classifications.append("safetensors") + elif is_pytorch: + has_pytorch = True + classifications.append("pytorch") + elif is_pt: + has_pt = True + classifications.append("pt") + elif is_gguf: + has_gguf = True + classifications.append("gguf") + + cursor = ( + base64.b64encode( + f'{{"file_name":"{content[-1]["path"]}"}}'.encode() ) - cursor = base64.b64encode(cursor) - cursor = cursor.replace(b"=", b"%3D") - - if (has_pytorch or has_pt) and has_safetensors: - links = [ - link - for link, classification in zip(links, classifications) - if classification not in ("pytorch", "pt") - ] - - if has_gguf and specific_file is None: - has_q4km = any("q4_k_m" in link.lower() for link in links) - if has_q4km: - links = [link for link in links if "q4_k_m" in link.lower()] - else: - links = [ - link for link in links if not link.lower().endswith(".gguf") - ] - - is_llamacpp = has_gguf and specific_file is not None + + b":50" + ) + cursor = base64.b64encode(cursor) + cursor = cursor.replace(b"=", b"%3D") + + if (has_pytorch or has_pt) and has_safetensors: + links = [ + link + for link, classification in zip(links, classifications) + if classification not in ("pytorch", "pt") + ] + + if has_gguf and specific_file is None: + has_q4km = any("q4_k_m" in link.lower() for link in links) + if has_q4km: + links = [link for link in links if "q4_k_m" in link.lower()] + else: + links = [link for link in links if not link.lower().endswith(".gguf")] - return links, sha256, is_lora, is_llamacpp + is_llamacpp = has_gguf and specific_file is not None - return [], [], False, False + return links, sha256, is_lora, is_llamacpp def get_output_folder( self, @@ -210,6 +210,7 @@ def get_output_folder( async def download_model_files( self, + session: ClientSession, model: str, branch: str, links: List[str], @@ -237,9 +238,6 @@ async def download_model_files( async def download_file(link: str): """Download a single file.""" - if not self.session: - raise RuntimeError("Session is not initialized") - async with semaphore: filename = Path(link).name file_hash: Optional[str] = sha256_dict.get(filename) @@ -249,9 +247,7 @@ async def download_file(link: str): f"Warning: No SHA256 hash found for {filename}. Downloading without sha256 verification." ) - await self._download_model_file( - self.session, link, output_folder, file_hash - ) + await self._download_model_file(session, link, output_folder, file_hash) tasks = [asyncio.ensure_future(download_file(link)) for link in links] await asyncio.gather(*tasks) @@ -264,6 +260,8 @@ async def _download_model_file( file_hash: Optional[str], ): """Download and save a model file.""" + from rainbow_tqdm import tqdm + filename = Path(url.rsplit("/", 1)[1]) output_path = output_folder / filename @@ -297,6 +295,7 @@ async def _download_model_file( async def check_model_files( self, + session: ClientSession, model: str, branch: str, links: List[str], diff --git a/getai/model_search.py b/getai/core/model_search.py similarity index 98% rename from getai/model_search.py rename to getai/core/model_search.py index fc9f717..1b9769a 100644 --- a/getai/model_search.py +++ b/getai/core/model_search.py @@ -4,21 +4,14 @@ import logging import re from datetime import datetime -from typing import ( - Optional, - Dict, - List, - Tuple, - Set, - Any, # pylint: disable=unused-import noqa: F401 -) +from typing import Optional, Dict, List, Tuple, Set, Any # noqa from aiohttp import ClientSession from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter -# Import API functions -from getai import download_model -from getai.utils import interactive_branch_selection, convert_to_bytes + +from getai import api + BASE_URL = "https://huggingface.co" logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) @@ -39,7 +32,7 @@ def __init__( session: ClientSession, max_connections: int = 10, hf_token: Optional[str] = None, - **kwargs: Any, # Accepting additional keyword arguments for flexibility + **kwargs: Any, ): """Initialize AsyncModelSearch with query and session.""" self.query = query @@ -190,7 +183,7 @@ async def display_search_results(self): ] branches = await self.get_model_branches(selected_model["id"]) selected_branch = await self.select_branch_interactive(branches) - await download_model( + await api.download_model( identifier=selected_model["id"], branch=selected_branch, hf_token=self.token, @@ -385,6 +378,8 @@ async def select_branch( branch_arg: Optional[str], ) -> str: """Select a branch based on user input or default to main.""" + from getai.core import interactive_branch_selection + if branch_arg: if isinstance(branch_arg, str): return branch_arg if branch_arg in branches else "main" @@ -451,6 +446,8 @@ async def get_branch_file_sizes( self, model: str, quiet: bool = False ) -> Dict[str, int]: """Get file sizes for all branches of a model.""" + from getai.core import convert_to_bytes + if not quiet: self.logger.debug("Fetching file sizes for %s...", model) branches = await self.get_model_branches(model) diff --git a/getai/session_manager.py b/getai/core/session_manager.py similarity index 90% rename from getai/session_manager.py rename to getai/core/session_manager.py index 1cb4d15..3be04e6 100644 --- a/getai/session_manager.py +++ b/getai/core/session_manager.py @@ -1,4 +1,4 @@ -"""session_manager.py for GetAI - Contains the SessionManager class for managing aiohttp.ClientSession instances.""" +""" getai/core/session_manager.py for GetAI - Contains the SessionManager class for managing aiohttp.ClientSession instances.""" from typing import Optional import aiohttp diff --git a/getai/core/utils.py b/getai/core/utils.py new file mode 100644 index 0000000..a28af5a --- /dev/null +++ b/getai/core/utils.py @@ -0,0 +1,116 @@ +# getai/core/utils.py - GetAI utility functions for the core functionality. + +import os +from pathlib import Path +import logging +import subprocess + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +class CoreUtils: + @staticmethod + def convert_to_bytes(size_str): + """Convert size string like '2.3 GB' or '200 MB' to bytes.""" + try: + size_units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3} + size, unit = size_str.split() + return int(float(size) * size_units[unit]) + except Exception as e: + logging.exception("Error converting size to bytes: %s", e) + raise + + @staticmethod + async def interactive_branch_selection(branches): + """Prompt user to select a branch interactively from a list.""" + try: + from prompt_toolkit import PromptSession + from prompt_toolkit.completion import WordCompleter + + branch_completer = WordCompleter(branches, ignore_case=True) + session = PromptSession(completer=branch_completer) + selected_branch = await session.prompt_async( + "Select a branch [Press TAB]: " + ) + return selected_branch if selected_branch in branches else "main" + except Exception as e: + logging.exception("Error during interactive branch selection: %s", e) + raise + + @staticmethod + def get_hf_token(): + """Retrieve the Hugging Face token securely from environment variables or the CLI.""" + try: + hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN") + if hf_token: + logging.info("Using Hugging Face token from environment variable.") + return hf_token + + hf_token_file = Path.home() / ".huggingface" / "token" + if hf_token_file.exists(): + with open(hf_token_file, "r", encoding="utf-8") as f: + logging.info("Using Hugging Face token from ~/.huggingface/token.") + return f.read().strip() + + hf_token = CoreUtils.get_hf_token_from_cli() + if hf_token: + logging.info("Using Hugging Face token from Hugging Face CLI.") + return hf_token + + raise ValueError( + "No Hugging Face token found. Please log in using the Hugging Face CLI." + ) + except Exception as e: + logging.exception("Error retrieving Hugging Face token: %s", e) + raise + + @staticmethod + def get_hf_token_from_cli(): + """Retrieve Hugging Face token using the CLI.""" + token_file = os.path.expanduser("~/.cache/huggingface/token") + try: + with open(token_file, "r", encoding="utf-8") as f: + return f.read().strip() + except FileNotFoundError: + logging.error( + "Hugging Face token file not found. Please log in using `huggingface-cli login`." + ) + return None + except Exception as e: + logging.exception("Error retrieving Hugging Face token from file: %s", e) + return None + + @staticmethod + def hf_login(): + """Log in using Hugging Face CLI.""" + try: + result = subprocess.run( + ["huggingface-cli", "login"], check=True, capture_output=True, text=True + ) + logging.info("Hugging Face CLI login successful: %s", result.stdout) + except subprocess.CalledProcessError as e: + logging.error("Hugging Face CLI login failed: %s", e.stderr) + except FileNotFoundError: + logging.error( + "Hugging Face CLI not found. Please install it and try again." + ) + except Exception as e: + logging.exception("Unexpected error during Hugging Face CLI login: %s", e) + + +__all__ = [ + "CoreUtils", + "convert_to_bytes", + "interactive_branch_selection", + "get_hf_token", + "get_hf_token_from_cli", + "hf_login", +] + +convert_to_bytes = CoreUtils.convert_to_bytes +interactive_branch_selection = CoreUtils.interactive_branch_selection +get_hf_token = CoreUtils.get_hf_token +get_hf_token_from_cli = CoreUtils.get_hf_token_from_cli +hf_login = CoreUtils.hf_login diff --git a/examples/example_getai_stanfordnlp_imdb_dataset.py b/getai/examples/example_getai_stanfordnlp_imdb_dataset.py similarity index 100% rename from examples/example_getai_stanfordnlp_imdb_dataset.py rename to getai/examples/example_getai_stanfordnlp_imdb_dataset.py diff --git a/getai/getai_config.yaml b/getai/getai_config.yaml deleted file mode 100644 index 11dec00..0000000 --- a/getai/getai_config.yaml +++ /dev/null @@ -1 +0,0 @@ -hf_token: your_huggingface_token_here \ No newline at end of file diff --git a/getai/main.py b/getai/main.py deleted file mode 100644 index 0998e67..0000000 --- a/getai/main.py +++ /dev/null @@ -1,184 +0,0 @@ -import argparse -import asyncio -import logging -from aiohttp import ClientError -from getai.api import search_datasets, download_dataset, search_models, download_model -from getai.utils import get_hf_token - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -async def main(): - """Main function for the GetAI CLI""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest="mode", help="Mode of operation") - - # Model mode - model_parser = subparsers.add_parser("model", help="Download a model") - model_parser.add_argument( - "identifier", type=str, help="Model identifier on Hugging Face" - ) - model_parser.add_argument( - "--branch", nargs="?", const="main", default="main", help="Branch name" - ) - model_parser.add_argument( - "--output-dir", type=str, default=None, help="Directory to save the model" - ) - model_parser.add_argument( - "--max-connections", type=int, default=5, help="Max connections for downloads" - ) - model_parser.add_argument( - "--clean", action="store_true", help="Start download from scratch" - ) - model_parser.add_argument( - "--check", - action="store_true", - help="Validate the checksums of files after download", - ) - - # Dataset mode - dataset_parser = subparsers.add_parser("dataset", help="Download a dataset") - dataset_parser.add_argument( - "identifier", type=str, help="Dataset identifier on Hugging Face" - ) - dataset_parser.add_argument("--revision", type=str, help="Revision of the dataset") - dataset_parser.add_argument( - "--output-dir", type=str, default=None, help="Directory to save the dataset" - ) - dataset_parser.add_argument( - "--max-connections", type=int, default=5, help="Max connections for downloads" - ) - dataset_parser.add_argument( - "--full", action="store_true", help="Fetch full dataset information" - ) - - # Search mode - search_parser = subparsers.add_parser( - "search", help="Search for models or datasets" - ) - search_subparsers = search_parser.add_subparsers( - dest="search_mode", help="Search mode" - ) - - # Model search mode - model_search_parser = search_subparsers.add_parser( - "model", help="Search for models" - ) - model_search_parser.add_argument("query", type=str, help="Search query for models") - model_search_parser.add_argument( - "--max-connections", type=int, default=5, help="Max connections for searching" - ) - - # Dataset search mode - dataset_search_parser = search_subparsers.add_parser( - "dataset", help="Search for datasets" - ) - dataset_search_parser.add_argument( - "query", type=str, help="Search query for datasets" - ) - dataset_search_parser.add_argument( - "--output-dir", type=str, default=None, help="Directory to save the dataset" - ) - dataset_search_parser.add_argument( - "--max-connections", type=int, default=5, help="Max connections for downloads" - ) - dataset_search_parser.add_argument( - "--author", type=str, help="Filter datasets by author or organization" - ) - dataset_search_parser.add_argument( - "--filter-criteria", type=str, help="Filter datasets based on tags" - ) - dataset_search_parser.add_argument( - "--sort", type=str, help="Property to use when sorting datasets" - ) - dataset_search_parser.add_argument( - "--direction", type=str, help="Direction to sort datasets" - ) - dataset_search_parser.add_argument( - "--limit", type=int, help="Limit the number of datasets fetched" - ) - dataset_search_parser.add_argument( - "--full", action="store_true", help="Fetch full dataset information" - ) - - # Token update - parser.add_argument( - "--update-token", type=str, help="Update the Hugging Face token" - ) - - args = parser.parse_args() - hf_token = get_hf_token(update_token=args.update_token) - if args.update_token: - logger.info("Hugging Face token updated successfully.") - - if args.mode not in ["model", "dataset", "search"]: - logger.error("Invalid mode. Please specify 'model', 'dataset', or 'search'.") - return - - try: - if args.mode == "search": - if not args.search_mode: - logger.error("Please specify the search mode (model or dataset).") - return - - if args.search_mode == "model": - await search_models( - query=args.query, - hf_token=hf_token, - max_connections=args.max_connections, - ) - else: - await search_datasets( - query=args.query, - hf_token=hf_token, - max_connections=args.max_connections, - output_dir=args.output_dir, - author=args.author, - filter_criteria=args.filter_criteria, - sort=args.sort, - direction=args.direction, - limit=args.limit, - full=args.full, - ) - elif args.mode == "dataset": - await download_dataset( - identifier=args.identifier, - hf_token=hf_token, - max_connections=args.max_connections, - output_dir=args.output_dir, - revision=args.revision, - full=args.full, - ) - else: # args.mode == 'model' - await download_model( - identifier=args.identifier, - branch=args.branch, - hf_token=hf_token, - max_connections=args.max_connections, - output_dir=args.output_dir, - clean=args.clean, - check=args.check, - ) - except KeyboardInterrupt: - logger.info("\nKeyboardInterrupt received. Closing operation...") - except ClientError as e: - logger.error("HTTP error during operation: %s", e) - except asyncio.CancelledError: - logger.info("Task cancelled during operation.") - except ValueError as e: - logger.error("Value error during operation: %s", e) - except Exception as e: - logger.error("Unexpected error during operation: %s", e) - - -if __name__ == "__main__": - loop = asyncio.get_event_loop() - try: - loop.run_until_complete(main()) - except KeyboardInterrupt: - logger.info("\nKeyboardInterrupt received. Closing operation...") - finally: - pending_tasks = asyncio.all_tasks(loop) - loop.run_until_complete(asyncio.gather(*pending_tasks, return_exceptions=True)) - loop.close() diff --git a/getai/utils.py b/getai/utils.py deleted file mode 100644 index 9c5de84..0000000 --- a/getai/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -""" utils.py - Contains utility functions for the GetAI CLI, AsyncModelDownloader, and AsyncDatasetDownloader classes.""" - -import os -from pathlib import Path -import logging -import yaml -from prompt_toolkit import PromptSession -from prompt_toolkit.completion import WordCompleter - - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO -) - - -def convert_to_bytes(size_str): - """Convert size string like '2.3 GB' or '200 MB' to bytes.""" - try: - size_units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3} - size, unit = size_str.split() - return int(float(size) * size_units[unit]) - except Exception as e: - logging.exception("Error converting size to bytes: %s", e) - raise - - -async def interactive_branch_selection(branches): - """Prompt user to select a branch interactively from a list.""" - try: - branch_completer = WordCompleter(branches, ignore_case=True) - session = PromptSession(completer=branch_completer) - selected_branch = await session.prompt_async("Select a branch [Press TAB]: ") - return selected_branch if selected_branch in branches else "main" - except Exception as e: - logging.exception("Error during interactive branch selection: %s", e) - raise - - -def get_hf_token(update_token=None): - """Retrieve or update Hugging Face token securely.""" - try: - getai_config_file = Path.home() / ".getai" / "getai_config.yaml" - - if update_token: - logging.info("Updating Hugging Face token in ~/.getai/getai_config.yaml.") - getai_config_file.parent.mkdir(parents=True, exist_ok=True) - with open(getai_config_file, "w", encoding="utf-8") as f: - yaml.dump({"hf_token": update_token}, f) - return update_token - - hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN") - if hf_token: - logging.info("Using Hugging Face token from environment variable.") - return hf_token - - hf_token_file = Path.home() / ".huggingface" / "token" - if hf_token_file.exists(): - with open(hf_token_file, "r", encoding="utf-8") as f: - logging.info("Using Hugging Face token from ~/.huggingface/token.") - return f.read().strip() - - if getai_config_file.exists(): - with open(getai_config_file, "r", encoding="utf-8") as f: - config = yaml.safe_load(f) - logging.info( - "Using Hugging Face token from ~/.getai/getai_config.yaml." - ) - return config.get("hf_token") - - logging.warning("No Hugging Face token found. Prompting user for input.") - hf_token = input("Enter your Hugging Face token: ") - - logging.info("Saving Hugging Face token to ~/.getai/getai_config.yaml.") - getai_config_file.parent.mkdir(parents=True, exist_ok=True) - with open(getai_config_file, "w", encoding="utf-8") as f: - yaml.dump({"hf_token": hf_token}, f) - - return hf_token - except Exception as e: - logging.exception("Error retrieving or updating Hugging Face token: %s", e) - raise diff --git a/pyproject.toml b/pyproject.toml index ebd799f..386f6c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,11 @@ [tool.poetry] name = "getai" -version = "0.0.97" +version = "0.0.98" description = "GetAI - An asynchronous AI search and download tool for AI models, datasets, and tools. Designed to streamline the process of downloading machine learning models, datasets, and more." authors = ["Ben Gorlick "] license = "MIT - with attribution" readme = "README.md" +include = ["getai/getai_config.yaml"] [tool.poetry.dependencies] python = "^3.9" @@ -12,20 +13,17 @@ aiohttp = "^3.9.3" aiofiles = "^23.2.1" prompt-toolkit = "^3.0.43" rainbow-tqdm = "^0.1.3" -PyYAML = "^6.0.1" types-aiofiles = "^0.1.0" tenacity = "^8.0.1" - [tool.poetry.group.dev.dependencies] pytest = "^8.0.0" +pytest-asyncio = "^0.17.0" +pytest-cov = "^3.0.0" # this is for coverage [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -include = ["getai/getai_config.yaml"] - [tool.poetry.scripts] getai = "getai.__main__:run" - diff --git a/setup.py b/setup.py index fa55aef..d70f306 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,6 @@ "aiofiles", "prompt_toolkit", "rainbow-tqdm", - "PyYAML", "types-aiofiles", "tenacity", ],