diff --git a/modelkit/core/models/distant_model.py b/modelkit/core/models/distant_model.py index 2b26d0ad..04f0aa97 100644 --- a/modelkit/core/models/distant_model.py +++ b/modelkit/core/models/distant_model.py @@ -106,3 +106,39 @@ def _predict(self, item, **kwargs): def close(self): if self.requests_session: return self.requests_session.close() + + +class DistantHTTPBatchModel(Model[ItemType, ReturnType]): + """ + Model to extend to be able to call a batch endpoint + expecting a list of ItemType as input. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.endpoint = self.model_settings["endpoint"] + self.endpoint_params = self.model_settings.get("endpoint_params", {}) + self.requests_session: Optional[requests.Session] = None + + def _load(self): + pass + + @retry(**SERVICE_MODEL_RETRY_POLICY) + def _predict_batch(self, items, **kwargs): + if not self.requests_session: + self.requests_session = requests.Session() + response = self.requests_session.post( + self.endpoint, + params=kwargs.get("endpoint_params", self.endpoint_params), + data=json.dumps(items), + headers={"content-type": "application/json"}, + ) + if response.status_code != 200: + raise DistantHTTPModelError( + response.status_code, response.reason, response.text + ) + return response.json() + + def close(self): + if self.requests_session: + return self.requests_session.close() diff --git a/tests/test_distant_http_model.py b/tests/test_distant_http_model.py index ed4e9fb2..0947b499 100644 --- a/tests/test_distant_http_model.py +++ b/tests/test_distant_http_model.py @@ -7,7 +7,11 @@ import requests from modelkit.core.library import ModelLibrary -from modelkit.core.models.distant_model import AsyncDistantHTTPModel, DistantHTTPModel +from modelkit.core.models.distant_model import ( + AsyncDistantHTTPModel, + DistantHTTPBatchModel, + DistantHTTPModel, +) from tests import TEST_DIR @@ -101,3 +105,64 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel): # Test with synchronous mode m = lib.get("some_model_sync") assert expected == m(item, endpoint_params=params) + + +@pytest.mark.parametrize( + "items, params, expected", + [ + ( + [{"some_content": "something"}, {"some_other_content": "something_else"}], + {}, + [{"some_content": "something"}, {"some_other_content": "something_else"}], + ), + ( + [{"some_content": "something"}, {"some_other_content": "something_else"}], + {"limit": 10}, + [ + {"some_content": "something", "limit": 10}, + {"some_other_content": "something_else", "limit": 10}, + ], + ), + ( + [{"some_content": "something"}, {"some_other_content": "something_else"}], + {"skip": 5}, + [ + {"some_content": "something", "skip": 5}, + {"some_other_content": "something_else", "skip": 5}, + ], + ), + ( + [{"some_content": "something"}, {"some_other_content": "something_else"}], + {"limit": 10, "skip": 5}, + [ + {"some_content": "something", "limit": 10, "skip": 5}, + {"some_other_content": "something_else", "limit": 10, "skip": 5}, + ], + ), + ], +) +def test_distant_http_batch_model( + items, params, expected, run_mocked_service, event_loop +): + model_settings = { + "endpoint": "http://127.0.0.1:8000/api/path/endpoint/batch", + "async_mode": False, + } + + class SomeDistantHTTPBatchModel(DistantHTTPBatchModel): + CONFIGURATIONS = { + "some_model_batch": {"model_settings": model_settings}, + } + + lib_without_params = ModelLibrary(models=[SomeDistantHTTPBatchModel]) + lib_with_params = ModelLibrary( + models=[SomeDistantHTTPBatchModel], + configuration={ + "some_model_batch": { + "model_settings": {**params, **model_settings}, + } + }, + ) + for lib in [lib_without_params, lib_with_params]: + m = lib.get("some_model_batch") + assert expected == m.predict_batch(items, endpoint_params=params) diff --git a/tests/testdata/mocked_service.py b/tests/testdata/mocked_service.py index 8ad0d3c9..60f72432 100644 --- a/tests/testdata/mocked_service.py +++ b/tests/testdata/mocked_service.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional import fastapi from starlette.responses import JSONResponse @@ -17,3 +17,17 @@ async def some_endpoint( if skip: item["skip"] = skip return JSONResponse(item) + + +@app.post("/api/path/endpoint/batch") +async def some_endpoint_batch( + items: List[Dict[str, str]], + limit: Optional[int] = None, + skip: Optional[int] = None, +): + for item in items: + if limit: + item["limit"] = limit + if skip: + item["skip"] = skip + return JSONResponse(items)