Skip to content

Commit

Permalink
0.23 upstream merge fix part 4:
Browse files Browse the repository at this point in the history
* Fix stream object
* Add assistant streaming + func call example
* Fix old OpenAI chat example
  • Loading branch information
ifsheldon committed Jun 11, 2024
2 parents 72ea9c9 + c64d80b commit 1fa18fc
Show file tree
Hide file tree
Showing 15 changed files with 533 additions and 126 deletions.
98 changes: 90 additions & 8 deletions async-openai-wasm/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
) -> OpenAIEventMappedStream<O>
where
I: Serialize,
O: DeserializeOwned + Send + 'static,
O: DeserializeOwned + Send + 'static
{
let event_source = self
.http_client
Expand All @@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
.eventsource()
.unwrap();

// stream_mapped_raw_events(event_source, event_mapper).await
todo!()
OpenAIEventMappedStream::new(event_source, event_mapper)
}

/// Make HTTP GET request to receive SSE
Expand Down Expand Up @@ -491,19 +490,21 @@ impl<C: Config> Client<C> {
/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
#[pin_project]
pub struct OpenAIEventStream<O> {
pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
#[pin]
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
done: bool,
_phantom_data: PhantomData<O>,
}

impl<O> OpenAIEventStream<O> {
impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
pub(crate) fn new(event_source: EventSource) -> Self {
Self {
stream: event_source.filter(|result|
// filter out the first event which is always Event::Open
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
),
done: false,
_phantom_data: PhantomData,
}
}
Expand All @@ -514,6 +515,9 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
return Poll::Ready(None);
}
let stream: Pin<&mut _> = this.stream;
match stream.poll_next(cx) {
Poll::Ready(response) => {
Expand All @@ -524,17 +528,24 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
Event::Open => unreachable!(), // it has been filtered out
Event::Message(message) => {
if message.data == "[DONE]" {
*this.done = true;
Poll::Ready(None) // end of the stream, defined by OpenAI
} else {
// deserialize the data
match serde_json::from_str::<O>(&message.data) {
Err(e) => Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes())))),
Err(e) => {
*this.done = true;
Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes()))))
}
Ok(output) => Poll::Ready(Some(Ok(output))),
}
}
}
}
Err(e) => Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
Err(e) => {
*this.done = true;
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
}
}
}
}
Expand All @@ -543,6 +554,77 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
}
}

#[pin_project]
pub struct OpenAIEventMappedStream<O>
where O: Send + 'static
{
#[pin]
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
done: bool,
_phantom_data: PhantomData<O>,
}

impl<O> OpenAIEventMappedStream<O>
where O: Send + 'static
{
pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
where M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static {
Self {
stream: event_source.filter(|result|
// filter out the first event which is always Event::Open
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
),
done: false,
event_mapper: Box::new(event_mapper),
_phantom_data: PhantomData,
}
}
}


impl<O> Stream for OpenAIEventMappedStream<O>
where O: Send + 'static
{
type Item = Result<O, OpenAIError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
return Poll::Ready(None);
}
let stream: Pin<&mut _> = this.stream;
match stream.poll_next(cx) {
Poll::Ready(response) => {
match response {
None => Poll::Ready(None), // end of the stream
Some(result) => match result {
Ok(event) => match event {
Event::Open => unreachable!(), // it has been filtered out
Event::Message(message) => {
if message.data == "[DONE]" {
*this.done = true;
}
let response = (this.event_mapper)(message);
match response {
Ok(output) => Poll::Ready(Some(Ok(output))),
Err(_) => Poll::Ready(None)
}
}
}
Err(e) => {
*this.done = true;
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
}
}
}
}
Poll::Pending => Poll::Pending
}
}
}


// pub(crate) async fn stream_mapped_raw_events<O>(
// mut event_source: EventSource,
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
Expand Down
8 changes: 2 additions & 6 deletions async-openai-wasm/src/types/assistant_stream.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::pin::Pin;

use futures::Stream;
use serde::Deserialize;

use crate::client::OpenAIEventMappedStream;
use crate::error::{ApiError, map_deserialization_error, OpenAIError};

