Skip to content

Commit

Permalink
refactor: extract language related data into languages.rs (#518)
Browse files Browse the repository at this point in the history
* refactor: extract language related data into languages.rs

* fix

* cleanup index

* fix

* further sanitize

* add a score threshold
  • Loading branch information
wsxiaoys authored Oct 7, 2023
1 parent d85a789 commit 8c09f75
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 69 deletions.
6 changes: 3 additions & 3 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use regex::Regex;
use tokenizers::tokenizer::Tokenizer;

pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
}

fn reverse<T>(s: T) -> String
Expand All @@ -28,12 +28,12 @@ impl DecodingFactory {
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static Vec<&'static str>,
stop_words: &'static [&'static str],
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
}

fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
if stop_words.is_empty() {
None
} else {
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32,

#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>,
pub stop_words: &'static [&'static str],
}

static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
Expand Down
15 changes: 10 additions & 5 deletions crates/tabby-scheduler/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tantivy::{
// Magic numbers
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
static MAX_BODY_LINES_THRESHOLD: usize = 15;

pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder();
Expand Down Expand Up @@ -82,19 +83,23 @@ struct IndexedDocument {
}

fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
file.tags.into_iter().map(move |tag| {
file.tags.into_iter().filter_map(move |tag| {
let name = file.content.get(tag.name_range).unwrap().to_owned();
let body = file.content.get(tag.range).unwrap().to_owned();

if body.lines().collect::<Vec<_>>().len() > MAX_BODY_LINES_THRESHOLD {
return None;
}

let language = reduce_language_if_needed(&file.language).to_owned();
IndexedDocument {
Some(IndexedDocument {
git_url: file.git_url.clone(),
filepath: file.filepath.clone(),
language,
name,
body,
kind: tag.syntax_type_name,
}
})
})
}

Expand Down Expand Up @@ -126,7 +131,7 @@ mod tests {
{
"range": {
"start": 290,
"end": 3094
"end": 320
},
"name_range": {
"start": 296,
Expand All @@ -142,7 +147,7 @@ mod tests {
{
"range": {
"start": 953,
"end": 1507
"end": 970
},
"name_range": {
"start": 957,
Expand Down
4 changes: 2 additions & 2 deletions crates/tabby/src/serve/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;

use self::languages::get_stop_words;
use self::languages::get_language;
use super::search::IndexServer;

#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
Expand Down Expand Up @@ -81,7 +81,7 @@ pub async fn completions(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_stop_words(&language))
.stop_words(get_language(&language).stop_words)
.build()
.unwrap();

Expand Down
74 changes: 46 additions & 28 deletions crates/tabby/src/serve/completions/languages.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashMap;

use lazy_static::lazy_static;

pub struct Language {
pub stop_words: &'static [&'static str],
pub line_comment: &'static str,
}

lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
Expand All @@ -20,29 +23,48 @@ lazy_static! {
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
let mut map = HashMap::new();
map.insert(
"python",
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(),
);
map.insert(
"javascript",
vec!["\nfunction", "\n//", "\nimport", "\nclass"],
);
map.insert(
"typescript",
vec![
"\nfunction",
"\n//",
"\nimport",
"\nclass",
"\ninterface",
"\ntype",
],
);
map
static ref UNKONWN: Language = Language {
stop_words: &DEFAULT,
line_comment: "#"
};
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref PYTHON: Language = Language {
stop_words: &PYTHON_STOP_WORDS,
line_comment: "#",
};
static ref RUST_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref RUST: Language = Language {
stop_words: &RUST_STOP_WORDS,
line_comment: "//",
};
static ref JAVASCRIPT_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref JAVASCRIPT: Language = Language {
stop_words: &JAVASCRIPT_STOP_WORDS,
line_comment: "",
};
static ref TYPESCRIPT_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref TYPESCRIPT: Language = Language {
stop_words: &TYPESCRIPT_STOP_WORDS,
line_comment: "",
};
}

pub fn get_language(language: &str) -> &'static Language {
if language == "python" {
&PYTHON
} else if language == "rust" {
&RUST
} else if language == "javascript" {
&JAVASCRIPT
} else if language == "typescript" {
&TYPESCRIPT
} else {
&UNKONWN
}
}

trait WithDefault {
Expand All @@ -56,7 +78,3 @@ impl WithDefault for Vec<&'static str> {
self
}
}

pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
LANGUAGES.get(language).unwrap_or(&DEFAULT)
}
22 changes: 12 additions & 10 deletions crates/tabby/src/serve/completions/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::{collections::HashMap, env, sync::Arc};
use std::{env, sync::Arc};

use lazy_static::lazy_static;
use strfmt::strfmt;
use tracing::{info, warn};

use super::Segments;
use crate::serve::search::IndexServer;
use crate::serve::{completions::languages::get_language, search::IndexServer};

static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
static SNIPPET_SCORE_THRESHOLD: f32 = 5.0;

pub struct PromptBuilder {
prompt_template: Option<String>,
Expand Down Expand Up @@ -84,7 +84,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
return prefix.to_owned();
}

let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
let comment_char = get_language(language).line_comment;
let mut lines: Vec<String> = vec![
format!(
"Below are some relevant {} snippets found in the repository:",
Expand Down Expand Up @@ -142,6 +142,10 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
};

for hit in serp.hits {
if hit.score < SNIPPET_SCORE_THRESHOLD {
break;
}

let body = hit.doc.body;

if text.contains(&body) {
Expand All @@ -161,15 +165,13 @@ fn sanitize_text(text: &str) -> String {
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
" ",
);
let tokens: Vec<&str> = x.split(' ').filter(|x| x.len() > 5).collect();
let tokens: Vec<&str> = x
.split(' ')
.filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5)
.collect();
tokens.join(" ")
}

lazy_static! {
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
HashMap::from([("python", "#"), ("rust", "//"),]);
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
22 changes: 2 additions & 20 deletions experimental/scheduler/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,18 @@
import streamlit as st
from typing import NamedTuple

class Doc(NamedTuple):
name: str
body: str
score: float
filepath: str

@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"],
filepath=doc["filepath"][0],
)

# force wide mode
st.set_page_config(layout="wide")

language = st.text_input("Language", "rust")

query = st.text_area("Query", "get")
tokens = re.findall(r"\w+", query)
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]

query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language

if query:
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"]
for x in hits:
doc = Doc.from_json(x)
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
st.code(doc.body)
st.write(x)

0 comments on commit 8c09f75

Please sign in to comment.