Skip to content

Commit

Permalink
Merge pull request #66 from NexaAI/paul/stream
Browse files Browse the repository at this point in the history
Let stream completion response uses the same format as stream chat completion
  • Loading branch information
zhiyuan8 authored Sep 3, 2024
2 parents b7c2fdf + f0c9488 commit 3b4c4fe
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 27 deletions.
61 changes: 37 additions & 24 deletions nexa/gguf/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,22 +1337,25 @@ def logit_bias_processor(
returned_tokens += 1
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"object": "text_completion.chunk",
"created": created,
"choices": [
{
"text": self.detokenize(
[token],
prev_tokens=prompt_tokens
+ completion_tokens[:returned_tokens],
).decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": logprobs_or_none,
"delta": {
"content": self.detokenize(
[token],
prev_tokens=prompt_tokens
+ completion_tokens[:returned_tokens],
).decode("utf-8", errors="ignore")
},
"finish_reason": None,
"logprobs": logprobs_or_none,
}
],
}

else:
while len(remaining_tokens) > 0:
decode_success = False
Expand Down Expand Up @@ -1383,19 +1386,22 @@ def logit_bias_processor(

yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"object": "text_completion.chunk",
"created": created,
"choices": [
{
"text": ts,
"index": 0,
"logprobs": None,
"delta": {
"content": ts,
},
"finish_reason": None,
}
],
}


if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "length"
Expand Down Expand Up @@ -1475,52 +1481,59 @@ def logit_bias_processor(
returned_tokens += 1
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"object": "text_completion.chunk",
"created": created,
"choices": [
{
"text": last_text[
: len(last_text) - (token_end_position - end)
].decode("utf-8", errors="ignore"),
"index": 0,
"delta": {
"content": last_text[: len(last_text) - (token_end_position - end)].decode(
"utf-8", errors="ignore"
)
},
"logprobs": logprobs_or_none,
"finish_reason": None,
}
],
}

break
returned_tokens += 1
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"object": "text_completion.chunk",
"created": created,
"choices": [
{
"text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"index": 0,
"delta": {
"content": self.detokenize([token]).decode("utf-8", errors="ignore")
},
"logprobs": logprobs_or_none,
"finish_reason": None,
}
],
}

yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"object": "text_completion.chunk",
"created": created,
"choices": [
{
"text": "",
"index": 0,
"delta": {
"content": "",
},
"logprobs": None,
"finish_reason": finish_reason,
}
],
}

if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
Expand Down
3 changes: 2 additions & 1 deletion nexa/gguf/llama/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _convert_text_completion_chunks_to_chat(
}
],
}

yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
Expand All @@ -314,7 +315,7 @@ def _convert_text_completion_chunks_to_chat(
"index": 0,
"delta": (
{
"content": chunk["choices"][0]["text"],
"content": chunk["choices"][0]["delta"]["content"],
}
if chunk["choices"][0]["finish_reason"] is None
else {}
Expand Down
19 changes: 18 additions & 1 deletion nexa/gguf/llama/llama_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,31 @@ class ChatCompletionNamedToolChoice(TypedDict):
Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
]

class CompletionStreamResponseDelta(TypedDict):
content: Optional[str]


class CompletionStreamResponseChoice(TypedDict):
index: int
delta: CompletionStreamResponseDelta
finish_reason: Optional[Literal["stop", "length"]]
logprobs: NotRequired[Optional[CompletionLogprobs]]


class CreateCompletionStreamResponse(TypedDict):
id: str
model: str
object: Literal["text_completion.chunk"]
created: int
choices: List[CompletionStreamResponseChoice]


# NOTE: The following type names are not part of the OpenAI OpenAPI specification
# and will be removed in a future major release.

EmbeddingData = Embedding
CompletionChunk = CreateCompletionResponse
Completion = CreateCompletionResponse
CreateCompletionStreamResponse = CreateCompletionResponse
ChatCompletionMessage = ChatCompletionResponseMessage
ChatCompletionChoice = ChatCompletionResponseChoice
ChatCompletion = CreateChatCompletionResponse
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_streaming():
)
for chunk in output:
if "choices" in chunk:
print(chunk["choices"][0]["text"], end="", flush=True)
print(chunk["choices"][0]["delta"]["content"], end="", flush=True)
# TODO: add assertions here

# Test conversation mode with chat format
Expand Down

0 comments on commit 3b4c4fe

Please sign in to comment.