Skip to content

Commit

Permalink
fix(core): add suffix support for openai completion legacy interface (#…
Browse files Browse the repository at this point in the history
…2825)

* chore: add default api base for mistral fim completion

* fix: add suffix support for openai completion legacy interface
  • Loading branch information
wsxiaoys authored Aug 10, 2024
1 parent 9dec617 commit c3fb6f9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
24 changes: 19 additions & 5 deletions crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,34 @@ use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;

pub struct MistralFIMEngine {
client: reqwest::Client,
api_endpoint: String,
api_key: String,
model_name: String,
}

const DEFAULT_API_ENDPOINT: &str = "https://api.mistral.ai";

impl MistralFIMEngine {
pub fn create(api_endpoint: &str, api_key: Option<String>, model_name: Option<String>) -> Self {
pub fn create(
api_endpoint: Option<&str>,
api_key: Option<String>,
model_name: Option<String>,
) -> Self {
let client = reqwest::Client::new();
let model_name = model_name.unwrap_or("codestral-latest".into());
let api_key = api_key.expect("API key is required for mistral/completion");

Self {
client,
model_name,
api_endpoint: format!("{}/v1/fim/completions", api_endpoint),
api_endpoint: format!(
"{}/v1/fim/completions",
api_endpoint.unwrap_or(DEFAULT_API_ENDPOINT)
),
api_key,
}
}
Expand All @@ -30,7 +41,7 @@ impl MistralFIMEngine {
#[derive(Serialize)]
struct FIMRequest {
prompt: String,
suffix: String,
suffix: Option<String>,
model: String,
temperature: f32,
max_tokens: i32,
Expand All @@ -57,10 +68,13 @@ struct FIMResponseDelta {
#[async_trait]
impl CompletionStream for MistralFIMEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts: Vec<&str> = prompt.split("<FIM>").collect();
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let request = FIMRequest {
prompt: parts[0].to_owned(),
suffix: parts[1].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
model: self.model_name.clone(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
Expand Down
14 changes: 6 additions & 8 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
Arc::new(engine)
}
"ollama/completion" => ollama_api_bindings::create_completion(model).await,

"mistral/completion" => {
let engine = MistralFIMEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_endpoint.as_deref(),
model.api_key.clone(),
model.model_name.clone(),
);
Expand All @@ -46,17 +42,19 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}

unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
unsupported_kind
),
}
}

const FIM_TOKEN: &str = "<|FIM|>";
const FIM_TEMPLATE: &str = "{prefix}<|FIM|>{suffix}";

pub fn build_completion_prompt(model: &HttpModelConfig) -> (Option<String>, Option<String>) {
if model.kind == "mistral/completion" {
(Some("{prefix}<FIM>{suffix}".to_owned()), None)
if model.kind == "mistral/completion" || model.kind == "openai/completion" {
(Some(FIM_TEMPLATE.to_owned()), None)
} else {
(model.prompt_template.clone(), model.chat_template.clone())
}
Expand Down
12 changes: 10 additions & 2 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;

pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
Expand All @@ -14,7 +16,7 @@ pub struct OpenAICompletionEngine {

impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
let model_name = model_name.unwrap();
let model_name = model_name.expect("model_name is required for openai/completion");
let client = reqwest::Client::new();

Self {
Expand All @@ -30,6 +32,7 @@ impl OpenAICompletionEngine {
struct CompletionRequest {
model: String,
prompt: String,
suffix: Option<String>,
max_tokens: i32,
temperature: f32,
stream: bool,
Expand All @@ -50,9 +53,14 @@ struct CompletionResponseChoice {
#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let request = CompletionRequest {
model: self.model_name.clone(),
prompt: prompt.to_owned(),
prompt: parts[0].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
Expand Down

0 comments on commit c3fb6f9

Please sign in to comment.