Skip to content

Commit

Permalink
Merge pull request #183 from Cornerstone-OnDemand/distant-http-model-…
Browse files Browse the repository at this point in the history
…batch

Implement distant model with predict batch
  • Loading branch information
mathilde-leval authored Mar 30, 2023
2 parents 0772848 + 2dbf581 commit b3887ef
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 2 deletions.
36 changes: 36 additions & 0 deletions modelkit/core/models/distant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ def _predict(self, item, **kwargs):
def close(self):
if self.requests_session:
return self.requests_session.close()


class DistantHTTPBatchModel(Model[ItemType, ReturnType]):
"""
Model to extend to be able to call a batch endpoint
expecting a list of ItemType as input.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.endpoint = self.model_settings["endpoint"]
self.endpoint_params = self.model_settings.get("endpoint_params", {})
self.requests_session: Optional[requests.Session] = None

def _load(self):
pass

@retry(**SERVICE_MODEL_RETRY_POLICY)
def _predict_batch(self, items, **kwargs):
if not self.requests_session:
self.requests_session = requests.Session()
response = self.requests_session.post(
self.endpoint,
params=kwargs.get("endpoint_params", self.endpoint_params),
data=json.dumps(items),
headers={"content-type": "application/json"},
)
if response.status_code != 200:
raise DistantHTTPModelError(
response.status_code, response.reason, response.text
)
return response.json()

def close(self):
if self.requests_session:
return self.requests_session.close()
67 changes: 66 additions & 1 deletion tests/test_distant_http_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import requests

from modelkit.core.library import ModelLibrary
from modelkit.core.models.distant_model import AsyncDistantHTTPModel, DistantHTTPModel
from modelkit.core.models.distant_model import (
AsyncDistantHTTPModel,
DistantHTTPBatchModel,
DistantHTTPModel,
)
from tests import TEST_DIR


Expand Down Expand Up @@ -101,3 +105,64 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
# Test with synchronous mode
m = lib.get("some_model_sync")
assert expected == m(item, endpoint_params=params)


@pytest.mark.parametrize(
"items, params, expected",
[
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{},
[{"some_content": "something"}, {"some_other_content": "something_else"}],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"limit": 10},
[
{"some_content": "something", "limit": 10},
{"some_other_content": "something_else", "limit": 10},
],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"skip": 5},
[
{"some_content": "something", "skip": 5},
{"some_other_content": "something_else", "skip": 5},
],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"limit": 10, "skip": 5},
[
{"some_content": "something", "limit": 10, "skip": 5},
{"some_other_content": "something_else", "limit": 10, "skip": 5},
],
),
],
)
def test_distant_http_batch_model(
items, params, expected, run_mocked_service, event_loop
):
model_settings = {
"endpoint": "http://127.0.0.1:8000/api/path/endpoint/batch",
"async_mode": False,
}

class SomeDistantHTTPBatchModel(DistantHTTPBatchModel):
CONFIGURATIONS = {
"some_model_batch": {"model_settings": model_settings},
}

lib_without_params = ModelLibrary(models=[SomeDistantHTTPBatchModel])
lib_with_params = ModelLibrary(
models=[SomeDistantHTTPBatchModel],
configuration={
"some_model_batch": {
"model_settings": {**params, **model_settings},
}
},
)
for lib in [lib_without_params, lib_with_params]:
m = lib.get("some_model_batch")
assert expected == m.predict_batch(items, endpoint_params=params)
16 changes: 15 additions & 1 deletion tests/testdata/mocked_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

import fastapi
from starlette.responses import JSONResponse
Expand All @@ -17,3 +17,17 @@ async def some_endpoint(
if skip:
item["skip"] = skip
return JSONResponse(item)


@app.post("/api/path/endpoint/batch")
async def some_endpoint_batch(
items: List[Dict[str, str]],
limit: Optional[int] = None,
skip: Optional[int] = None,
):
for item in items:
if limit:
item["limit"] = limit
if skip:
item["skip"] = skip
return JSONResponse(items)

0 comments on commit b3887ef

Please sign in to comment.