Skip to content

Commit

Permalink
Feat support OpenAI embedding (#367)
Browse files Browse the repository at this point in the history
* adds spec for embeddings

* adds embedding spec

* adds initial test for embedding specs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add license info

* adds decode request and encode response methods

* updated e2e test

* adds TODO message

* updeates test API

* adds validation message

* updated example

* updated validation message

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: typo in message

* added a handy helper fn i.e `get_input_as_list`

* refactor: replace get_input_as_list with ensure_list in decode_request method

* refactor: rename get_input_as_list to ensure_list and improve error messages in OpenAIEmbeddingSpec

* fix precommit error

* remove comment

* adds test cases

* adds test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update decode request return type

* remove comment

* fixted tests getting trapped

* refactor: update decode_request return type

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix precommit error on test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add TestEmbedBatchedAPI for batch predictions

* refactor: simplify batch prediction logic in TestEmbedBatchedAPI

* test: add test for OpenAI embedding spec with batching

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aniket Maurya <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2024
1 parent fbeb010 commit 771a1c9
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 3 deletions.
14 changes: 12 additions & 2 deletions src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
from litserve.callbacks import Callback
from litserve.loggers import Logger
from litserve.server import LitServer, Request, Response
from litserve.specs.openai import OpenAISpec
from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec

__all__ = ["LitAPI", "LitServer", "Request", "Response", "OpenAISpec", "test_examples", "Callback", "Logger"]
__all__ = [
"LitAPI",
"LitServer",
"Request",
"Response",
"OpenAISpec",
"OpenAIEmbeddingSpec",
"test_examples",
"Callback",
"Logger",
]
3 changes: 2 additions & 1 deletion src/litserve/specs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from litserve.specs.openai import OpenAISpec
from litserve.specs.openai_embedding import OpenAIEmbeddingSpec

__all__ = ["OpenAISpec"]
__all__ = ["OpenAISpec", "OpenAIEmbeddingSpec"]
176 changes: 176 additions & 0 deletions src/litserve/specs/openai_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import logging
import time
import uuid
from typing import List, Literal, Optional, Union

from fastapi import Request, Response, status
from pydantic import BaseModel

from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus

logger = logging.getLogger(__name__)


class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str
dimensions: Optional[int] = None
encoding_format: Literal["float"] = "float"

def ensure_list(self):
return self.input if isinstance(self.input, list) else [self.input]


class Embedding(BaseModel):
index: int
embedding: List[float]
object: Literal["embedding"] = "embedding"


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0


class EmbeddingResponse(BaseModel):
data: List[Embedding]
model: str
object: Literal["list"] = "list"
usage: UsageInfo


EMBEDDING_API_EXAMPLE = """
Please follow the example below for guidance on how to use the OpenAI Embedding spec:
```python
import numpy as np
from typing import List
from litserve import LitAPI, OpenAIEmbeddingSpec
class TestAPI(LitAPI):
def setup(self, device):
self.model = None
def decode_request(self, request) -> List[str]:
return request.ensure_list()
def predict(self, x) -> List[List[float]]:
return np.random.rand(len(x), 768).tolist()
def encode_response(self, output) -> dict:
return {"embeddings": output}
if __name__ == "__main__":
import litserve as ls
server = ls.LitServer(TestAPI(), spec=OpenAIEmbeddingSpec())
server.run()
```
"""


class OpenAIEmbeddingSpec(LitSpec):
def __init__(self):
super().__init__()
# register the endpoint
self.add_endpoint("/v1/embeddings", self.embeddings, ["POST"])
self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"])

def setup(self, server: "LitServer"): # noqa: F821
from litserve import LitAPI

super().setup(server)

lit_api = self._server.lit_api
if inspect.isgeneratorfunction(lit_api.predict):
raise ValueError(
"You are using yield in your predict method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in predict.\n",
EMBEDDING_API_EXAMPLE,
)

is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__
if not is_encode_response_original and inspect.isgeneratorfunction(lit_api.encode_response):
raise ValueError(
"You are using yield in your encode_response method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in encode_response.\n",
EMBEDDING_API_EXAMPLE,
)

print("OpenAI Embedding Spec is ready.")

def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]:
return request.ensure_list()

def encode_response(self, output: List[List[float]], context_kwargs: Optional[dict] = None) -> dict:
return {
"embeddings": output,
"prompt_tokens": context_kwargs.get("prompt_tokens", 0),
"total_tokens": context_kwargs.get("total_tokens", 0),
}

def validate_response(self, response: dict) -> None:
if not isinstance(response, dict):
raise ValueError(
"The response is not a dictionary."
"The response should be a dictionary to ensure proper compatibility with the OpenAIEmbeddingSpec.\n\n"
"Please ensure that your response is a dictionary with the following keys:\n"
"- 'embeddings' (required)\n"
"- 'prompt_tokens' (optional)\n"
"- 'total_tokens' (optional)\n"
f"{EMBEDDING_API_EXAMPLE}"
)
if "embeddings" not in response:
raise ValueError(
"The response does not contain the key 'embeddings'."
"The key 'embeddings' is required to ensure proper compatibility with the OpenAIEmbeddingSpec.\n"
"Please ensure that your response contains the key 'embeddings'.\n"
f"{EMBEDDING_API_EXAMPLE}"
)

async def embeddings(self, request: EmbeddingRequest):
response_queue_id = self.response_queue_id
logger.debug("Received embedding request: %s", request)
uid = uuid.uuid4()
event = asyncio.Event()
self._server.response_buffer[uid] = event

self._server.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), request.model_copy()))
await event.wait()

response, status = self._server.response_buffer.pop(uid)

if status == LitAPIStatus.ERROR:
raise response

logger.debug(response)

self.validate_response(response)

usage = UsageInfo(**response)
data = [Embedding(index=i, embedding=embedding) for i, embedding in enumerate(response["embeddings"])]

return EmbeddingResponse(data=data, model=request.model, usage=usage)

async def options_embeddings(self, request: Request):
return Response(status_code=status.HTTP_200_OK)
49 changes: 49 additions & 0 deletions src/litserve/test_examples/openai_embedding_spec_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List

import numpy as np

from litserve.api import LitAPI


class TestEmbedAPI(LitAPI):
def setup(self, device):
self.model = None

def decode_request(self, request) -> List[str]:
return request.ensure_list()

def predict(self, x) -> List[List[float]]:
return np.random.rand(len(x), 768).tolist()

def encode_response(self, output) -> dict:
return {"embeddings": output}


class TestEmbedBatchedAPI(TestEmbedAPI):
def predict(self, batch) -> List[List[List[float]]]:
return [np.random.rand(len(x), 768).tolist() for x in batch]


class TestEmbedAPIWithUsage(TestEmbedAPI):
def encode_response(self, output) -> dict:
return {"embeddings": output, "prompt_tokens": 10, "total_tokens": 10}


class TestEmbedAPIWithYieldPredict(TestEmbedAPI):
def predict(self, x):
yield from np.random.rand(768).tolist()


class TestEmbedAPIWithYieldEncodeResponse(TestEmbedAPI):
def encode_response(self, output):
yield {"embeddings": output}


class TestEmbedAPIWithNonDictOutput(TestEmbedAPI):
def encode_response(self, output):
return output


class TestEmbedAPIWithMissingEmbeddings(TestEmbedAPI):
def encode_response(self, output):
return {"output": output}
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,13 @@ def openai_request_data_with_response_format():
"frequency_penalty": 0,
"user": "string",
}


@pytest.fixture
def openai_embedding_request_data():
return {"input": "A beautiful sunset over the beach.", "model": "lit", "encoding_format": "float"}


@pytest.fixture
def openai_embedding_request_data_array():
return {"input": ["A beautiful sunset over the beach."] * 4, "model": "lit", "encoding_format": "float"}
7 changes: 7 additions & 0 deletions tests/e2e/default_openai_embedding_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import litserve as ls
from litserve import OpenAIEmbeddingSpec
from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI

if __name__ == "__main__":
server = ls.LitServer(TestEmbedAPI(), spec=OpenAIEmbeddingSpec())
server.run()
25 changes: 25 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,28 @@ def test_e2e_single_streaming():
expected_values = [4.0, 8.0, 12.0]
for i, output in enumerate(outputs):
assert output["output"] == expected_values[i], f"Intermediate output {i} is not expected value"


@e2e_from_file("tests/e2e/default_openai_embedding_spec.py")
def test_openai_embedding_parity():
client = OpenAI(
base_url="http://127.0.0.1:8000/v1",
api_key="lit",
)

model = "lit"
input_text = "The food was delicious and the waiter was very friendly."
input_text_list = [input_text] * 2
response = client.embeddings.create(
model="lit", input="The food was delicious and the waiter...", encoding_format="float"
)
assert response.model == model, f"Expected model to be {model} but got {response.model}"
assert len(response.data) == 1, f"Expected 1 embeddings but got {len(response.data)}"
assert len(response.data[0].embedding) == 768, f"Expected 768 dimensions but got {len(response.data[0].embedding)}"
assert isinstance(response.data[0].embedding[0], float), "Expected float datatype but got something else"

response = client.embeddings.create(model="lit", input=input_text_list, encoding_format="float")
assert response.model == model, f"Expected model to be {model} but got {response.model}"
assert len(response.data) == 2, f"Expected 2 embeddings but got {len(response.data)}"
for data in response.data:
assert len(data.embedding) == 768, f"Expected 768 dimensions but got {len(data.embedding)}"
Loading

0 comments on commit 771a1c9

Please sign in to comment.