Skip to content

Commit

Permalink
Fix distant http model json encoding (#194)
Browse files Browse the repository at this point in the history
* distant models: raise typerrors

* distant models: handle pydantic json serialization
  • Loading branch information
antoinejeannot authored Aug 24, 2023
1 parent 563a1cf commit ddb0a5c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
33 changes: 29 additions & 4 deletions modelkit/core/models/distant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ def _load(self):
async def _predict(self, item, **kwargs):
if self.aiohttp_session is None:
self.aiohttp_session = aiohttp.ClientSession()
try:
item = json.dumps(item)
except TypeError:
# TypeError: Object of type {ItemType} is not JSON serializable
# Try converting the pydantic model to json directly
item = item.json()
async with self.aiohttp_session.post(
self.endpoint,
params=kwargs.get("endpoint_params", self.endpoint_params),
data=json.dumps(item),
data=item,
headers={"content-type": "application/json"},
) as response:
if response.status != 200:
Expand Down Expand Up @@ -104,10 +110,16 @@ def _load(self):
def _predict(self, item, **kwargs):
if not self.requests_session:
self.requests_session = requests.Session()
try:
item = json.dumps(item)
except TypeError:
# TypeError: Object of type {ItemType} is not JSON serializable
# Try converting the pydantic model to json directly
item = item.json()
response = self.requests_session.post(
self.endpoint,
params=kwargs.get("endpoint_params", self.endpoint_params),
data=json.dumps(item),
data=item,
headers={"content-type": "application/json"},
)
if response.status_code != 200:
Expand Down Expand Up @@ -146,10 +158,16 @@ def _load(self):
def _predict_batch(self, items, **kwargs):
if not self.requests_session:
self.requests_session = requests.Session()
try:
items = json.dumps(items)
except TypeError:
# TypeError: Object of type {ItemType} is not JSON serializable
# Try converting a list of pydantic models to dict
items = json.dumps([item.dict() for item in items])
response = self.requests_session.post(
self.endpoint,
params=kwargs.get("endpoint_params", self.endpoint_params),
data=json.dumps(items),
data=items,
headers={"content-type": "application/json"},
)
if response.status_code != 200:
Expand Down Expand Up @@ -185,10 +203,17 @@ def __init__(self, **kwargs):
async def _predict_batch(self, items, **kwargs):
if self.aiohttp_session is None:
self.aiohttp_session = aiohttp.ClientSession()
try:
items = json.dumps(items)
except TypeError:
# TypeError: Object of type {ItemType} is not JSON serializable
# Try converting a list of pydantic models to dict
items = json.dumps([item.dict() for item in items])

async with self.aiohttp_session.post(
self.endpoint,
params=kwargs.get("endpoint_params", self.endpoint_params),
data=json.dumps(items),
data=items,
headers={"content-type": "application/json"},
) as response:
if response.status != 200:
Expand Down
70 changes: 70 additions & 0 deletions tests/test_distant_http_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
import time

import pydantic
import pytest
import requests

Expand Down Expand Up @@ -39,13 +40,27 @@ def run_mocked_service():
proc.terminate()


class SomeContentModel(pydantic.BaseModel):
some_content: str


@pytest.mark.asyncio
@pytest.mark.parametrize(
"item,params,expected",
[
({"some_content": "something"}, {}, {"some_content": "something"}),
(
SomeContentModel(**{"some_content": "something"}),
{},
{"some_content": "something"},
),
(
{"some_content": "something"},
{"limit": 10},
{"some_content": "something", "limit": 10},
),
(
SomeContentModel(**{"some_content": "something"}),
{"limit": 10},
{"some_content": "something", "limit": 10},
),
Expand All @@ -54,11 +69,21 @@ def run_mocked_service():
{"skip": 5},
{"some_content": "something", "skip": 5},
),
(
SomeContentModel(**{"some_content": "something"}),
{"skip": 5},
{"some_content": "something", "skip": 5},
),
(
{"some_content": "something"},
{"limit": 10, "skip": 5},
{"some_content": "something", "limit": 10, "skip": 5},
),
(
SomeContentModel(**{"some_content": "something"}),
{"limit": 10, "skip": 5},
{"some_content": "something", "limit": 10, "skip": 5},
),
],
)
async def test_distant_http_model(
Expand Down Expand Up @@ -108,6 +133,10 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
assert expected == m(item, endpoint_params=params)


class SomeOtherContentModel(pydantic.BaseModel):
some_other_content: str


@pytest.mark.asyncio
@pytest.mark.parametrize(
"items, params, expected",
Expand All @@ -117,6 +146,14 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
{},
[{"some_content": "something"}, {"some_other_content": "something_else"}],
),
(
[
SomeContentModel(**{"some_content": "something"}),
SomeOtherContentModel(**{"some_other_content": "something_else"}),
],
{},
[{"some_content": "something"}, {"some_other_content": "something_else"}],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"limit": 10},
Expand All @@ -125,6 +162,17 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
{"some_other_content": "something_else", "limit": 10},
],
),
(
[
SomeContentModel(**{"some_content": "something"}),
SomeOtherContentModel(**{"some_other_content": "something_else"}),
],
{"limit": 10},
[
{"some_content": "something", "limit": 10},
{"some_other_content": "something_else", "limit": 10},
],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"skip": 5},
Expand All @@ -133,6 +181,17 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
{"some_other_content": "something_else", "skip": 5},
],
),
(
[
SomeContentModel(**{"some_content": "something"}),
SomeOtherContentModel(**{"some_other_content": "something_else"}),
],
{"skip": 5},
[
{"some_content": "something", "skip": 5},
{"some_other_content": "something_else", "skip": 5},
],
),
(
[{"some_content": "something"}, {"some_other_content": "something_else"}],
{"limit": 10, "skip": 5},
Expand All @@ -141,6 +200,17 @@ class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
{"some_other_content": "something_else", "limit": 10, "skip": 5},
],
),
(
[
SomeContentModel(**{"some_content": "something"}),
SomeOtherContentModel(**{"some_other_content": "something_else"}),
],
{"limit": 10, "skip": 5},
[
{"some_content": "something", "limit": 10, "skip": 5},
{"some_other_content": "something_else", "limit": 10, "skip": 5},
],
),
],
)
async def test_distant_http_batch_model(
Expand Down

0 comments on commit ddb0a5c

Please sign in to comment.