diff --git a/geniusrise_text/language_model/api.py b/geniusrise_text/language_model/api.py index 1315d64..e793995 100644 --- a/geniusrise_text/language_model/api.py +++ b/geniusrise_text/language_model/api.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict - +from typing import Any, Dict, Optional +from concurrent.futures import ThreadPoolExecutor +import asyncio import cherrypy from geniusrise import BatchInput, BatchOutput, State from geniusrise.logging import setup_logger +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.protocol import CompletionRequest from geniusrise_text.base import TextAPI @@ -85,6 +88,9 @@ def __init__( """ super().__init__(input=input, output=output, state=state) self.log = setup_logger(self) + self.vllm_server: Optional[OpenAIServingCompletion] = None + self.event_loop: Any = None + self.executor = ThreadPoolExecutor(max_workers=4) @cherrypy.expose @cherrypy.tools.json_in() @@ -130,3 +136,110 @@ def complete(self, **kwargs: Any) -> Dict[str, Any]: "args": data, "completion": self.generate(prompt=prompt, decoding_strategy=decoding_strategy, **generation_params), } + + def initialize_vllm(self): + self.vllm_server = OpenAIServingCompletion(engine=self.model, served_model=self.model_name) + self.event_loop = asyncio.new_event_loop() + + @cherrypy.expose + @cherrypy.tools.json_in() + @cherrypy.tools.json_out() + @cherrypy.tools.allow(methods=["POST"]) + def complete_vllm(self, **kwargs: Any) -> Dict[str, Any]: + """ + Handles POST requests to generate chat completions using the VLLM (Versatile Language Learning Model) engine. + This method accepts various parameters for customizing the chat completion request, including message content, + generation settings, and more. + + Parameters: + - **kwargs (Any): Arbitrary keyword arguments. Expects data in JSON format containing any of the following keys: + - messages (Union[str, List[Dict[str, str]]]): The messages for the chat context. + - temperature (float, optional): The sampling temperature. Defaults to 0.7. + - top_p (float, optional): The nucleus sampling probability. Defaults to 1.0. + - n (int, optional): The number of completions to generate. Defaults to 1. + - max_tokens (int, optional): The maximum number of tokens to generate. + - stop (Union[str, List[str]], optional): Stop sequence to end generation. + - stream (bool, optional): Whether to stream the response. Defaults to False. + - presence_penalty (float, optional): The presence penalty. Defaults to 0.0. + - frequency_penalty (float, optional): The frequency penalty. Defaults to 0.0. + - logit_bias (Dict[str, float], optional): Adjustments to the logits of specified tokens. + - user (str, optional): An identifier for the user making the request. + - (Additional model-specific parameters) + + Returns: + Dict[str, Any]: A dictionary with the chat completion response or an error message. + + Example CURL Request: + ```bash + curl -X POST "http://localhost:3000/complete_vllm" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "Whats the weather like in London?"} + ], + "temperature": 0.7, + "top_p": 1.0, + "n": 1, + "max_tokens": 50, + "stream": false, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "logit_bias": {}, + "user": "example_user" + }' + ``` + This request asks the VLLM engine to generate a completion for the provided chat context, with specified generation settings. + """ + # Extract data from the POST request + data = cherrypy.request.json + + # Initialize VLLM server with chat template and response role if not already initialized + if not hasattr(self, "vllm_server") or self.vllm_server is None: + self.initialize_vllm() + + # Prepare the chat completion request + chat_request = CompletionRequest( + model=self.model_name, + prompt=data.get("messages"), + temperature=data.get("temperature", 0.7), + top_p=data.get("top_p", 1.0), + n=data.get("n", 1), + max_tokens=data.get("max_tokens"), + stop=data.get("stop", []), + stream=data.get("stream", False), + logprobs=data.get("logprobs", None), + presence_penalty=data.get("presence_penalty", 0.0), + frequency_penalty=data.get("frequency_penalty", 0.0), + logit_bias=data.get("logit_bias", {}), + user=data.get("user"), + best_of=data.get("best_of"), + top_k=data.get("top_k", -1), + ignore_eos=data.get("ignore_eos", False), + use_beam_search=data.get("use_beam_search", False), + stop_token_ids=data.get("stop_token_ids", []), + skip_special_tokens=data.get("skip_special_tokens", True), + spaces_between_special_tokens=data.get("spaces_between_special_tokens", True), + echo=data.get("echo", False), + repetition_penalty=data.get("repetition_penalty", 1.0), + min_p=data.get("min_p", 0.0), + include_stop_str_in_output=data.get("include_stop_str_in_output", False), + length_penalty=data.get("length_penalty", 1.0), + ) + + # Generate chat completion using the VLLM engine + try: + + class DummyObject: + async def is_disconnected(self): + return False + + async def async_call(): + response = await self.vllm_server.create_completion(request=chat_request, raw_request=DummyObject()) + return response + + chat_completion = asyncio.run(async_call()) + + return chat_completion.model_dump() if chat_completion else {"error": "Failed to generate lm completion"} + except Exception as e: + self.log.exception("Error generating chat completion: %s", str(e)) + raise e