Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 24, 2024
1 parent c0b35dc commit 6ff8b70
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 63 deletions.
79 changes: 45 additions & 34 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,6 @@ def _get_embedding(
assert_never(encoding_format)


def request_output_to_embedding_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
prompt_token_ids = final_res.prompt_token_ids

embedding = _get_embedding(embedding_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)

num_prompt_tokens += len(prompt_token_ids)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)

return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)


class OpenAIServingEmbedding(OpenAIServing):

def __init__(
Expand Down Expand Up @@ -114,7 +84,7 @@ async def create_embedding(

model_name = request.model
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
created_time = int(time.time())

truncate_prompt_tokens = None

Expand Down Expand Up @@ -218,13 +188,54 @@ async def create_embedding(
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)

response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
response = self.request_output_to_embedding_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

return response

def request_output_to_embedding_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> EmbeddingResponse:
items: List[EmbeddingResponseData] = []
num_prompt_tokens = 0

for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)

item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids

items.append(item)
num_prompt_tokens += len(prompt_token_ids)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)

return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
71 changes: 42 additions & 29 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,6 @@
logger = init_logger(__name__)


def request_output_to_score_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)

score_data = ScoreResponseData(index=idx,
score=classify_res.outputs.score)
data.append(score_data)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)

return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)


def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
str]) -> List:
if isinstance(text_1, (str, dict)):
Expand Down Expand Up @@ -103,7 +77,7 @@ async def create_score(

model_name = request.model
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens

request_prompts = []
Expand Down Expand Up @@ -203,12 +177,51 @@ async def create_score(
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)

response = request_output_to_score_response(
final_res_batch_checked, request_id, created_time, model_name)
response = self.request_output_to_score_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

return response

def request_output_to_score_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: List[ScoreResponseData] = []
num_prompt_tokens = 0

for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)

item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids

items.append(item)
num_prompt_tokens += len(prompt_token_ids)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)

return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)

0 comments on commit 6ff8b70

Please sign in to comment.