From 4d273567887cb47a5292d9c041ba324db8e8230c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 16 Dec 2023 16:02:49 +0300 Subject: [PATCH] set_handlers: changed type of `models_to_fetch`, removed "models_download_params" (#184) * set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. Signed-off-by: Alexander Piskun --- CHANGELOG.md | 1 + docs/NextcloudTalkBotTransformers.rst | 2 +- examples/as_app/talk_bot_ai/lib/main.py | 2 +- nc_py_api/ex_app/integration_fastapi.py | 18 ++++++------------ tests/_install_init_handler_models.py | 2 +- tests/actual_tests/nc_app_test.py | 2 +- 6 files changed, 11 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36b4b857..61a44127 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. ### Changed - set_handlers: `enabled_handler`, `heartbeat_handler`, `init_handler` now can be async(Coroutines). #175 #181 +- set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. #184 - drop Python 3.9 support. #180 - internal code refactoring and clean-up #177 diff --git a/docs/NextcloudTalkBotTransformers.rst b/docs/NextcloudTalkBotTransformers.rst index c3bae7f8..ad785985 100644 --- a/docs/NextcloudTalkBotTransformers.rst +++ b/docs/NextcloudTalkBotTransformers.rst @@ -60,7 +60,7 @@ This library also provides an additional functionality over this endpoint for ea @asynccontextmanager async def lifespan(_app: FastAPI): - set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME]) + set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME:{}}) yield This will automatically download models specified in ``models_to_fetch`` parameter to the application persistent storage. diff --git a/examples/as_app/talk_bot_ai/lib/main.py b/examples/as_app/talk_bot_ai/lib/main.py index f9cb412d..4d6e0c26 100644 --- a/examples/as_app/talk_bot_ai/lib/main.py +++ b/examples/as_app/talk_bot_ai/lib/main.py @@ -15,7 +15,7 @@ @asynccontextmanager async def lifespan(_app: FastAPI): - set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME]) + set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}}) yield diff --git a/nc_py_api/ex_app/integration_fastapi.py b/nc_py_api/ex_app/integration_fastapi.py index 553cece4..d356fc7b 100644 --- a/nc_py_api/ex_app/integration_fastapi.py +++ b/nc_py_api/ex_app/integration_fastapi.py @@ -75,8 +75,7 @@ def set_handlers( enabled_handler: typing.Callable[[bool, AsyncNextcloudApp | NextcloudApp], typing.Awaitable[str] | str], heartbeat_handler: typing.Callable[[], typing.Awaitable[str] | str] | None = None, init_handler: typing.Callable[[AsyncNextcloudApp | NextcloudApp], typing.Awaitable[None] | None] | None = None, - models_to_fetch: list[str] | None = None, - models_download_params: dict | None = None, + models_to_fetch: dict[str, dict] | None = None, map_app_static: bool = True, ): """Defines handlers for the application. @@ -92,7 +91,6 @@ def set_handlers( .. note:: ```huggingface_hub`` package should be present for automatic models fetching. - :param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**. :param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not. .. note:: First, presence of these directories in the current working dir is checked, then one directory higher. @@ -140,8 +138,7 @@ async def init_callback( background_tasks.add_task( __fetch_models_task, nc, - models_to_fetch if models_to_fetch else [], - models_download_params if models_download_params else {}, + models_to_fetch if models_to_fetch else {}, ) return responses.JSONResponse(content={}, status_code=200) @@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI): def __fetch_models_task( nc: NextcloudApp, - models: list[str], - params: dict[str, typing.Any], + models: dict[str, dict], ) -> None: if models: from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401 @@ -193,10 +189,8 @@ def display(self, msg=None, pos=None): nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100)) return super().display(msg, pos) - if "max_workers" not in params: - params["max_workers"] = 2 - if "cache_dir" not in params: - params["cache_dir"] = persistent_storage() for model in models: - snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa + workers = models[model].pop("max_workers", 2) + cache = models[model].pop("cache_dir", persistent_storage()) + snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache) nc.set_init_status(100) diff --git a/tests/_install_init_handler_models.py b/tests/_install_init_handler_models.py index 4f866e8c..625a9b20 100644 --- a/tests/_install_init_handler_models.py +++ b/tests/_install_init_handler_models.py @@ -10,7 +10,7 @@ @asynccontextmanager async def lifespan(_app: FastAPI): - ex_app.set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME]) + ex_app.set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}}) yield diff --git a/tests/actual_tests/nc_app_test.py b/tests/actual_tests/nc_app_test.py index 7c13e589..502fe41f 100644 --- a/tests/actual_tests/nc_app_test.py +++ b/tests/actual_tests/nc_app_test.py @@ -116,4 +116,4 @@ async def test_set_user_same_value_async(anc_app): def test_set_handlers_invalid_param(nc_any): with pytest.raises(ValueError): - set_handlers(None, None, init_handler=set_handlers, models_to_fetch=["some"]) # noqa + set_handlers(None, None, init_handler=set_handlers, models_to_fetch={"some": {}}) # noqa