From bf4c62c44ca1d75d9d0a349aea66a13133ec0dbb Mon Sep 17 00:00:00 2001 From: Woyten Tielesch Date: Mon, 30 Sep 2024 18:50:05 +0200 Subject: [PATCH] Add streaming support --- README.md | 5 +- aleph_alpha_client/aleph_alpha_client.py | 85 ++++++++++++++- aleph_alpha_client/completion.py | 125 ++++++++++++++++++++++- tests/test_complete.py | 29 +++++- 4 files changed, 238 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 45df4ad..0671c39 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,9 @@ async with AsyncClient(token=os.environ["AA_TOKEN"]) as client: prompt=Prompt.from_text("Provide a short description of AI:"), maximum_tokens=64, ) - response = await client.complete(request, model="luminous-base") - print(response.completions[0].completion) + response = client.complete_with_streaming(request, model="luminous-base") + async for stream_item in response: + print(stream_item) ``` ### Interactive Examples diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index a039e9b..1aaace9 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -1,3 +1,4 @@ +import json import warnings from packaging import version @@ -5,6 +6,7 @@ from types import TracebackType from typing import ( Any, + AsyncGenerator, List, Mapping, Optional, @@ -30,7 +32,13 @@ ) from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse from aleph_alpha_client.qa import QaRequest, QaResponse -from aleph_alpha_client.completion import CompletionRequest, CompletionResponse +from aleph_alpha_client.completion import ( + CompletionRequest, + CompletionResponse, + CompletionResponseStreamItem, + StreamChunk, + stream_item_from_json, +) from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -759,6 +767,38 @@ async def _post_request( _raise_for_status(response.status, await response.text()) return await response.json() + SSE_DATA_PREFIX = "data: " + + async def _post_request_with_streaming( + self, + endpoint: str, + request: AnyRequest, + model: Optional[str] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + json_body = self._build_json_body(request, model) + json_body["stream"] = True + + query_params = self._build_query_parameters() + + async with self.session.post( + self.host + endpoint, json=json_body, params=query_params + ) as response: + if not response.ok: + _raise_for_status(response.status, await response.text()) + + async for stream_item in response.content: + stream_item_as_str = stream_item.decode().strip() + + if not stream_item_as_str: + continue + + if not stream_item_as_str.startswith(self.SSE_DATA_PREFIX): + raise ValueError( + f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}" + ) + + yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :]) + def _build_query_parameters(self) -> Mapping[str, str]: return { # cannot use str() here because we want lowercase true/false in query string @@ -768,7 +808,7 @@ def _build_query_parameters(self) -> Mapping[str, str]: def _build_json_body( self, request: AnyRequest, model: Optional[str] - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: json_body = dict(request.to_json()) if model is not None: @@ -824,6 +864,47 @@ async def complete( ) return CompletionResponse.from_json(response) + async def complete_with_streaming( + self, + request: CompletionRequest, + model: str, + ) -> AsyncGenerator[CompletionResponseStreamItem, None]: + """Generates streamed completions given a prompt. + + Parameters: + request (CompletionRequest, required): + Parameters for the requested completion. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a prompt + >>> prompt = Prompt.from_text("An apple a day, ") + >>> + >>> # create a completion request + >>> request = CompletionRequest( + prompt=prompt, + maximum_tokens=32, + stop_sequences=["###","\\n"], + temperature=0.12 + ) + >>> + >>> # complete the prompt + >>> result = await client.complete_with_streaming(request, model=model_name) + >>> + >>> # consume the completion stream + >>> async for stream_item in result: + >>> do_something_with(stream_item) + """ + async for stream_item_json in self._post_request_with_streaming( + "complete", + request, + model, + ): + yield stream_item_from_json(stream_item_json) + async def tokenize( self, request: TokenizationRequest, diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index e49b92b..6b8aff5 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union from aleph_alpha_client.prompt import Prompt @@ -301,3 +301,126 @@ def to_json(self) -> Mapping[str, Any]: def _asdict(self) -> Mapping[str, Any]: return asdict(self) + + +CompletionResponseStreamItem = Union[ + "StreamChunk", "StreamSummary", "CompletionSummary" +] + + +def stream_item_from_json(json: Dict[str, Any]) -> CompletionResponseStreamItem: + if json["type"] == "stream_chunk": + return StreamChunk.from_json(json) + elif json["type"] == "stream_summary": + return StreamSummary.from_json(json) + elif json["type"] == "completion_summary": + return CompletionSummary.from_json(json) + else: + raise ValueError(f"Unknown stream item type: {json['type']}") + + +@dataclass(frozen=True) +class StreamChunk: + """ + Describes a chunk of a completion stream + + Parameters: + index: + The index of the stream that this chunk belongs to. + This is relevant if multiple completion streams are requested (see parameter n). + log_probs: + The log probabilities of the generated tokens. + completion: + The generated tokens formatted as single a string. + raw_completion: + The generated tokens including special tokens formatted as single a string. + completion_tokens: + The generated tokens as a list of strings. + """ + + index: int + log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] + completion: str + raw_completion: Optional[str] + completion_tokens: Optional[Sequence[str]] + + @staticmethod + def from_json(json: Dict[str, Any]) -> "StreamChunk": + return StreamChunk( + index=json["index"], + log_probs=json.get("log_probs"), + completion=json["completion"], + raw_completion=json.get("raw_completion"), + completion_tokens=json.get("completion_tokens"), + ) + + def to_json(self) -> Mapping[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class StreamSummary: + """ + Denotes the end of a completion stream + + Parameters: + index: + The index of the stream that is being terminated. + This is relevant if multiple completion streams are requested (see parameter n). + model_version: + Model name and version (if any) of the used model for inference. + finish_reason: + The reason why the model stopped generating new tokens. + """ + + index: int + model_version: str + finish_reason: str + + @staticmethod + def from_json(json: Dict[str, Any]) -> "StreamSummary": + return StreamSummary( + index=json["index"], + model_version=json["model_version"], + finish_reason=json["finish_reason"], + ) + + def to_json(self) -> Mapping[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class CompletionSummary: + """ + Denotes the end of all completion streams + + Parameters: + optimized_prompt: + Describes prompt after optimizations. This field is only returned if the flag + `disable_optimizations` flag is not set and the prompt has actually changed. + num_tokens_prompt_total: + Number of tokens combined across all completion tasks. + In particular, if you set best_of or n to a number larger than 1 then we report the + combined prompt token count for all best_of or n tasks. + num_tokens_generated: + Number of tokens combined across all completion tasks. + If multiple completions are returned or best_of is set to a value greater than 1 then + this value contains the combined generated token count. + """ + + optimized_prompt: Optional[Prompt] + num_tokens_prompt_total: int + num_tokens_generated: int + + @staticmethod + def from_json(json: Dict[str, Any]) -> "CompletionSummary": + optimized_prompt_json = json.get("optimized_prompt") + return CompletionSummary( + optimized_prompt=( + Prompt.from_json(optimized_prompt_json) + if optimized_prompt_json + else None + ), + num_tokens_prompt_total=json["num_tokens_prompt_total"], + num_tokens_generated=json["num_tokens_generated"], + ) diff --git a/tests/test_complete.py b/tests/test_complete.py index 6c21063..62c8e72 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -1,6 +1,11 @@ import pytest from aleph_alpha_client import AsyncClient, Client -from aleph_alpha_client.completion import CompletionRequest +from aleph_alpha_client.completion import ( + CompletionRequest, + CompletionSummary, + StreamChunk, + StreamSummary, +) from aleph_alpha_client.prompt import ( ControlTokenOverlap, Image, @@ -32,6 +37,28 @@ async def test_can_complete_with_async_client( assert response.model_version is not None +@pytest.mark.system_test +async def test_can_use_streaming_support_with_async_client( + async_client: AsyncClient, model_name: str +): + request = CompletionRequest( + prompt=Prompt.from_text(""), + maximum_tokens=7, + ) + + stream_items = [ + stream_item + async for stream_item in async_client.complete_with_streaming( + request, model=model_name + ) + ] + + assert len(stream_items) >= 3 + assert isinstance(stream_items[-3], StreamChunk) + assert isinstance(stream_items[-2], StreamSummary) + assert isinstance(stream_items[-1], CompletionSummary) + + @pytest.mark.system_test def test_complete_maximum_tokens_none(sync_client: Client, model_name: str): request = CompletionRequest(