Skip to content

Commit

Permalink
[Bug]: Authorization ignored when root_path is set (vllm-project#10606)
Browse files Browse the repository at this point in the history
Signed-off-by: chaunceyjiang <[email protected]>
  • Loading branch information
chaunceyjiang authored Nov 25, 2024
1 parent 2b0879b commit d04b13a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
103 changes: 103 additions & 0 deletions tests/entrypoints/openai/test_root_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import contextlib
import os
from typing import Any, List, NamedTuple

import openai # use the official client for correctness check
import pytest

from ...utils import RemoteOpenAIServer

# # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
API_KEY = "abc-123"
ERROR_API_KEY = "abc"
ROOT_PATH = "llm"


@pytest.fixture(scope="module")
def server():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--enforce-eager",
"--max-model-len",
"4080",
"--root-path", # use --root-path=/llm for testing
"/" + ROOT_PATH,
"--chat-template",
DUMMY_CHAT_TEMPLATE,
]
envs = os.environ.copy()

envs["VLLM_API_KEY"] = API_KEY
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server


class TestCase(NamedTuple):
model_name: str
base_url: List[str]
api_key: str
expected_error: Any


@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_case",
[
TestCase(
model_name=MODEL_NAME,
base_url=["v1"], # http://localhost:8000/v1
api_key=ERROR_API_KEY,
expected_error=openai.AuthenticationError),
TestCase(
model_name=MODEL_NAME,
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
api_key=ERROR_API_KEY,
expected_error=openai.AuthenticationError),
TestCase(
model_name=MODEL_NAME,
base_url=["v1"], # http://localhost:8000/v1
api_key=API_KEY,
expected_error=None),
TestCase(
model_name=MODEL_NAME,
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
api_key=API_KEY,
expected_error=None),
],
)
async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer,
test_case: TestCase):
saying: str = "Here is a common saying about apple. An apple a day, keeps"
ctx = contextlib.nullcontext()
if test_case.expected_error is not None:
ctx = pytest.raises(test_case.expected_error)
with ctx:
client = openai.AsyncOpenAI(
api_key=test_case.api_key,
base_url=server.url_for(*test_case.base_url),
max_retries=0)
chat_completion = await client.chat.completions.create(
model=test_case.model_name,
messages=[{
"role": "user",
"content": "tell me a common saying"
}, {
"role": "assistant",
"content": saying
}],
extra_body={
"continue_final_message": True,
"add_generation_prompt": False
})

assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "stop"
message = choice.message
assert len(message.content) > 0
assert message.role == "assistant"
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,12 @@ async def validation_exception_handler(_, exc):

@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if request.method == "OPTIONS":
return await call_next(request)
if not request.url.path.startswith(f"{root_path}/v1"):
url_path = request.url.path
if app.root_path and url_path.startswith(app.root_path):
url_path = url_path[len(app.root_path):]
if not url_path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
Expand Down

0 comments on commit d04b13a

Please sign in to comment.