Skip to content

Commit

Permalink
updated api server to pass in the prompt token ids in the right forma…
Browse files Browse the repository at this point in the history
…t depending on the vllm version (#167)
  • Loading branch information
wongjingping authored Jun 14, 2024
1 parent 1e1c68e commit 70e1bfc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ results_fn_postgres/*.json
.vscode

# all eda notebooks
eda_*.ipynb
eda_*.ipynb

# wandb output (created when running upload_wandb.ipynb)
wandb/
20 changes: 14 additions & 6 deletions utils/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm import __version__ as vllm_version

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
Expand Down Expand Up @@ -56,12 +57,19 @@ async def generate(request: Request) -> Response:
if prompt_token_ids[0] != tokenizer.bos_token_id:
prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids

results_generator = engine.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_token_ids,
)
if vllm_version >= "0.4.2":
results_generator = engine.generate(
inputs={"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params,
request_id=request_id,
)
else:
results_generator = engine.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_token_ids,
)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
Expand Down

0 comments on commit 70e1bfc

Please sign in to comment.