Skip to content

Commit

Permalink
feat(http): support openai /v1/completions streaming interface (#1373)
Browse files Browse the repository at this point in the history
* switch goldentest streaming

* support streaming in openai
  • Loading branch information
wsxiaoys authored Feb 4, 2024
1 parent cb8faed commit 9a0e8d3
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 59 deletions.
37 changes: 36 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ axum = "0.6"
hyper = "0.14"
juniper = "0.15"
chrono = "0.4"
reqwest-eventsource = "0.5.0"

[workspace.dependencies.uuid]
version = "1.3.3"
Expand Down
3 changes: 2 additions & 1 deletion crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ version = "0.8.0"
edition = "2021"

[dependencies]
anyhow.workspace = true
async-stream.workspace = true
async-trait.workspace = true
futures.workspace = true
reqwest = { workspace = true, features = ["json"] }
reqwest-eventsource.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tabby-inference = { version = "0.8.0", path = "../tabby-inference" }
Expand Down
71 changes: 31 additions & 40 deletions crates/http-api-bindings/src/openai.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::{anyhow, Result};
use async_stream::stream;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::header;
use reqwest_eventsource::{Error, Event, EventSource};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, TextGenerationStream};
use tabby_inference::{TextGenerationOptions, TextGenerationStream};
use tracing::warn;

#[derive(Serialize)]
Expand All @@ -13,7 +13,7 @@ struct Request {
prompt: Vec<String>,
max_tokens: usize,
temperature: f32,
stop: Vec<String>,
stream: bool,
}

#[derive(Deserialize)]
Expand Down Expand Up @@ -52,53 +52,44 @@ impl OpenAIEngine {
client,
}
}
}

async fn generate_impl(&self, prompt: &str, options: TextGenerationOptions) -> Result<String> {
// OpenAI's API usually handles stop words in an O(n) manner, so we just use a single stop word here.
// FIXME(meng): consider improving this for some external vendors, e.g vLLM.
let stop = vec!["\n\n".to_owned()];

#[async_trait]
impl TextGenerationStream for OpenAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
let request = Request {
model: self.model_name.to_owned(),
prompt: vec![prompt.to_string()],
max_tokens: options.max_decoding_length,
temperature: options.sampling_temperature,
stop,
stream: true,
};

let es = EventSource::new(self.client.post(&self.api_endpoint).json(&request));
// API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md
let resp = self
.client
.post(&self.api_endpoint)
.json(&request)
.send()
.await?;

if resp.status() != 200 {
let err: Value = resp.json().await.expect("Failed to parse response");
return Err(anyhow!("Request failed: {}", err));
}

let resp: Response = resp.json().await.expect("Failed to parse response");

Ok(resp.choices[0].text.clone())
}
let s = stream! {
let Ok(es) = es else {
warn!("Failed to access api_endpoint: {}", &self.api_endpoint);
return;
};

// FIXME(meng): migrate to streaming implementation
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
match self.generate_impl(prompt, options).await {
Ok(output) => output,
Err(err) => {
warn!("Failed to generate completion: `{}`", err);
String::new()
for await event in es {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let x: Response = serde_json::from_str(&message.data).unwrap();
yield x.choices[0].text.clone();
}
Err(Error::StreamEnded) => {
break;
},
Err(err) => {
warn!("Failed to start streaming: {}", err);
}
};
}
}
}
}
};

#[async_trait]
impl TextGenerationStream for OpenAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
helpers::string_to_stream(self.generate(prompt, options).await).await
Box::pin(s)
}
}
1 change: 1 addition & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ assert-json-diff = "2.0.2"
insta = { version = "1.34.0", features = ["yaml", "redactions"] }
reqwest.workspace = true
serde-jsonlines = "0.5.0"
reqwest-eventsource = { workspace = true }

[package.metadata.cargo-machete]
ignored = ["openssl"]
8 changes: 7 additions & 1 deletion crates/tabby/src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use axum::{
body::StreamBody,
extract::State,
http::HeaderValue,
response::{IntoResponse, Response},
Json,
};
Expand Down Expand Up @@ -44,5 +45,10 @@ pub async fn chat_completions(
Ok(s) => Ok(format!("data: {s}\n\n")),
Err(e) => Err(anyhow::Error::from(e)),
});
StreamBody::new(s).into_response()
let mut resp = StreamBody::new(s).into_response();
resp.headers_mut().append(
"Content-Type",
HeaderValue::from_str("text/event-stream").unwrap(),
);
resp
}
36 changes: 20 additions & 16 deletions crates/tabby/tests/goldentests_chat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{io::BufRead, path::PathBuf};
use std::path::PathBuf;

use futures::StreamExt;
use lazy_static::lazy_static;
use reqwest_eventsource::{Event, EventSource};
use serde::Deserialize;
use serde_json::json;
use serial_test::serial;
Expand Down Expand Up @@ -83,23 +85,25 @@ async fn wait_for_server(gpu_device: Option<&str>) {
}

async fn golden_test(body: serde_json::Value) -> String {
let bytes = CLIENT
.post("http://localhost:9090/v1beta/chat/completions")
.json(&body)
.send()
.await
.unwrap()
.bytes()
.await
.unwrap();
let mut es = EventSource::new(
CLIENT
.post("http://localhost:9090/v1beta/chat/completions")
.json(&body),
)
.unwrap();

let mut actual = "".to_owned();
for x in bytes.lines() {
let content = x.unwrap();
if content.starts_with("data:") {
let content = content.strip_prefix("data:").unwrap();
let x: ChatCompletionChunk = serde_json::from_str(content).unwrap();
actual += &x.choices[0].delta.content;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let x: ChatCompletionChunk = serde_json::from_str(&message.data).unwrap();
actual += &x.choices[0].delta.content;
}
Err(_) => {
// StreamEnd
break;
}
}
}

Expand Down

0 comments on commit 9a0e8d3

Please sign in to comment.