diff --git a/Dockerfile b/Dockerfile index 153bff9cf565f..088314eb38dbe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -234,8 +234,8 @@ RUN mv vllm test_docs/ #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### -# openai api server alternative -FROM vllm-base AS vllm-openai +# base openai image with additional requirements, for any subsequent openai-style images +FROM vllm-base AS vllm-openai-base # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ @@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV VLLM_USAGE_SOURCE production-docker-image +# define sagemaker first, so it is not default from `docker build` +FROM vllm-openai-base AS vllm-sagemaker + +COPY examples/sagemaker-entrypoint.sh . +RUN chmod +x sagemaker-entrypoint.sh +ENTRYPOINT ["./sagemaker-entrypoint.sh"] + +FROM vllm-openai-base AS vllm-openai + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/examples/sagemaker-entrypoint.sh b/examples/sagemaker-entrypoint.sh new file mode 100644 index 0000000000000..75a99ffc1f155 --- /dev/null +++ b/examples/sagemaker-entrypoint.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Define the prefix for environment variables to look for +PREFIX="SM_VLLM_" +ARG_PREFIX="--" + +# Initialize an array for storing the arguments +# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response +ARGS=(--port 8080) + +# Loop through all environment variables +while IFS='=' read -r key value; do + # Remove the prefix from the key, convert to lowercase, and replace underscores with dashes + arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-') + + # Add the argument name and value to the ARGS array + ARGS+=("${ARG_PREFIX}${arg_name}") + if [ -n "$value" ]; then + ARGS+=("$value") + fi +done < <(env | grep "^${PREFIX}") + +# Pass the collected arguments to the main entrypoint +exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}" \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 74fe378fdae42..e942b475535ad 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,7 +16,7 @@ from typing import AsyncIterator, Optional, Set, Tuple import uvloop -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -44,11 +44,15 @@ CompletionResponse, DetokenizeRequest, DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, LoadLoraAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, PoolingRequest, PoolingResponse, ScoreRequest, ScoreResponse, TokenizeRequest, @@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response: return Response(status_code=200) +@router.api_route("/ping", methods=["GET", "POST"]) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post("/tokenize") @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): @@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +TASK_HANDLERS = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "default": (ScoreRequest, create_score), + }, + "reward": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, + "classify": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, +} + + +@router.post("/invocations") +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + body = await raw_request.json() + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}") + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -687,6 +745,7 @@ def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.task = model_config.task def create_server_socket(addr: Tuple[str, int]) -> socket.socket: