Skip to content

Commit

Permalink
[Misc] Minimum requirements for SageMaker compatibility (#11576)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-az authored Jan 2, 2025
1 parent 5dba257 commit 68d3780
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
13 changes: 11 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ RUN mv vllm test_docs/
#################### TEST IMAGE ####################

#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai
# base openai image with additional requirements, for any subsequent openai-style images
FROM vllm-base AS vllm-openai-base

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
Expand All @@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \

ENV VLLM_USAGE_SOURCE production-docker-image

# define sagemaker first, so it is not default from `docker build`
FROM vllm-openai-base AS vllm-sagemaker

COPY examples/sagemaker-entrypoint.sh .
RUN chmod +x sagemaker-entrypoint.sh
ENTRYPOINT ["./sagemaker-entrypoint.sh"]

FROM vllm-openai-base AS vllm-openai

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
#################### OPENAI API SERVER ####################
24 changes: 24 additions & 0 deletions examples/sagemaker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Define the prefix for environment variables to look for
PREFIX="SM_VLLM_"
ARG_PREFIX="--"

# Initialize an array for storing the arguments
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
ARGS=(--port 8080)

# Loop through all environment variables
while IFS='=' read -r key value; do
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')

# Add the argument name and value to the ARGS array
ARGS+=("${ARG_PREFIX}${arg_name}")
if [ -n "$value" ]; then
ARGS+=("$value")
fi
done < <(env | grep "^${PREFIX}")

# Pass the collected arguments to the main entrypoint
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"
61 changes: 60 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import AsyncIterator, Optional, Set, Tuple

import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down Expand Up @@ -44,11 +44,15 @@
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
Expand Down Expand Up @@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200)


@router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)


@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
Expand Down Expand Up @@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


TASK_HANDLERS = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
"embed": {
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (ScoreRequest, create_score),
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
"classify": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
}


@router.post("/invocations")
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
task = raw_request.app.state.task

if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")

handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]

# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)


if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
Expand Down Expand Up @@ -687,6 +745,7 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.task = model_config.task


def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
Expand Down

0 comments on commit 68d3780

Please sign in to comment.