diff --git a/Cargo.lock b/Cargo.lock index 7ec175fcce51..7ad187e9d0bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1095,6 +1095,17 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom 7.1.3", + "pin-project-lite", +] + [[package]] name = "fastdivide" version = "0.4.0" @@ -1310,6 +1321,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.29" @@ -1585,10 +1602,11 @@ dependencies = [ name = "http-api-bindings" version = "0.8.0" dependencies = [ - "anyhow", + "async-stream", "async-trait", "futures", "reqwest", + "reqwest-eventsource", "serde", "serde_json", "tabby-inference", @@ -3297,6 +3315,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f529a5ff327743addc322af460761dff5b50e0c826b9e6ac44c3195c50bb2026" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom 7.1.3", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "ring" version = "0.17.5" @@ -4256,6 +4290,7 @@ dependencies = [ "opentelemetry-otlp", "regex", "reqwest", + "reqwest-eventsource", "serde", "serde-jsonlines 0.5.0", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index a618d5a59696..c97a55e37d20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ axum = "0.6" hyper = "0.14" juniper = "0.15" chrono = "0.4" +reqwest-eventsource = "0.5.0" [workspace.dependencies.uuid] version = "1.3.3" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index f996842e3a8c..703cd5900371 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -4,10 +4,11 @@ version = "0.8.0" edition = "2021" [dependencies] -anyhow.workspace = true +async-stream.workspace = true async-trait.workspace = true futures.workspace = true reqwest = { workspace = true, features = ["json"] } +reqwest-eventsource.workspace = true serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tabby-inference = { version = "0.8.0", path = "../tabby-inference" } diff --git a/crates/http-api-bindings/src/openai.rs b/crates/http-api-bindings/src/openai.rs index 6c58b4b82af6..d94cc1aa809d 100644 --- a/crates/http-api-bindings/src/openai.rs +++ b/crates/http-api-bindings/src/openai.rs @@ -1,10 +1,10 @@ -use anyhow::{anyhow, Result}; +use async_stream::stream; use async_trait::async_trait; use futures::stream::BoxStream; use reqwest::header; +use reqwest_eventsource::{Error, Event, EventSource}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, TextGenerationStream}; +use tabby_inference::{TextGenerationOptions, TextGenerationStream}; use tracing::warn; #[derive(Serialize)] @@ -13,7 +13,7 @@ struct Request { prompt: Vec, max_tokens: usize, temperature: f32, - stop: Vec, + stream: bool, } #[derive(Deserialize)] @@ -52,53 +52,44 @@ impl OpenAIEngine { client, } } +} - async fn generate_impl(&self, prompt: &str, options: TextGenerationOptions) -> Result { - // OpenAI's API usually handles stop words in an O(n) manner, so we just use a single stop word here. - // FIXME(meng): consider improving this for some external vendors, e.g vLLM. - let stop = vec!["\n\n".to_owned()]; - +#[async_trait] +impl TextGenerationStream for OpenAIEngine { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { let request = Request { model: self.model_name.to_owned(), prompt: vec![prompt.to_string()], max_tokens: options.max_decoding_length, temperature: options.sampling_temperature, - stop, + stream: true, }; + let es = EventSource::new(self.client.post(&self.api_endpoint).json(&request)); // API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md - let resp = self - .client - .post(&self.api_endpoint) - .json(&request) - .send() - .await?; - - if resp.status() != 200 { - let err: Value = resp.json().await.expect("Failed to parse response"); - return Err(anyhow!("Request failed: {}", err)); - } - - let resp: Response = resp.json().await.expect("Failed to parse response"); - - Ok(resp.choices[0].text.clone()) - } + let s = stream! { + let Ok(es) = es else { + warn!("Failed to access api_endpoint: {}", &self.api_endpoint); + return; + }; - // FIXME(meng): migrate to streaming implementation - async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - match self.generate_impl(prompt, options).await { - Ok(output) => output, - Err(err) => { - warn!("Failed to generate completion: `{}`", err); - String::new() + for await event in es { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + let x: Response = serde_json::from_str(&message.data).unwrap(); + yield x.choices[0].text.clone(); + } + Err(Error::StreamEnded) => { + break; + }, + Err(err) => { + warn!("Failed to start streaming: {}", err); + } + }; } - } - } -} + }; -#[async_trait] -impl TextGenerationStream for OpenAIEngine { - async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { - helpers::string_to_stream(self.generate(prompt, options).await).await + Box::pin(s) } } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 0af7c66d1e77..e7277dbd615a 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -70,6 +70,7 @@ assert-json-diff = "2.0.2" insta = { version = "1.34.0", features = ["yaml", "redactions"] } reqwest.workspace = true serde-jsonlines = "0.5.0" +reqwest-eventsource = { workspace = true } [package.metadata.cargo-machete] ignored = ["openssl"] diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index 05a7431edbf2..0e5286e6e338 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::{ body::StreamBody, extract::State, + http::HeaderValue, response::{IntoResponse, Response}, Json, }; @@ -44,5 +45,10 @@ pub async fn chat_completions( Ok(s) => Ok(format!("data: {s}\n\n")), Err(e) => Err(anyhow::Error::from(e)), }); - StreamBody::new(s).into_response() + let mut resp = StreamBody::new(s).into_response(); + resp.headers_mut().append( + "Content-Type", + HeaderValue::from_str("text/event-stream").unwrap(), + ); + resp } diff --git a/crates/tabby/tests/goldentests_chat.rs b/crates/tabby/tests/goldentests_chat.rs index 4e760b18878f..6ce8eeba9485 100644 --- a/crates/tabby/tests/goldentests_chat.rs +++ b/crates/tabby/tests/goldentests_chat.rs @@ -1,6 +1,8 @@ -use std::{io::BufRead, path::PathBuf}; +use std::path::PathBuf; +use futures::StreamExt; use lazy_static::lazy_static; +use reqwest_eventsource::{Event, EventSource}; use serde::Deserialize; use serde_json::json; use serial_test::serial; @@ -83,23 +85,25 @@ async fn wait_for_server(gpu_device: Option<&str>) { } async fn golden_test(body: serde_json::Value) -> String { - let bytes = CLIENT - .post("http://localhost:9090/v1beta/chat/completions") - .json(&body) - .send() - .await - .unwrap() - .bytes() - .await - .unwrap(); + let mut es = EventSource::new( + CLIENT + .post("http://localhost:9090/v1beta/chat/completions") + .json(&body), + ) + .unwrap(); let mut actual = "".to_owned(); - for x in bytes.lines() { - let content = x.unwrap(); - if content.starts_with("data:") { - let content = content.strip_prefix("data:").unwrap(); - let x: ChatCompletionChunk = serde_json::from_str(content).unwrap(); - actual += &x.choices[0].delta.content; + while let Some(event) = es.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + let x: ChatCompletionChunk = serde_json::from_str(&message.data).unwrap(); + actual += &x.choices[0].delta.content; + } + Err(_) => { + // StreamEnd + break; + } } }