From 92bc89bd982e00cd2d0982b3c8dd0a7ed76fc113 Mon Sep 17 00:00:00 2001 From: Liam Griffiths Date: Tue, 18 Jun 2024 15:40:18 -0400 Subject: [PATCH] wip --- examples/basic.py | 12 ++++++++---- examples/streaming.py | 38 ++++++++++++++++++++++++++++++++++++++ poetry.lock | 15 +++++++++++++-- pyproject.toml | 1 + substrate/_client.py | 31 +++++++++++++++++++++++++++++++ substrate/streaming.py | 41 +++++++++++++++++++++++++++++++++++++++++ substrate/substrate.py | 19 +++++++++++++++++++ 7 files changed, 151 insertions(+), 6 deletions(-) create mode 100755 examples/streaming.py create mode 100644 substrate/streaming.py diff --git a/examples/basic.py b/examples/basic.py index e4abbf0..02cdb97 100755 --- a/examples/basic.py +++ b/examples/basic.py @@ -10,17 +10,21 @@ if api_key is None: raise EnvironmentError("No SUBSTRATE_API_KEY set") -from substrate import Substrate, GenerateText +from substrate import Substrate, GenerateText, sb substrate = Substrate(api_key=api_key, timeout=60 * 5) story = GenerateText(prompt="tell me a story") -# summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text)) +summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text)) -# response = substrate.run(story, summary) -response = substrate.run(story) +response = substrate.run(story, summary) print(response) +print("=== story") +story_out = response.get(story) +print(story_out.text) + +print("=== summary") summary_out = response.get(summary) print(summary_out.text) diff --git a/examples/streaming.py b/examples/streaming.py new file mode 100755 index 0000000..03cad10 --- /dev/null +++ b/examples/streaming.py @@ -0,0 +1,38 @@ +import os +import sys +import asyncio +from pathlib import Path + +# add parent dir to sys.path to make 'substrate' importable +parent_dir = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(parent_dir)) + +api_key = os.environ.get("SUBSTRATE_API_KEY") +if api_key is None: + raise EnvironmentError("No SUBSTRATE_API_KEY set") + +from substrate import Substrate, GenerateText + +substrate = Substrate(api_key=api_key, timeout=60 * 5) + + +a = GenerateText(prompt="tell me about windmills", max_tokens=10) +b = GenerateText(prompt="tell me more about cereal", max_tokens=10) + + +async def amain(): + response = await substrate.async_stream(a, b) + async for event in response.async_iter_events(): + print(event) + + +asyncio.run(amain()) + + +def main(): + response = substrate.stream(a, b) + for message in response.iter_events(): + print(message) + + +main() diff --git a/poetry.lock b/poetry.lock index bb1e7f1..5819257 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -488,6 +488,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "identify" version = "2.5.36" @@ -1668,4 +1679,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9" -content-hash = "eeb3c8e7b79916e553f863169d130bde5dc45a8841f5a35dcea65f85c51711fe" +content-hash = "a4e20c643ecda148c6406d6699c83b656818038c2b86e295a5aef7eea46b85e5" diff --git a/pyproject.toml b/pyproject.toml index df0863f..b730bc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ httpx = ">=0.26.0" distro = ">=1.8.0" typing-extensions = "^4.10.0" pydantic = ">=1.0.0" +httpx-sse = "^0.4.0" [tool.ruff.lint] ignore-init-module-imports = true diff --git a/substrate/_client.py b/substrate/_client.py index 9023be8..3249d79 100644 --- a/substrate/_client.py +++ b/substrate/_client.py @@ -5,6 +5,7 @@ import httpx import distro +import httpx_sse from ._version import __version__ from .core.id_generator import IDGenerator @@ -172,6 +173,13 @@ def default_headers(self) -> Dict[str, str]: **self._additional_headers, } + @property + def streaming_headers(self) -> Dict[str, str]: + headers = self.default_headers + headers["Accept"] = "text/event-stream" + headers["X-Substrate-Streaming"] = "1" + return headers + def post_compose(self, dag: Dict[str, Any]) -> APIResponse: url = f"{self._base_url}/compose" body = {"dag": dag} @@ -190,6 +198,29 @@ def post_compose(self, dag: Dict[str, Any]) -> APIResponse: ) return res + def post_compose_streaming(self, dag: Dict[str, Any]): + url = f"{self._base_url}/compose" + body = {"dag": dag} + + def iterator(): + with httpx.Client(timeout=self._timeout, follow_redirects=True) as client: + with httpx_sse.connect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source: + for sse in event_source.iter_sse(): + yield sse + return iterator() + + + async def async_post_compose_streaming(self, dag: Dict[str, Any]): + url = f"{self._base_url}/compose" + body = {"dag": dag} + + async def iterator(): + async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client: + async with httpx_sse.aconnect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source: + async for sse in event_source.aiter_sse(): + yield sse + return iterator() + async def async_post_compose(self, dag: Dict[str, Any]) -> APIResponse: url = f"{self._base_url}/compose" body = {"dag": dag} diff --git a/substrate/streaming.py b/substrate/streaming.py new file mode 100644 index 0000000..7e8166c --- /dev/null +++ b/substrate/streaming.py @@ -0,0 +1,41 @@ +import json +from typing import Iterator, AsyncIterator + +import httpx_sse + + +class ServerSentEvent: + def __init__(self, event: httpx_sse.ServerSentEvent): + self.event = event + + @property + def data(self): + return json.loads(self.event.data) + + def __repr__(self): + return self.event.__repr__() + + def __str__(self): + """ + Render the Server-Sent Event as a string to be rendered in a streaming response + """ + fields = ["id", "event", "data", "retry"] + lines = [f"{field}: {getattr(self.event, field)}" for field in fields if getattr(self.event, field)] + return "\n".join(lines) + "\n" + + +class SubstrateStreamingResponse: + """ + Substrate stream response. + """ + + def __init__(self, *, iterator): + self.iterator = iterator + + def iter_events(self) -> Iterator[ServerSentEvent]: + for sse in self.iterator: + yield ServerSentEvent(sse) + + async def async_iter_events(self) -> AsyncIterator[ServerSentEvent]: + async for sse in self.iterator: + yield ServerSentEvent(sse) diff --git a/substrate/substrate.py b/substrate/substrate.py index 5b04f96..aa5c0e6 100644 --- a/substrate/substrate.py +++ b/substrate/substrate.py @@ -3,6 +3,8 @@ import base64 from typing import Any, Dict +from substrate.streaming import SubstrateStreamingResponse + from ._client import APIClient from .core.corenode import CoreNode from .core.client.graph import Graph @@ -48,6 +50,22 @@ async def async_run(self, *nodes: CoreNode) -> SubstrateResponse: api_response = await self._client.async_post_compose(dag=serialized) return SubstrateResponse(api_response=api_response) + def stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse: + """ + Run the given nodes and receive results as Server-Sent Events. + """ + serialized = Substrate.serialize(*nodes) + iterator = self._client.post_compose_streaming(dag=serialized) + return SubstrateStreamingResponse(iterator=iterator) + + async def async_stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse: + """ + Run the given nodes and receive results as Server-Sent Events. + """ + serialized = Substrate.serialize(*nodes) + iterator = await self._client.async_post_compose_streaming(dag=serialized) + return SubstrateStreamingResponse(iterator=iterator) + @staticmethod def visualize(*nodes): """ @@ -67,6 +85,7 @@ def serialize(*nodes): """ all_nodes = set() + def collect_nodes(node): all_nodes.add(node) for referenced_node in node.referenced_nodes: