Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(inference): use TextGenerationStream trait when possible. #1927

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
use openai::OpenAIEngine;
use openai_chat::OpenAIChatEngine;
use serde_json::Value;
use tabby_inference::{chat::ChatCompletionStream, make_text_generation, TextGeneration};
use tabby_inference::{chat::ChatCompletionStream, TextGenerationStream};

pub fn create(model: &str) -> (Arc<dyn TextGeneration>, Option<String>, Option<String>) {
pub fn create(model: &str) -> (impl TextGenerationStream, Option<String>, Option<String>) {

Check warning on line 11 in crates/http-api-bindings/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/lib.rs#L11

Added line #L11 was not covered by tests
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
Expand All @@ -17,9 +17,8 @@
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine =
make_text_generation(OpenAIEngine::create(&api_endpoint, &model_name, api_key));
(Arc::new(engine), prompt_template, chat_template)
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(engine, prompt_template, chat_template)

Check warning on line 21 in crates/http-api-bindings/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/lib.rs#L20-L21

Added lines #L20 - L21 were not covered by tests
} else {
panic!("Only openai are supported for http completion");
}
Expand Down
40 changes: 20 additions & 20 deletions crates/tabby-inference/src/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,6 @@
stop_condition_factory: StopConditionFactory::default(),
}
}
}

#[async_trait]
impl<T: TextGenerationStream> TextGeneration for TextGenerationImpl<T> {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let prompt = prompt.to_owned();
let s = stream! {
for await (streaming, text) in self.generate_stream(&prompt, options).await {
if !streaming {
yield text;
}
}
};

if let Some(text) = Box::pin(s).into_future().await.0 {
text
} else {
String::new()
}
}

async fn generate_stream(
&self,
Expand Down Expand Up @@ -71,3 +51,23 @@
Box::pin(s)
}
}

#[async_trait]
impl<T: TextGenerationStream> TextGeneration for TextGenerationImpl<T> {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let prompt = prompt.to_owned();
let s = stream! {
for await (streaming, text) in self.generate_stream(&prompt, options).await {
if !streaming {
yield text;
}

Check warning on line 63 in crates/tabby-inference/src/imp.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-inference/src/imp.rs#L57-L63

Added lines #L57 - L63 were not covered by tests
}
};

if let Some(text) = Box::pin(s).into_future().await.0 {
text

Check warning on line 68 in crates/tabby-inference/src/imp.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-inference/src/imp.rs#L67-L68

Added lines #L67 - L68 were not covered by tests
} else {
String::new()

Check warning on line 70 in crates/tabby-inference/src/imp.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-inference/src/imp.rs#L70

Added line #L70 was not covered by tests
}
}

Check warning on line 72 in crates/tabby-inference/src/imp.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-inference/src/imp.rs#L72

Added line #L72 was not covered by tests
}
12 changes: 7 additions & 5 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String>;
}

#[async_trait]
impl TextGenerationStream for Box<dyn TextGenerationStream> {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
self.as_ref().generate(prompt, options).await
}

Check warning on line 48 in crates/tabby-inference/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-inference/src/lib.rs#L46-L48

Added lines #L46 - L48 were not covered by tests
}

#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<(bool, String)>;
}

pub fn make_text_generation(imp: impl TextGenerationStream) -> impl TextGeneration {
Expand Down
12 changes: 7 additions & 5 deletions crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
},
languages::get_language,
};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tabby_inference::{
make_text_generation, TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder,
};
use thiserror::Error;
use utoipa::ToSchema;

Expand Down Expand Up @@ -226,20 +228,20 @@
}

pub struct CompletionService {
engine: Arc<dyn TextGeneration>,
engine: Box<dyn TextGeneration>,
logger: Arc<dyn EventLogger>,
prompt_builder: completion_prompt::PromptBuilder,
}

