Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces Streaming #25

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from substrate import Substrate, GenerateText
from substrate import Substrate, GenerateText, sb

substrate = Substrate(api_key=api_key, timeout=60 * 5)

story = GenerateText(prompt="tell me a story")
# summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text))
summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text))

# response = substrate.run(story, summary)
response = substrate.run(story)
response = substrate.run(story, summary)
print(response)

summary_out = response.get(story)
print(summary_out.text)
print("=== story")
story_out = response.get(story)
print(story_out.text)

print("=== summary")
summary_out = response.get(summary)

# viz = Substrate.visualize(ry)
# print(viz)
38 changes: 38 additions & 0 deletions examples/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import sys
import asyncio
from pathlib import Path

# add parent dir to sys.path to make 'substrate' importable
parent_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(parent_dir))

api_key = os.environ.get("SUBSTRATE_API_KEY")
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from substrate import Substrate, GenerateText, sb

substrate = Substrate(api_key=api_key, timeout=60 * 5)


a = GenerateText(prompt="tell me about windmills", max_tokens=10)
b = GenerateText(prompt=sb.concat("is this true? ", a.future.text), max_tokens=10)


async def amain():
response = await substrate.async_stream(a, b)
async for event in response.async_iter():
print(event)


asyncio.run(amain())


def main():
response = substrate.stream(a, b)
for message in response.iter():
print(message)


main()
29 changes: 29 additions & 0 deletions examples/streaming/fastapi-example/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import sys
from pathlib import Path

from substrate.nodes import Llama3Instruct8B

# add parent dir to sys.path to make 'substrate' importable
parent_dir = Path(__file__).resolve().parent.parent.parent.parent
sys.path.insert(0, str(parent_dir))

api_key = os.environ.get("SUBSTRATE_API_KEY")
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from fastapi import FastAPI
from fastapi.responses import StreamingResponse

from substrate import Substrate, Llama3Instruct8B

app = FastAPI()
substrate = Substrate(api_key=api_key, timeout=60 * 5)

@app.get("/qotd")
def quote_of_the_day():
quote = Llama3Instruct8B(prompt="What's an inspirational quote of the day?")

response = substrate.stream(quote)

return StreamingResponse(response.iter_events(), media_type="text/event-stream")
564 changes: 562 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ httpx = ">=0.26.0"
distro = ">=1.8.0"
typing-extensions = "^4.10.0"
pydantic = ">=1.0.0"
httpx-sse = "^0.4.0"

[tool.ruff.lint]
ignore-init-module-imports = true
Expand Down Expand Up @@ -80,3 +81,4 @@ toml = "^0.10.2"
marimo = ">=0.3.2"
pre-commit = "^3.6.2"
twine = "^5.0.0"
fastapi = "^0.111.0"
31 changes: 31 additions & 0 deletions substrate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import httpx
import distro
import httpx_sse

from ._version import __version__
from .core.id_generator import IDGenerator
Expand Down Expand Up @@ -172,6 +173,13 @@ def default_headers(self) -> Dict[str, str]:
**self._additional_headers,
}

@property
def streaming_headers(self) -> Dict[str, str]:
headers = self.default_headers
headers["Accept"] = "text/event-stream"
headers["X-Substrate-Streaming"] = "1"
return headers

def post_compose(self, dag: Dict[str, Any]) -> APIResponse:
url = f"{self._base_url}/compose"
body = {"dag": dag}
Expand All @@ -190,6 +198,29 @@ def post_compose(self, dag: Dict[str, Any]) -> APIResponse:
)
return res

def post_compose_streaming(self, dag: Dict[str, Any]):
url = f"{self._base_url}/compose"
body = {"dag": dag}

def iterator():
with httpx.Client(timeout=self._timeout, follow_redirects=True) as client:
with httpx_sse.connect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source:
for sse in event_source.iter_sse():
yield sse
return iterator()


async def async_post_compose_streaming(self, dag: Dict[str, Any]):
url = f"{self._base_url}/compose"
body = {"dag": dag}

async def iterator():
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
async with httpx_sse.aconnect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source:
async for sse in event_source.aiter_sse():
yield sse
return iterator()

async def async_post_compose(self, dag: Dict[str, Any]) -> APIResponse:
url = f"{self._base_url}/compose"
body = {"dag": dag}
Expand Down
49 changes: 49 additions & 0 deletions substrate/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
from typing import Iterator, AsyncIterator

import httpx_sse


class ServerSentEvent:
def __init__(self, event: httpx_sse.ServerSentEvent):
self.event = event

@property
def data(self):
return json.loads(self.event.data)

def __repr__(self):
return self.event.__repr__()

def __str__(self):
"""
Render the Server-Sent Event as a string to be rendered in a streaming response
"""
fields = ["id", "event", "data", "retry"]
lines = [f"{field}: {getattr(self.event, field)}" for field in fields if getattr(self.event, field)]
return "\n".join(lines) + "\n\n"


class SubstrateStreamingResponse:
"""
Substrate stream response.
"""

def __init__(self, *, iterator):
self.iterator = iterator

def iter(self) -> Iterator[ServerSentEvent]:
for sse in self.iterator:
yield ServerSentEvent(sse)

def iter_events(self) -> Iterator[str]:
for sse in self.iterator:
yield str(ServerSentEvent(sse))

async def async_iter(self) -> Iterator[ServerSentEvent]:
async for sse in self.iterator:
yield ServerSentEvent(sse)

async def async_iter_events(self) -> AsyncIterator[str]:
async for sse in self.iterator:
yield str(ServerSentEvent(sse))
19 changes: 19 additions & 0 deletions substrate/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import base64
from typing import Any, Dict

from substrate.streaming import SubstrateStreamingResponse

from ._client import APIClient
from .core.corenode import CoreNode
from .core.client.graph import Graph
Expand Down Expand Up @@ -48,6 +50,22 @@ async def async_run(self, *nodes: CoreNode) -> SubstrateResponse:
api_response = await self._client.async_post_compose(dag=serialized)
return SubstrateResponse(api_response=api_response)

def stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse:
"""
Run the given nodes and receive results as Server-Sent Events.
"""
serialized = Substrate.serialize(*nodes)
iterator = self._client.post_compose_streaming(dag=serialized)
return SubstrateStreamingResponse(iterator=iterator)

async def async_stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse:
"""
Run the given nodes and receive results as Server-Sent Events.
"""
serialized = Substrate.serialize(*nodes)
iterator = await self._client.async_post_compose_streaming(dag=serialized)
return SubstrateStreamingResponse(iterator=iterator)

@staticmethod
def visualize(*nodes):
"""
Expand All @@ -67,6 +85,7 @@ def serialize(*nodes):
"""

all_nodes = set()

def collect_nodes(node):
all_nodes.add(node)
for referenced_node in node.referenced_nodes:
Expand Down
15 changes: 13 additions & 2 deletions tests/pydantic-1/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/pydantic-1/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pydantic = "<2.0.0"
networkx = "^3.3"
httpx = "^0.27.0"
distro = "^1.9.0"
httpx-sse = "^0.4.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand Down
15 changes: 13 additions & 2 deletions tests/pydantic-2/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/pydantic-2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pydantic = "^2.0.0"
networkx = "^3.3"
httpx = "^0.27.0"
distro = "^1.9.0"
httpx-sse = "^0.4.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand Down
15 changes: 13 additions & 2 deletions tests/python-3-9/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/python-3-9/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pydantic = "^2.7.1"
networkx = "3.2.1"
httpx = "^0.27.0"
distro = "^1.9.0"
httpx-sse = "^0.4.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand Down
Loading