From 771a1c9c12b4840a0a56a352cb89bedacd4c82dc Mon Sep 17 00:00:00 2001 From: Bhimraj Yadav Date: Wed, 27 Nov 2024 05:52:31 +0545 Subject: [PATCH] Feat support OpenAI embedding (#367) * adds spec for embeddings * adds embedding spec * adds initial test for embedding specs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add license info * adds decode request and encode response methods * updated e2e test * adds TODO message * updeates test API * adds validation message * updated example * updated validation message * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: typo in message * added a handy helper fn i.e `get_input_as_list` * refactor: replace get_input_as_list with ensure_list in decode_request method * refactor: rename get_input_as_list to ensure_list and improve error messages in OpenAIEmbeddingSpec * fix precommit error * remove comment * adds test cases * adds test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update decode request return type * remove comment * fixted tests getting trapped * refactor: update decode_request return type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix precommit error on test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add TestEmbedBatchedAPI for batch predictions * refactor: simplify batch prediction logic in TestEmbedBatchedAPI * test: add test for OpenAI embedding spec with batching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aniket Maurya --- src/litserve/__init__.py | 14 +- src/litserve/specs/__init__.py | 3 +- src/litserve/specs/openai_embedding.py | 176 ++++++++++++++++++ .../openai_embedding_spec_example.py | 49 +++++ tests/conftest.py | 10 + tests/e2e/default_openai_embedding_spec.py | 7 + tests/e2e/test_e2e.py | 25 +++ tests/test_specs.py | 131 +++++++++++++ 8 files changed, 412 insertions(+), 3 deletions(-) create mode 100644 src/litserve/specs/openai_embedding.py create mode 100644 src/litserve/test_examples/openai_embedding_spec_example.py create mode 100644 tests/e2e/default_openai_embedding_spec.py diff --git a/src/litserve/__init__.py b/src/litserve/__init__.py index 99d27c37..c1e45e62 100644 --- a/src/litserve/__init__.py +++ b/src/litserve/__init__.py @@ -17,6 +17,16 @@ from litserve.callbacks import Callback from litserve.loggers import Logger from litserve.server import LitServer, Request, Response -from litserve.specs.openai import OpenAISpec +from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec -__all__ = ["LitAPI", "LitServer", "Request", "Response", "OpenAISpec", "test_examples", "Callback", "Logger"] +__all__ = [ + "LitAPI", + "LitServer", + "Request", + "Response", + "OpenAISpec", + "OpenAIEmbeddingSpec", + "test_examples", + "Callback", + "Logger", +] diff --git a/src/litserve/specs/__init__.py b/src/litserve/specs/__init__.py index 90395c68..71fe4ef1 100644 --- a/src/litserve/specs/__init__.py +++ b/src/litserve/specs/__init__.py @@ -1,3 +1,4 @@ from litserve.specs.openai import OpenAISpec +from litserve.specs.openai_embedding import OpenAIEmbeddingSpec -__all__ = ["OpenAISpec"] +__all__ = ["OpenAISpec", "OpenAIEmbeddingSpec"] diff --git a/src/litserve/specs/openai_embedding.py b/src/litserve/specs/openai_embedding.py new file mode 100644 index 00000000..7684e490 --- /dev/null +++ b/src/litserve/specs/openai_embedding.py @@ -0,0 +1,176 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import inspect +import logging +import time +import uuid +from typing import List, Literal, Optional, Union + +from fastapi import Request, Response, status +from pydantic import BaseModel + +from litserve.specs.base import LitSpec +from litserve.utils import LitAPIStatus + +logger = logging.getLogger(__name__) + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str + dimensions: Optional[int] = None + encoding_format: Literal["float"] = "float" + + def ensure_list(self): + return self.input if isinstance(self.input, list) else [self.input] + + +class Embedding(BaseModel): + index: int + embedding: List[float] + object: Literal["embedding"] = "embedding" + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + + +class EmbeddingResponse(BaseModel): + data: List[Embedding] + model: str + object: Literal["list"] = "list" + usage: UsageInfo + + +EMBEDDING_API_EXAMPLE = """ +Please follow the example below for guidance on how to use the OpenAI Embedding spec: + +```python +import numpy as np +from typing import List +from litserve import LitAPI, OpenAIEmbeddingSpec + + +class TestAPI(LitAPI): + def setup(self, device): + self.model = None + + def decode_request(self, request) -> List[str]: + return request.ensure_list() + + def predict(self, x) -> List[List[float]]: + return np.random.rand(len(x), 768).tolist() + + def encode_response(self, output) -> dict: + return {"embeddings": output} + +if __name__ == "__main__": + import litserve as ls + server = ls.LitServer(TestAPI(), spec=OpenAIEmbeddingSpec()) + server.run() +``` +""" + + +class OpenAIEmbeddingSpec(LitSpec): + def __init__(self): + super().__init__() + # register the endpoint + self.add_endpoint("/v1/embeddings", self.embeddings, ["POST"]) + self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"]) + + def setup(self, server: "LitServer"): # noqa: F821 + from litserve import LitAPI + + super().setup(server) + + lit_api = self._server.lit_api + if inspect.isgeneratorfunction(lit_api.predict): + raise ValueError( + "You are using yield in your predict method, which is used for streaming.", + "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", + "is not a sequential operation.", + "Please consider replacing yield with return in predict.\n", + EMBEDDING_API_EXAMPLE, + ) + + is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__ + if not is_encode_response_original and inspect.isgeneratorfunction(lit_api.encode_response): + raise ValueError( + "You are using yield in your encode_response method, which is used for streaming.", + "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", + "is not a sequential operation.", + "Please consider replacing yield with return in encode_response.\n", + EMBEDDING_API_EXAMPLE, + ) + + print("OpenAI Embedding Spec is ready.") + + def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]: + return request.ensure_list() + + def encode_response(self, output: List[List[float]], context_kwargs: Optional[dict] = None) -> dict: + return { + "embeddings": output, + "prompt_tokens": context_kwargs.get("prompt_tokens", 0), + "total_tokens": context_kwargs.get("total_tokens", 0), + } + + def validate_response(self, response: dict) -> None: + if not isinstance(response, dict): + raise ValueError( + "The response is not a dictionary." + "The response should be a dictionary to ensure proper compatibility with the OpenAIEmbeddingSpec.\n\n" + "Please ensure that your response is a dictionary with the following keys:\n" + "- 'embeddings' (required)\n" + "- 'prompt_tokens' (optional)\n" + "- 'total_tokens' (optional)\n" + f"{EMBEDDING_API_EXAMPLE}" + ) + if "embeddings" not in response: + raise ValueError( + "The response does not contain the key 'embeddings'." + "The key 'embeddings' is required to ensure proper compatibility with the OpenAIEmbeddingSpec.\n" + "Please ensure that your response contains the key 'embeddings'.\n" + f"{EMBEDDING_API_EXAMPLE}" + ) + + async def embeddings(self, request: EmbeddingRequest): + response_queue_id = self.response_queue_id + logger.debug("Received embedding request: %s", request) + uid = uuid.uuid4() + event = asyncio.Event() + self._server.response_buffer[uid] = event + + self._server.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), request.model_copy())) + await event.wait() + + response, status = self._server.response_buffer.pop(uid) + + if status == LitAPIStatus.ERROR: + raise response + + logger.debug(response) + + self.validate_response(response) + + usage = UsageInfo(**response) + data = [Embedding(index=i, embedding=embedding) for i, embedding in enumerate(response["embeddings"])] + + return EmbeddingResponse(data=data, model=request.model, usage=usage) + + async def options_embeddings(self, request: Request): + return Response(status_code=status.HTTP_200_OK) diff --git a/src/litserve/test_examples/openai_embedding_spec_example.py b/src/litserve/test_examples/openai_embedding_spec_example.py new file mode 100644 index 00000000..d80dba95 --- /dev/null +++ b/src/litserve/test_examples/openai_embedding_spec_example.py @@ -0,0 +1,49 @@ +from typing import List + +import numpy as np + +from litserve.api import LitAPI + + +class TestEmbedAPI(LitAPI): + def setup(self, device): + self.model = None + + def decode_request(self, request) -> List[str]: + return request.ensure_list() + + def predict(self, x) -> List[List[float]]: + return np.random.rand(len(x), 768).tolist() + + def encode_response(self, output) -> dict: + return {"embeddings": output} + + +class TestEmbedBatchedAPI(TestEmbedAPI): + def predict(self, batch) -> List[List[List[float]]]: + return [np.random.rand(len(x), 768).tolist() for x in batch] + + +class TestEmbedAPIWithUsage(TestEmbedAPI): + def encode_response(self, output) -> dict: + return {"embeddings": output, "prompt_tokens": 10, "total_tokens": 10} + + +class TestEmbedAPIWithYieldPredict(TestEmbedAPI): + def predict(self, x): + yield from np.random.rand(768).tolist() + + +class TestEmbedAPIWithYieldEncodeResponse(TestEmbedAPI): + def encode_response(self, output): + yield {"embeddings": output} + + +class TestEmbedAPIWithNonDictOutput(TestEmbedAPI): + def encode_response(self, output): + return output + + +class TestEmbedAPIWithMissingEmbeddings(TestEmbedAPI): + def encode_response(self, output): + return {"output": output} diff --git a/tests/conftest.py b/tests/conftest.py index f2ce8b6d..aca14ff0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -265,3 +265,13 @@ def openai_request_data_with_response_format(): "frequency_penalty": 0, "user": "string", } + + +@pytest.fixture +def openai_embedding_request_data(): + return {"input": "A beautiful sunset over the beach.", "model": "lit", "encoding_format": "float"} + + +@pytest.fixture +def openai_embedding_request_data_array(): + return {"input": ["A beautiful sunset over the beach."] * 4, "model": "lit", "encoding_format": "float"} diff --git a/tests/e2e/default_openai_embedding_spec.py b/tests/e2e/default_openai_embedding_spec.py new file mode 100644 index 00000000..aad833b9 --- /dev/null +++ b/tests/e2e/default_openai_embedding_spec.py @@ -0,0 +1,7 @@ +import litserve as ls +from litserve import OpenAIEmbeddingSpec +from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI + +if __name__ == "__main__": + server = ls.LitServer(TestEmbedAPI(), spec=OpenAIEmbeddingSpec()) + server.run() diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index c64859f5..81d8a165 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -326,3 +326,28 @@ def test_e2e_single_streaming(): expected_values = [4.0, 8.0, 12.0] for i, output in enumerate(outputs): assert output["output"] == expected_values[i], f"Intermediate output {i} is not expected value" + + +@e2e_from_file("tests/e2e/default_openai_embedding_spec.py") +def test_openai_embedding_parity(): + client = OpenAI( + base_url="http://127.0.0.1:8000/v1", + api_key="lit", + ) + + model = "lit" + input_text = "The food was delicious and the waiter was very friendly." + input_text_list = [input_text] * 2 + response = client.embeddings.create( + model="lit", input="The food was delicious and the waiter...", encoding_format="float" + ) + assert response.model == model, f"Expected model to be {model} but got {response.model}" + assert len(response.data) == 1, f"Expected 1 embeddings but got {len(response.data)}" + assert len(response.data[0].embedding) == 768, f"Expected 768 dimensions but got {len(response.data[0].embedding)}" + assert isinstance(response.data[0].embedding[0], float), "Expected float datatype but got something else" + + response = client.embeddings.create(model="lit", input=input_text_list, encoding_format="float") + assert response.model == model, f"Expected model to be {model} but got {response.model}" + assert len(response.data) == 2, f"Expected 2 embeddings but got {len(response.data)}" + for data in response.data: + assert len(data.embedding) == 768, f"Expected 768 dimensions but got {len(data.embedding)}" diff --git a/tests/test_specs.py b/tests/test_specs.py index a02aad1c..d651df83 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + import pytest from asgi_lifespan import LifespanManager from fastapi import HTTPException @@ -19,6 +21,16 @@ import litserve as ls from litserve.specs.openai import ChatMessage, OpenAISpec +from litserve.specs.openai_embedding import OpenAIEmbeddingSpec +from litserve.test_examples.openai_embedding_spec_example import ( + TestEmbedAPI, + TestEmbedAPIWithMissingEmbeddings, + TestEmbedAPIWithNonDictOutput, + TestEmbedAPIWithUsage, + TestEmbedAPIWithYieldEncodeResponse, + TestEmbedAPIWithYieldPredict, + TestEmbedBatchedAPI, +) from litserve.test_examples.openai_spec_example import ( OpenAIBatchingWithUsage, OpenAIWithUsage, @@ -204,3 +216,122 @@ async def test_fail_http(openai_request_data): res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) assert res.status_code == 501, "Server raises 501 error" assert res.text == '{"detail":"test LitAPI.predict error"}' + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_with_single_input_doc(openai_embedding_request_data): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedAPI(), spec=spec) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert resp.json()["object"] == "list", "Object should be list" + assert resp.json()["data"][0]["index"] == 0, "Index should be 0" + assert len(resp.json()["data"]) == 1, "Length of data should be 1" + assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768" + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_with_multiple_input_docs(openai_embedding_request_data_array): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedAPI(), spec=spec) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/embeddings", json=openai_embedding_request_data_array, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert resp.json()["object"] == "list", "Object should be list" + assert resp.json()["data"][0]["index"] == 0, "Index should be 0" + assert len(resp.json()["data"]) == 4, "Length of data should be 1" + assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768" + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_with_usage(openai_embedding_request_data): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedAPIWithUsage(), spec=spec) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert resp.json()["object"] == "list", "Object should be list" + assert resp.json()["data"][0]["index"] == 0, "Index should be 0" + assert len(resp.json()["data"]) == 1, "Length of data should be 1" + assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768" + assert resp.json()["usage"]["prompt_tokens"] == 10, "Prompt tokens should be 10" + assert resp.json()["usage"]["total_tokens"] == 10, "Total tokens should be 10" + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_validation(openai_request_data): + server = ls.LitServer(TestEmbedAPIWithYieldPredict(), spec=OpenAIEmbeddingSpec()) + with pytest.raises(ValueError, match="You are using yield in your predict method"), wrap_litserve_start( + server + ) as server: + async with LifespanManager(server.app) as manager: + await manager.shutdown() + + server = ls.LitServer(TestEmbedAPIWithYieldEncodeResponse(), spec=OpenAIEmbeddingSpec()) + with pytest.raises(ValueError, match="You are using yield in your encode_response method"), wrap_litserve_start( + server + ) as server: + async with LifespanManager(server.app) as manager: + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_with_non_dict_output(openai_embedding_request_data): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedAPIWithNonDictOutput(), spec=spec) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + with pytest.raises(ValueError, match="The response is not a dictionary"): + await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10) + + +@pytest.mark.asyncio +async def test_openai_embedding_spec_with_missing_embeddings(openai_embedding_request_data): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedAPIWithMissingEmbeddings(), spec=spec) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + with pytest.raises(ValueError, match="The response does not contain the key 'embeddings'"): + await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "batch_size", + [2, 4], +) +async def test_openai_embedding_spec_with_batching( + batch_size, openai_embedding_request_data, openai_embedding_request_data_array +): + spec = OpenAIEmbeddingSpec() + server = ls.LitServer(TestEmbedBatchedAPI(), spec=spec, max_batch_size=batch_size, batch_timeout=0.01) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + # send single request + resp = await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert len(resp.json()["data"]) == 1, "Length of data should be 1" + assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768" + + # send concurrent requests + resp1, resp2 = await asyncio.gather( + ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10), + ac.post("/v1/embeddings", json=openai_embedding_request_data_array, timeout=10), + ) + + assert resp1.status_code == 200, "Status code should be 200" + assert resp2.status_code == 200, "Status code should be 200" + assert len(resp1.json()["data"]) == 1, "Length of data should be 1" + assert len(resp2.json()["data"]) == 4, "Length of data should be 4" + assert len(resp1.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768" + assert len(resp2.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"