Skip to content

Commit

Permalink
refactor(core): restructure CodeGeneration, CompletionStream, ChatCom…
Browse files Browse the repository at this point in the history
…pletionStream (#1998)

* refactor(core): remove generate method from TextGeneration

* refactor out TextGeneration struct

* eliminate chat sub module

* refactor: extract CompletionStream

* restruct ChatCompletionStream <-> CompletionStream

* restruct with CodeGeneration

* handle output limitation inside of Stream traits
  • Loading branch information
wsxiaoys authored Apr 29, 2024
1 parent 602cf5e commit b76a101
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 237 deletions.
4 changes: 2 additions & 2 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use std::sync::Arc;
use openai::OpenAIEngine;
use openai_chat::OpenAIChatEngine;
use serde_json::Value;
use tabby_inference::{chat::ChatCompletionStream, TextGenerationStream};
use tabby_inference::{ChatCompletionStream, CompletionStream};

pub fn create(model: &str) -> (impl TextGenerationStream, Option<String>, Option<String>) {
pub fn create(model: &str) -> (impl CompletionStream, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
Expand Down
8 changes: 4 additions & 4 deletions crates/http-api-bindings/src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_openai::{config::OpenAIConfig, error::OpenAIError, types::CreateComple
use async_stream::stream;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tabby_inference::{TextGenerationOptions, TextGenerationStream};
use tabby_inference::{CompletionOptions, CompletionStream};
use tracing::warn;

pub struct OpenAIEngine {
Expand All @@ -26,12 +26,12 @@ impl OpenAIEngine {
}

#[async_trait]
impl TextGenerationStream for OpenAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
impl CompletionStream for OpenAIEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let request = CreateCompletionRequestArgs::default()
.model(&self.model_name)
.max_tokens(options.max_decoding_length as u16)
.temperature(options.sampling_temperature)
.max_tokens(options.max_decoding_tokens as u16)
.stream(true)
.prompt(prompt)
.build();
Expand Down
3 changes: 2 additions & 1 deletion crates/http-api-bindings/src/openai_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use async_stream::stream;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tabby_common::api::chat::Message;
use tabby_inference::chat::{ChatCompletionOptions, ChatCompletionStream};
use tabby_inference::{ChatCompletionOptions, ChatCompletionStream};
use tracing::{debug, warn};

pub struct OpenAIChatEngine {
Expand Down Expand Up @@ -49,6 +49,7 @@ impl ChatCompletionStream for OpenAIChatEngine {

let request = CreateChatCompletionRequestArgs::default()
.seed(options.seed as i64)
.max_tokens(options.max_decoding_tokens as u16)
.model(&self.model_name)
.temperature(options.sampling_temperature)
.stream(true)
Expand Down
11 changes: 8 additions & 3 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use derive_builder::Builder;
use ffi::create_engine;
use futures::stream::BoxStream;
use llama::{LlamaInitRequest, LlamaService};
use tabby_inference::{TextGenerationOptions, TextGenerationStream};
use tabby_inference::{CompletionOptions, CompletionStream};

#[cxx::bridge(namespace = "llama")]
mod ffi {
Expand Down Expand Up @@ -68,8 +68,9 @@ impl LlamaTextGeneration {
}

#[async_trait]
impl TextGenerationStream for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream<String> {
impl CompletionStream for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let mut output_token_budget = options.max_decoding_tokens;
let mut rx = self
.service
.add_request(
Expand All @@ -83,6 +84,10 @@ impl TextGenerationStream for LlamaTextGeneration {
let s = stream! {
while let Some(new_text) = rx.recv().await {
yield new_text;
output_token_budget -= 1;
if output_token_budget <= 0 {
break;
}
}

rx.close();
Expand Down
40 changes: 4 additions & 36 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use anyhow::Result;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_common::api::chat::Message;

use crate::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};

#[derive(Builder, Debug)]
pub struct ChatCompletionOptions {
#[builder(default = "0.1")]
pub sampling_temperature: f32,

#[builder(default = "TextGenerationOptions::default_seed()")]
#[builder(default = "crate::default_seed()")]
pub seed: u64,

#[builder(default = "1920")]
pub max_decoding_tokens: i32,
}

#[async_trait]
Expand All @@ -24,35 +24,3 @@ pub trait ChatCompletionStream: Sync + Send {
options: ChatCompletionOptions,
) -> Result<BoxStream<String>>;
}

pub trait ChatPromptBuilder {
fn build_chat_prompt(&self, messages: &[Message]) -> Result<String>;
}

#[async_trait]
impl<T: ChatPromptBuilder + TextGeneration> ChatCompletionStream for T {
async fn chat_completion(
&self,
messages: &[Message],
options: ChatCompletionOptions,
) -> Result<BoxStream<String>> {
let options = TextGenerationOptionsBuilder::default()
.max_input_length(2048)
.max_decoding_length(1920)
.seed(options.seed)
.sampling_temperature(options.sampling_temperature)
.build()?;

let prompt = self.build_chat_prompt(messages)?;

let s = stream! {
for await (streaming, content) in self.generate_stream(&prompt, options).await {
if streaming {
yield content
}
}
};

Ok(Box::pin(s))
}
}
79 changes: 79 additions & 0 deletions crates/tabby-inference/src/code.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::sync::Arc;

use async_stream::stream;
use derive_builder::Builder;
use futures::StreamExt;
use tabby_common::languages::Language;

use crate::{decoding::StopConditionFactory, CompletionOptionsBuilder, CompletionStream};

#[derive(Builder, Debug)]
pub struct CodeGenerationOptions {
#[builder(default = "1024")]
pub max_input_length: usize,

#[builder(default = "256")]
pub max_decoding_tokens: i32,

#[builder(default = "0.1")]
pub sampling_temperature: f32,

#[builder(default = "crate::default_seed()")]
pub seed: u64,

#[builder(default = "None")]
pub language: Option<&'static Language>,
}

pub struct CodeGeneration {
imp: Arc<dyn CompletionStream>,
stop_condition_factory: StopConditionFactory,
}

impl CodeGeneration {
pub fn new(imp: Arc<dyn CompletionStream>) -> Self {
Self {
imp,
stop_condition_factory: StopConditionFactory::default(),
}
}
}

impl CodeGeneration {
pub async fn generate(&self, prompt: &str, options: CodeGenerationOptions) -> String {
let s = stream! {
let mut text = String::new();
let mut stop_condition = self.stop_condition_factory.create(
prompt,
options.language,
);

let options = CompletionOptionsBuilder::default()
.max_input_length(options.max_input_length)
.max_decoding_tokens(options.max_decoding_tokens)
.sampling_temperature(options.sampling_temperature)
.seed(options.seed)
.build()
.expect("Failed to build completion options");

for await new_text in self.imp.generate(prompt, options).await {
let (should_stop, stop_length) = stop_condition.should_stop(&new_text);
text += &new_text;
if should_stop {
// stop condition matched against prompt + generated text. There's a chance that stop_length >= text.len();
let new_text_length = text.len().checked_sub(stop_length).unwrap_or_default();
text.truncate(new_text_length);
break;
}
}

yield text;
};

if let Some(text) = Box::pin(s).into_future().await.0 {
text
} else {
String::new()
}
}
}
19 changes: 19 additions & 0 deletions crates/tabby-inference/src/completion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;

#[derive(Builder, Debug)]
pub struct CompletionOptions {
pub max_input_length: usize,

pub max_decoding_tokens: i32,

pub sampling_temperature: f32,

pub seed: u64,
}

#[async_trait]
pub trait CompletionStream: Sync + Send {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String>;
}
21 changes: 7 additions & 14 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,11 @@ impl Default for StopConditionFactory {
type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;

impl StopConditionFactory {
pub fn create(
&self,
text: &str,
max_decoding_length: usize,
language: Option<&'static Language>,
) -> StopCondition {
pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
if let Some(language) = language {
StopCondition::new(self.get_trie(language), max_decoding_length, text)
StopCondition::new(self.get_trie(language), text)
} else {
StopCondition::new(None, max_decoding_length, text)
StopCondition::new(None, text)
}
}

Expand Down Expand Up @@ -65,16 +60,14 @@ fn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {

pub struct StopCondition<'a> {
stop_trie: Option<CachedTrie<'a>>,
max_decoding_length: usize,
reversed_text: String,
num_decoded: usize,
}

impl<'a> StopCondition<'a> {
pub fn new(stop_trie: Option<CachedTrie<'a>>, max_decoding_length: usize, text: &str) -> Self {
pub fn new(stop_trie: Option<CachedTrie<'a>>, text: &str) -> Self {
Self {
stop_trie,
max_decoding_length,
reversed_text: reverse(text),
num_decoded: 0,
}
Expand All @@ -93,7 +86,7 @@ impl<'a> StopCondition<'a> {
}
}
}
(self.num_decoded >= self.max_decoding_length, 0)
(false, 0)
}
}

Expand Down Expand Up @@ -121,14 +114,14 @@ mod tests {
#[test]
fn test_stop_condition_max_length() {
let factory = StopConditionFactory::default();
let mut cond = factory.create("", 4, Some(&UNKNOWN_LANGUAGE));
let mut cond = factory.create("", Some(&UNKNOWN_LANGUAGE));
let (should_stop, _) = cond.should_stop("1");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("2");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("3");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("4");
assert!(should_stop)
assert!(!should_stop)
}
}
56 changes: 0 additions & 56 deletions crates/tabby-inference/src/imp.rs

This file was deleted.

Loading

0 comments on commit b76a101

Please sign in to comment.