-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
End to End Refactor for more modular scaling
- Loading branch information
Showing
24 changed files
with
947 additions
and
692 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,31 @@ | ||
# getai/__init__.py for GetAI - Contains the core API functions for searching and downloading datasets and models. | ||
""" | ||
Initialization module for the getai package. | ||
""" | ||
|
||
from getai.utils import get_hf_token | ||
from getai.session_manager import SessionManager | ||
|
||
# Import the core API functions from the api module | ||
from getai.api import ( | ||
search_datasets, | ||
download_dataset, | ||
search_models, | ||
download_model, | ||
from .core import ( | ||
AsyncDatasetDownloader, | ||
AsyncDatasetSearch, | ||
AsyncModelDownloader, | ||
AsyncModelSearch, | ||
SessionManager, | ||
convert_to_bytes, | ||
interactive_branch_selection, | ||
get_hf_token, | ||
get_hf_token_from_cli, | ||
hf_login, | ||
) | ||
|
||
__all__ = [ | ||
"search_datasets", | ||
"download_dataset", | ||
"search_models", | ||
"download_model", | ||
"get_hf_token", | ||
"AsyncDatasetDownloader", | ||
"AsyncDatasetSearch", | ||
"AsyncModelDownloader", | ||
"AsyncModelSearch", | ||
"SessionManager", | ||
"convert_to_bytes", | ||
"interactive_branch_selection", | ||
"get_hf_token", | ||
"get_hf_token_from_cli", | ||
"hf_login", | ||
] | ||
|
||
# Path: getai/cli/__init__.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# __main__.py | ||
# getai/__main__.py | ||
import asyncio | ||
from getai.main import main | ||
from getai.cli import cli_main | ||
|
||
|
||
def run(): | ||
asyncio.run(main()) | ||
asyncio.run(cli_main()) | ||
|
||
|
||
if __name__ == '__main__': | ||
if __name__ == "__main__": | ||
run() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# getai/api/__init__.py - Initialization module for the getai.api package. | ||
|
||
from getai.api.datasets import DatasetAPI | ||
from getai.api.models import ModelAPI | ||
from getai.api.utils import UtilsAPI | ||
|
||
# Exposing class methods as module-level functions | ||
search_datasets = DatasetAPI.search_datasets | ||
download_dataset = DatasetAPI.download_dataset | ||
search_models = ModelAPI.search_models | ||
download_model = ModelAPI.download_model | ||
get_hf_token = UtilsAPI.get_hf_token | ||
hf_login = UtilsAPI.hf_login | ||
|
||
__all__ = [ | ||
"search_datasets", | ||
"download_dataset", | ||
"search_models", | ||
"download_model", | ||
"get_hf_token", | ||
"hf_login", | ||
"DatasetAPI", | ||
"ModelAPI", | ||
"UtilsAPI", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# getai/api/datasets.py - This module contains the API methods for searching and downloading datasets. | ||
|
||
from pathlib import Path | ||
from typing import Optional | ||
from getai.api.utils import get_hf_token | ||
|
||
|
||
class DatasetAPI: | ||
|
||
@staticmethod | ||
async def search_datasets( | ||
query: str, | ||
hf_token: Optional[str] = None, | ||
max_connections: int = 5, | ||
output_dir: Optional[Path] = None, | ||
**kwargs, | ||
): | ||
from getai.core.dataset_search import AsyncDatasetSearch | ||
|
||
hf_token = hf_token or get_hf_token() | ||
output_dir = output_dir or Path.home() / ".getai" / "datasets" / query | ||
|
||
searcher = AsyncDatasetSearch( | ||
query=query, | ||
output_dir=output_dir, | ||
max_connections=max_connections, | ||
hf_token=hf_token, | ||
) | ||
await searcher.display_dataset_search_results() | ||
|
||
@staticmethod | ||
async def download_dataset( | ||
identifier: str, | ||
hf_token: Optional[str] = None, | ||
max_connections: int = 5, | ||
output_dir: Optional[Path] = None, | ||
**kwargs, | ||
): | ||
from getai.core.dataset_downloader import AsyncDatasetDownloader | ||
|
||
hf_token = hf_token or get_hf_token() | ||
output_dir = output_dir or Path.home() / ".getai" / "datasets" / identifier | ||
|
||
downloader = AsyncDatasetDownloader( | ||
output_dir=output_dir, | ||
max_connections=max_connections, | ||
hf_token=hf_token, | ||
) | ||
await downloader.download_dataset_info( | ||
dataset_id=identifier, | ||
**{ | ||
key: value | ||
for key, value in kwargs.items() | ||
if key in downloader.get_expected_kwargs() | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# getai/api/models.py - GetAI API methods for searching and downloading models. | ||
|
||
# getai/api/models.py - This module contains the API methods for searching and downloading models. | ||
|
||
from functools import lru_cache | ||
from typing import Optional | ||
from pathlib import Path | ||
|
||
|
||
class ModelAPI: | ||
|
||
@staticmethod | ||
@lru_cache(maxsize=None) | ||
def _get_hf_token(token: Optional[str] = None) -> str: | ||
from getai.core.utils import get_hf_token | ||
|
||
return token or get_hf_token() | ||
|
||
@staticmethod | ||
async def search_models( | ||
query: str, | ||
hf_token: Optional[str] = None, | ||
max_connections: int = 5, | ||
**kwargs, | ||
): | ||
from getai.api.utils import get_hf_token | ||
from getai.core.model_search import AsyncModelSearch | ||
from getai.core.session_manager import SessionManager | ||
|
||
hf_token = hf_token or get_hf_token() | ||
|
||
async with SessionManager( | ||
max_connections=max_connections, hf_token=hf_token | ||
) as manager: | ||
session = await manager.get_session() | ||
searcher = AsyncModelSearch( | ||
query=query, | ||
max_connections=max_connections, | ||
session=session, | ||
hf_token=hf_token, | ||
) | ||
await searcher.search_models(query, **kwargs) | ||
|
||
@staticmethod | ||
async def download_model( | ||
identifier: str, | ||
branch: str = "main", | ||
hf_token: Optional[str] = None, | ||
max_connections: int = 5, | ||
output_dir: Optional[Path] = None, | ||
**kwargs, | ||
): | ||
from getai.api.utils import get_hf_token | ||
from getai.core.model_downloader import AsyncModelDownloader | ||
|
||
hf_token = hf_token or get_hf_token() | ||
output_dir = output_dir or Path.home() / ".getai" / "models" | ||
|
||
downloader = AsyncModelDownloader( | ||
output_dir=output_dir, | ||
max_connections=max_connections, | ||
hf_token=hf_token, | ||
) | ||
await downloader.download_model(identifier, branch, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# getai/api/utils.py - This module contains utility functions for the API. | ||
|
||
from functools import lru_cache | ||
from typing import Optional | ||
from getai.core.utils import CoreUtils | ||
|
||
|
||
class UtilsAPI: | ||
@staticmethod | ||
@lru_cache(maxsize=None) | ||
def get_hf_token(token: Optional[str] = None) -> str: | ||
"""Retrieve the Hugging Face token using caching for efficiency.""" | ||
return token or CoreUtils.get_hf_token() | ||
|
||
@staticmethod | ||
def hf_login(): | ||
"""Log in using Hugging Face CLI.""" | ||
CoreUtils.hf_login() | ||
|
||
|
||
# Expose class methods as module-level functions | ||
get_hf_token = UtilsAPI.get_hf_token | ||
hf_login = UtilsAPI.hf_login | ||
|
||
__all__ = [ | ||
"UtilsAPI", | ||
"get_hf_token", | ||
"hf_login", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
getai/cli/__init__.py - Initialization module for the getai.cli package. | ||
""" | ||
|
||
from getai.cli.cli import main as cli_main, add_common_arguments, define_subparsers | ||
from getai.cli.utils import CLIUtils | ||
|
||
__all__ = [ | ||
"cli_main", | ||
"add_common_arguments", | ||
"define_subparsers", | ||
"CLIUtils", | ||
] |
Oops, something went wrong.