Skip to content

Commit

Permalink
Add streaming support
Browse files Browse the repository at this point in the history
  • Loading branch information
WoytenAA committed Oct 1, 2024
1 parent 8eecd0b commit bf4c62c
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 6 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ async with AsyncClient(token=os.environ["AA_TOKEN"]) as client:
prompt=Prompt.from_text("Provide a short description of AI:"),
maximum_tokens=64,
)
response = await client.complete(request, model="luminous-base")
print(response.completions[0].completion)
response = client.complete_with_streaming(request, model="luminous-base")
async for stream_item in response:
print(stream_item)
```

### Interactive Examples
Expand Down
85 changes: 83 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import warnings

from packaging import version
from tokenizers import Tokenizer # type: ignore
from types import TracebackType
from typing import (
Any,
AsyncGenerator,
List,
Mapping,
Optional,
Expand All @@ -30,7 +32,13 @@
)
from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse
from aleph_alpha_client.qa import QaRequest, QaResponse
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.completion import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamItem,
StreamChunk,
stream_item_from_json,
)
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -759,6 +767,38 @@ async def _post_request(
_raise_for_status(response.status, await response.text())
return await response.json()

SSE_DATA_PREFIX = "data: "

async def _post_request_with_streaming(
self,
endpoint: str,
request: AnyRequest,
model: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
json_body = self._build_json_body(request, model)
json_body["stream"] = True

query_params = self._build_query_parameters()

async with self.session.post(
self.host + endpoint, json=json_body, params=query_params
) as response:
if not response.ok:
_raise_for_status(response.status, await response.text())

async for stream_item in response.content:
stream_item_as_str = stream_item.decode().strip()

if not stream_item_as_str:
continue

if not stream_item_as_str.startswith(self.SSE_DATA_PREFIX):
raise ValueError(
f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}"
)

yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :])

def _build_query_parameters(self) -> Mapping[str, str]:
return {
# cannot use str() here because we want lowercase true/false in query string
Expand All @@ -768,7 +808,7 @@ def _build_query_parameters(self) -> Mapping[str, str]:

def _build_json_body(
self, request: AnyRequest, model: Optional[str]
) -> Mapping[str, Any]:
) -> Dict[str, Any]:
json_body = dict(request.to_json())

if model is not None:
Expand Down Expand Up @@ -824,6 +864,47 @@ async def complete(
)
return CompletionResponse.from_json(response)

async def complete_with_streaming(
self,
request: CompletionRequest,
model: str,
) -> AsyncGenerator[CompletionResponseStreamItem, None]:
"""Generates streamed completions given a prompt.
Parameters:
request (CompletionRequest, required):
Parameters for the requested completion.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a prompt
>>> prompt = Prompt.from_text("An apple a day, ")
>>>
>>> # create a completion request
>>> request = CompletionRequest(
prompt=prompt,
maximum_tokens=32,
stop_sequences=["###","\\n"],
temperature=0.12
)
>>>
>>> # complete the prompt
>>> result = await client.complete_with_streaming(request, model=model_name)
>>>
>>> # consume the completion stream
>>> async for stream_item in result:
>>> do_something_with(stream_item)
"""
async for stream_item_json in self._post_request_with_streaming(
"complete",
request,
model,
):
yield stream_item_from_json(stream_item_json)

async def tokenize(
self,
request: TokenizationRequest,
Expand Down
125 changes: 124 additions & 1 deletion aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Mapping, Optional, Sequence
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from aleph_alpha_client.prompt import Prompt

Expand Down Expand Up @@ -301,3 +301,126 @@ def to_json(self) -> Mapping[str, Any]:

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)


CompletionResponseStreamItem = Union[
"StreamChunk", "StreamSummary", "CompletionSummary"
]


def stream_item_from_json(json: Dict[str, Any]) -> CompletionResponseStreamItem:
if json["type"] == "stream_chunk":
return StreamChunk.from_json(json)
elif json["type"] == "stream_summary":
return StreamSummary.from_json(json)
elif json["type"] == "completion_summary":
return CompletionSummary.from_json(json)
else:
raise ValueError(f"Unknown stream item type: {json['type']}")


@dataclass(frozen=True)
class StreamChunk:
"""
Describes a chunk of a completion stream
Parameters:
index:
The index of the stream that this chunk belongs to.
This is relevant if multiple completion streams are requested (see parameter n).
log_probs:
The log probabilities of the generated tokens.
completion:
The generated tokens formatted as single a string.
raw_completion:
The generated tokens including special tokens formatted as single a string.
completion_tokens:
The generated tokens as a list of strings.
"""

index: int
log_probs: Optional[Sequence[Mapping[str, Optional[float]]]]
completion: str
raw_completion: Optional[str]
completion_tokens: Optional[Sequence[str]]

@staticmethod
def from_json(json: Dict[str, Any]) -> "StreamChunk":
return StreamChunk(
index=json["index"],
log_probs=json.get("log_probs"),
completion=json["completion"],
raw_completion=json.get("raw_completion"),
completion_tokens=json.get("completion_tokens"),
)

def to_json(self) -> Mapping[str, Any]:
return asdict(self)


@dataclass(frozen=True)
class StreamSummary:
"""
Denotes the end of a completion stream
Parameters:
index:
The index of the stream that is being terminated.
This is relevant if multiple completion streams are requested (see parameter n).
model_version:
Model name and version (if any) of the used model for inference.
finish_reason:
The reason why the model stopped generating new tokens.
"""

index: int
model_version: str
finish_reason: str

@staticmethod
def from_json(json: Dict[str, Any]) -> "StreamSummary":
return StreamSummary(
index=json["index"],
model_version=json["model_version"],
finish_reason=json["finish_reason"],
)

def to_json(self) -> Mapping[str, Any]:
return asdict(self)


@dataclass(frozen=True)
class CompletionSummary:
"""
Denotes the end of all completion streams
Parameters:
optimized_prompt:
Describes prompt after optimizations. This field is only returned if the flag
`disable_optimizations` flag is not set and the prompt has actually changed.
num_tokens_prompt_total:
Number of tokens combined across all completion tasks.
In particular, if you set best_of or n to a number larger than 1 then we report the
combined prompt token count for all best_of or n tasks.
num_tokens_generated:
Number of tokens combined across all completion tasks.
If multiple completions are returned or best_of is set to a value greater than 1 then
this value contains the combined generated token count.
"""

optimized_prompt: Optional[Prompt]
num_tokens_prompt_total: int
num_tokens_generated: int

@staticmethod
def from_json(json: Dict[str, Any]) -> "CompletionSummary":
optimized_prompt_json = json.get("optimized_prompt")
return CompletionSummary(
optimized_prompt=(
Prompt.from_json(optimized_prompt_json)
if optimized_prompt_json
else None
),
num_tokens_prompt_total=json["num_tokens_prompt_total"],
num_tokens_generated=json["num_tokens_generated"],
)
29 changes: 28 additions & 1 deletion tests/test_complete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
from aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.completion import CompletionRequest
from aleph_alpha_client.completion import (
CompletionRequest,
CompletionSummary,
StreamChunk,
StreamSummary,
)
from aleph_alpha_client.prompt import (
ControlTokenOverlap,
Image,
Expand Down Expand Up @@ -32,6 +37,28 @@ async def test_can_complete_with_async_client(
assert response.model_version is not None


@pytest.mark.system_test
async def test_can_use_streaming_support_with_async_client(
async_client: AsyncClient, model_name: str
):
request = CompletionRequest(
prompt=Prompt.from_text(""),
maximum_tokens=7,
)

stream_items = [
stream_item
async for stream_item in async_client.complete_with_streaming(
request, model=model_name
)
]

assert len(stream_items) >= 3
assert isinstance(stream_items[-3], StreamChunk)
assert isinstance(stream_items[-2], StreamSummary)
assert isinstance(stream_items[-1], CompletionSummary)


@pytest.mark.system_test
def test_complete_maximum_tokens_none(sync_client: Client, model_name: str):
request = CompletionRequest(
Expand Down

0 comments on commit bf4c62c

Please sign in to comment.