use super::{
Expand All @@ -28,7 +26,6 @@ use super::{
/// We may add additional events over time, so we recommend handling unknown events gracefully
/// in your code. See the [Assistants API quickstart](https://platform.openai.com/docs/assistants/overview) to learn how to
/// integrate the Assistants API with streaming.
#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "event", content = "data")]
#[non_exhaustive]
Expand Down Expand Up @@ -110,8 +107,7 @@ pub enum AssistantStreamEvent {
Done(String),
}

pub type AssistantEventStream =
Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
pub type AssistantEventStream = OpenAIEventMappedStream<AssistantStreamEvent>;

impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
type Error = OpenAIError;
Expand Down
File renamed without changes.
16 changes: 16 additions & 0 deletions examples/openai-web-app-assistant/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "openai-web-assistant-chat"
version = "0.1.0"
edition = "2021"
publish = false

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
dioxus = {version = "~0.5", features = ["web"]}
futures = "0.3.30"
async-openai-wasm = { path = "../../async-openai-wasm" }
# Debug
tracing = "0.1.40"
dioxus-logger = "~0.5"
serde_json = "1.0.117"
40 changes: 40 additions & 0 deletions examples/openai-web-app-assistant/Dioxus.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[application]

# App (Project) Name
name = "openai-web-app-assistant-dioxus"

# Dioxus App Default Platform
# desktop, web
default_platform = "web"

# `build` & `serve` dist path
out_dir = "dist"

[web.app]

# HTML title tag content
title = "openai-web-app-assistant-dioxus"

[web.watcher]

# when watcher trigger, regenerate the `index.html`
reload_html = true

# which files or dirs will be watcher monitoring
watch_path = ["src"]

# include `assets` in web platform
[web.resource]

# CSS style file

style = []

# Javascript code file
script = []

[web.resource.dev]

# Javascript code file
# serve: [dev-server] only
script = []
14 changes: 14 additions & 0 deletions examples/openai-web-app-assistant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# OpenAI Web App - Assistant

This builds a `dioxus` web App that uses OpenAI Assistant APIs to generate text.

To run it, you need:
1. Set OpenAI secrets in `./src/main.rs`. Please do NOT take this demo into production without using a secure secret store
2. Install `dioxus-cli` by `cargo install dioxus-cli`.
3. Run `dx serve`

Note: Safari may not work due to CORS issues. Please use Chrome or Edge.

## Reference

The code is adapted from [assistant-func-call-stream example in async-openai](https://github.com/64bit/async-openai/tree/main/examples/assistants-func-call-stream).
98 changes: 98 additions & 0 deletions examples/openai-web-app-assistant/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#![allow(non_snake_case)]

use dioxus::prelude::*;
use dioxus_logger::tracing::{error, info, Level};
use futures::stream::StreamExt;

use async_openai_wasm::types::{AssistantStreamEvent, CreateMessageRequest, CreateRunRequest, CreateThreadRequest, MessageRole};

use crate::utils::*;

mod utils;

pub const API_BASE: &str = "...";
pub const API_KEY: &str = "...";


pub fn App() -> Element {
const QUERY: &str = "What's the weather in San Francisco today and the likelihood it'll rain?";
let reply = use_signal(String::new);
let _run_assistant: Coroutine<()> = use_coroutine(|_rx| {
let client = get_client();
async move {
//
// Step 1: Define functions
//
let assistant = client
.assistants()
.create(create_assistant_request())
.await
.expect("failed to create assistant");
//
// Step 2: Create a Thread and add Messages
//
let thread = client
.threads()
.create(CreateThreadRequest::default())
.await
.expect("failed to create thread");
let _message = client
.threads()
.messages(&thread.id)
.create(CreateMessageRequest {
role: MessageRole::User,
content: QUERY.into(),
..Default::default()
})
.await
.expect("failed to create message");
//
// Step 3: Initiate a Run
//
let mut event_stream = client
.threads()
.runs(&thread.id)
.create_stream(CreateRunRequest {
assistant_id: assistant.id.clone(),
stream: Some(true),
..Default::default()
})
.await
.expect("failed to create run");


while let Some(event) = event_stream.next().await {
match event {
Ok(event) => match event {
AssistantStreamEvent::ThreadRunRequiresAction(run_object) => {
info!("thread.run.requires_action: run_id:{}", run_object.id);
handle_requires_action(&client, run_object, reply.to_owned()).await
}
_ => info!("\nEvent: {event:?}\n"),
},
Err(e) => {
error!("Error: {e}");
}
}
}

client.threads().delete(&thread.id).await.expect("failed to delete thread");
client.assistants().delete(&assistant.id).await.expect("failed to delete assistant");
info!("Done!");
}
});

rsx! {
div {
p { "Using OpenAI" }
p { "User: {QUERY}" }
p { "Expected Stats (Debug): temperature = {TEMPERATURE}, rain_probability = {RAIN_PROBABILITY}" }
p { "Assistant: {reply}" }
}
}
}

fn main() {
dioxus_logger::init(Level::INFO).expect("failed to init logger");
launch(App);
}
Loading

0 comments on commit 1fa18fc

Please sign in to comment.