Skip to content

Commit

Permalink
vllm lm api
Browse files Browse the repository at this point in the history
  • Loading branch information
ixaxaar committed Feb 23, 2024
1 parent ca1f4c5 commit 09f51b8
Showing 1 changed file with 115 additions and 2 deletions.
117 changes: 115 additions & 2 deletions geniusrise_text/language_model/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

from typing import Any, Dict, Optional
from concurrent.futures import ThreadPoolExecutor
import asyncio
import cherrypy
from geniusrise import BatchInput, BatchOutput, State
from geniusrise.logging import setup_logger
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.protocol import CompletionRequest

from geniusrise_text.base import TextAPI

Expand Down Expand Up @@ -85,6 +88,9 @@ def __init__(
"""
super().__init__(input=input, output=output, state=state)
self.log = setup_logger(self)
self.vllm_server: Optional[OpenAIServingCompletion] = None
self.event_loop: Any = None
self.executor = ThreadPoolExecutor(max_workers=4)

@cherrypy.expose
@cherrypy.tools.json_in()
Expand Down Expand Up @@ -130,3 +136,110 @@ def complete(self, **kwargs: Any) -> Dict[str, Any]:
"args": data,
"completion": self.generate(prompt=prompt, decoding_strategy=decoding_strategy, **generation_params),
}

def initialize_vllm(self):
self.vllm_server = OpenAIServingCompletion(engine=self.model, served_model=self.model_name)
self.event_loop = asyncio.new_event_loop()

@cherrypy.expose
@cherrypy.tools.json_in()
@cherrypy.tools.json_out()
@cherrypy.tools.allow(methods=["POST"])
def complete_vllm(self, **kwargs: Any) -> Dict[str, Any]:
"""
Handles POST requests to generate chat completions using the VLLM (Versatile Language Learning Model) engine.
This method accepts various parameters for customizing the chat completion request, including message content,
generation settings, and more.
Parameters:
- **kwargs (Any): Arbitrary keyword arguments. Expects data in JSON format containing any of the following keys:
- messages (Union[str, List[Dict[str, str]]]): The messages for the chat context.
- temperature (float, optional): The sampling temperature. Defaults to 0.7.
- top_p (float, optional): The nucleus sampling probability. Defaults to 1.0.
- n (int, optional): The number of completions to generate. Defaults to 1.
- max_tokens (int, optional): The maximum number of tokens to generate.
- stop (Union[str, List[str]], optional): Stop sequence to end generation.
- stream (bool, optional): Whether to stream the response. Defaults to False.
- presence_penalty (float, optional): The presence penalty. Defaults to 0.0.
- frequency_penalty (float, optional): The frequency penalty. Defaults to 0.0.
- logit_bias (Dict[str, float], optional): Adjustments to the logits of specified tokens.
- user (str, optional): An identifier for the user making the request.
- (Additional model-specific parameters)
Returns:
Dict[str, Any]: A dictionary with the chat completion response or an error message.
Example CURL Request:
```bash
curl -X POST "http://localhost:3000/complete_vllm" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Whats the weather like in London?"}
],
"temperature": 0.7,
"top_p": 1.0,
"n": 1,
"max_tokens": 50,
"stream": false,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"logit_bias": {},
"user": "example_user"
}'
```
This request asks the VLLM engine to generate a completion for the provided chat context, with specified generation settings.
"""
# Extract data from the POST request
data = cherrypy.request.json

# Initialize VLLM server with chat template and response role if not already initialized
if not hasattr(self, "vllm_server") or self.vllm_server is None:
self.initialize_vllm()

# Prepare the chat completion request
chat_request = CompletionRequest(
model=self.model_name,
prompt=data.get("messages"),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 1.0),
n=data.get("n", 1),
max_tokens=data.get("max_tokens"),
stop=data.get("stop", []),
stream=data.get("stream", False),
logprobs=data.get("logprobs", None),
presence_penalty=data.get("presence_penalty", 0.0),
frequency_penalty=data.get("frequency_penalty", 0.0),
logit_bias=data.get("logit_bias", {}),
user=data.get("user"),
best_of=data.get("best_of"),
top_k=data.get("top_k", -1),
ignore_eos=data.get("ignore_eos", False),
use_beam_search=data.get("use_beam_search", False),
stop_token_ids=data.get("stop_token_ids", []),
skip_special_tokens=data.get("skip_special_tokens", True),
spaces_between_special_tokens=data.get("spaces_between_special_tokens", True),
echo=data.get("echo", False),
repetition_penalty=data.get("repetition_penalty", 1.0),
min_p=data.get("min_p", 0.0),
include_stop_str_in_output=data.get("include_stop_str_in_output", False),
length_penalty=data.get("length_penalty", 1.0),
)

# Generate chat completion using the VLLM engine
try:

class DummyObject:
async def is_disconnected(self):
return False

async def async_call():
response = await self.vllm_server.create_completion(request=chat_request, raw_request=DummyObject())
return response

chat_completion = asyncio.run(async_call())

return chat_completion.model_dump() if chat_completion else {"error": "Failed to generate lm completion"}
except Exception as e:
self.log.exception("Error generating chat completion: %s", str(e))
raise e

0 comments on commit 09f51b8

Please sign in to comment.