impl CompletionService {
fn new(
engine: Arc<dyn TextGeneration>,
engine: impl TextGeneration + 'static,

Check warning on line 238 in crates/tabby/src/services/completion.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/completion.rs#L238

Added line #L238 was not covered by tests
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
prompt_template: Option<String>,
) -> Self {
Self {
engine,
engine: Box::new(engine),

Check warning on line 244 in crates/tabby/src/services/completion.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/completion.rs#L244

Added line #L244 was not covered by tests
prompt_builder: completion_prompt::PromptBuilder::new(prompt_template, Some(code)),
logger,
}
Expand Down Expand Up @@ -351,5 +353,5 @@
},
) = model::load_text_generation(model, device, parallelism).await;

CompletionService::new(engine.clone(), code, logger, prompt_template)
CompletionService::new(make_text_generation(engine), code, logger, prompt_template)

Check warning on line 356 in crates/tabby/src/services/completion.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/completion.rs#L356

Added line #L356 was not covered by tests
}
20 changes: 4 additions & 16 deletions crates/tabby/src/services/model/chat.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use std::sync::Arc;

use anyhow::Result;
use async_stream::stream;
use futures::stream::BoxStream;
use minijinja::{context, Environment};
use tabby_common::api::chat::Message;
use tabby_inference::{
chat::{self, ChatCompletionStream},
TextGeneration, TextGenerationOptions, TextGenerationStream,
TextGenerationOptions, TextGenerationStream,
};

struct ChatPromptBuilder {
Expand Down Expand Up @@ -37,7 +34,7 @@
}

struct ChatCompletionImpl {
engine: Arc<dyn TextGeneration>,
engine: Box<dyn TextGenerationStream>,
prompt_builder: ChatPromptBuilder,
}

Expand All @@ -50,21 +47,12 @@
#[async_trait::async_trait]
impl TextGenerationStream for ChatCompletionImpl {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
let prompt = prompt.to_owned();
let s = stream! {
for await (streaming, text) in self.engine.generate_stream(&prompt, options).await {
if streaming {
yield text;
}
}
};

Box::pin(s)
self.engine.generate(prompt, options).await

Check warning on line 50 in crates/tabby/src/services/model/chat.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/chat.rs#L50

Added line #L50 was not covered by tests
}
}

pub fn make_chat_completion(
engine: Arc<dyn TextGeneration>,
engine: Box<dyn TextGenerationStream>,

Check warning on line 55 in crates/tabby/src/services/model/chat.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/chat.rs#L55

Added line #L55 was not covered by tests
prompt_template: String,
) -> impl ChatCompletionStream {
ChatCompletionImpl {
Expand Down
12 changes: 5 additions & 7 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
terminal::{HeaderFormat, InfoMessage},
};
use tabby_download::download_model;
use tabby_inference::{
chat::ChatCompletionStream, make_text_generation, TextGeneration, TextGenerationStream,
};
use tabby_inference::{chat::ChatCompletionStream, TextGenerationStream};
use tracing::info;

use crate::{fatal, Device};
Expand Down Expand Up @@ -39,12 +37,12 @@
model_id: &str,
device: &Device,
parallelism: u8,
) -> (Arc<dyn TextGeneration>, PromptInfo) {
) -> (Box<dyn TextGenerationStream>, PromptInfo) {

Check warning on line 40 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L40

Added line #L40 was not covered by tests
#[cfg(feature = "experimental-http")]
if device == &Device::ExperimentalHttp {
let (engine, prompt_template, chat_template) = http_api_bindings::create(model_id);
return (
engine,
Box::new(engine),

Check warning on line 45 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L45

Added line #L45 was not covered by tests
PromptInfo {
prompt_template,
chat_template,
Expand All @@ -61,15 +59,15 @@
parallelism,
);
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(make_text_generation(engine)), engine_info)
(Box::new(engine), engine_info)

Check warning on line 62 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L62

Added line #L62 was not covered by tests
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism);
(
Arc::new(make_text_generation(engine)),
Box::new(engine),

Check warning on line 70 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L70

Added line #L70 was not covered by tests
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
Expand Down
Loading