Skip to content

Commit

Permalink
End to End Refactor for more modular scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
bgorlick committed Jun 7, 2024
1 parent d95e828 commit c185b41
Show file tree
Hide file tree
Showing 24 changed files with 947 additions and 692 deletions.
1 change: 0 additions & 1 deletion getai/0

This file was deleted.

40 changes: 25 additions & 15 deletions getai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
# getai/__init__.py for GetAI - Contains the core API functions for searching and downloading datasets and models.
"""
Initialization module for the getai package.
"""

from getai.utils import get_hf_token
from getai.session_manager import SessionManager

# Import the core API functions from the api module
from getai.api import (
search_datasets,
download_dataset,
search_models,
download_model,
from .core import (
AsyncDatasetDownloader,
AsyncDatasetSearch,
AsyncModelDownloader,
AsyncModelSearch,
SessionManager,
convert_to_bytes,
interactive_branch_selection,
get_hf_token,
get_hf_token_from_cli,
hf_login,
)

__all__ = [
"search_datasets",
"download_dataset",
"search_models",
"download_model",
"get_hf_token",
"AsyncDatasetDownloader",
"AsyncDatasetSearch",
"AsyncModelDownloader",
"AsyncModelSearch",
"SessionManager",
"convert_to_bytes",
"interactive_branch_selection",
"get_hf_token",
"get_hf_token_from_cli",
"hf_login",
]

# Path: getai/cli/__init__.py
8 changes: 4 additions & 4 deletions getai/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# __main__.py
# getai/__main__.py
import asyncio
from getai.main import main
from getai.cli import cli_main


def run():
asyncio.run(main())
asyncio.run(cli_main())


if __name__ == '__main__':
if __name__ == "__main__":
run()
118 changes: 0 additions & 118 deletions getai/api.py

This file was deleted.

25 changes: 25 additions & 0 deletions getai/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# getai/api/__init__.py - Initialization module for the getai.api package.

from getai.api.datasets import DatasetAPI
from getai.api.models import ModelAPI
from getai.api.utils import UtilsAPI

# Exposing class methods as module-level functions
search_datasets = DatasetAPI.search_datasets
download_dataset = DatasetAPI.download_dataset
search_models = ModelAPI.search_models
download_model = ModelAPI.download_model
get_hf_token = UtilsAPI.get_hf_token
hf_login = UtilsAPI.hf_login

__all__ = [
"search_datasets",
"download_dataset",
"search_models",
"download_model",
"get_hf_token",
"hf_login",
"DatasetAPI",
"ModelAPI",
"UtilsAPI",
]
56 changes: 56 additions & 0 deletions getai/api/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# getai/api/datasets.py - This module contains the API methods for searching and downloading datasets.

from pathlib import Path
from typing import Optional
from getai.api.utils import get_hf_token


class DatasetAPI:

@staticmethod
async def search_datasets(
query: str,
hf_token: Optional[str] = None,
max_connections: int = 5,
output_dir: Optional[Path] = None,
**kwargs,
):
from getai.core.dataset_search import AsyncDatasetSearch

hf_token = hf_token or get_hf_token()
output_dir = output_dir or Path.home() / ".getai" / "datasets" / query

searcher = AsyncDatasetSearch(
query=query,
output_dir=output_dir,
max_connections=max_connections,
hf_token=hf_token,
)
await searcher.display_dataset_search_results()

@staticmethod
async def download_dataset(
identifier: str,
hf_token: Optional[str] = None,
max_connections: int = 5,
output_dir: Optional[Path] = None,
**kwargs,
):
from getai.core.dataset_downloader import AsyncDatasetDownloader

hf_token = hf_token or get_hf_token()
output_dir = output_dir or Path.home() / ".getai" / "datasets" / identifier

downloader = AsyncDatasetDownloader(
output_dir=output_dir,
max_connections=max_connections,
hf_token=hf_token,
)
await downloader.download_dataset_info(
dataset_id=identifier,
**{
key: value
for key, value in kwargs.items()
if key in downloader.get_expected_kwargs()
},
)
64 changes: 64 additions & 0 deletions getai/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# getai/api/models.py - GetAI API methods for searching and downloading models.

# getai/api/models.py - This module contains the API methods for searching and downloading models.

from functools import lru_cache
from typing import Optional
from pathlib import Path


class ModelAPI:

@staticmethod
@lru_cache(maxsize=None)
def _get_hf_token(token: Optional[str] = None) -> str:
from getai.core.utils import get_hf_token

return token or get_hf_token()

@staticmethod
async def search_models(
query: str,
hf_token: Optional[str] = None,
max_connections: int = 5,
**kwargs,
):
from getai.api.utils import get_hf_token
from getai.core.model_search import AsyncModelSearch
from getai.core.session_manager import SessionManager

hf_token = hf_token or get_hf_token()

async with SessionManager(
max_connections=max_connections, hf_token=hf_token
) as manager:
session = await manager.get_session()
searcher = AsyncModelSearch(
query=query,
max_connections=max_connections,
session=session,
hf_token=hf_token,
)
await searcher.search_models(query, **kwargs)

@staticmethod
async def download_model(
identifier: str,
branch: str = "main",
hf_token: Optional[str] = None,
max_connections: int = 5,
output_dir: Optional[Path] = None,
**kwargs,
):
from getai.api.utils import get_hf_token
from getai.core.model_downloader import AsyncModelDownloader

hf_token = hf_token or get_hf_token()
output_dir = output_dir or Path.home() / ".getai" / "models"

downloader = AsyncModelDownloader(
output_dir=output_dir,
max_connections=max_connections,
hf_token=hf_token,
)
await downloader.download_model(identifier, branch, **kwargs)
29 changes: 29 additions & 0 deletions getai/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# getai/api/utils.py - This module contains utility functions for the API.

from functools import lru_cache
from typing import Optional
from getai.core.utils import CoreUtils


class UtilsAPI:
@staticmethod
@lru_cache(maxsize=None)
def get_hf_token(token: Optional[str] = None) -> str:
"""Retrieve the Hugging Face token using caching for efficiency."""
return token or CoreUtils.get_hf_token()

@staticmethod
def hf_login():
"""Log in using Hugging Face CLI."""
CoreUtils.hf_login()


# Expose class methods as module-level functions
get_hf_token = UtilsAPI.get_hf_token
hf_login = UtilsAPI.hf_login

__all__ = [
"UtilsAPI",
"get_hf_token",
"hf_login",
]
13 changes: 13 additions & 0 deletions getai/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
getai/cli/__init__.py - Initialization module for the getai.cli package.
"""

from getai.cli.cli import main as cli_main, add_common_arguments, define_subparsers
from getai.cli.utils import CLIUtils

__all__ = [
"cli_main",
"add_common_arguments",
"define_subparsers",
"CLIUtils",
]
Loading

0 comments on commit c185b41

Please sign in to comment.