From 73929bd8e9c3cace4326f913879e74e4e2e26b88 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 15 May 2024 22:49:36 -0700 Subject: [PATCH] refactor(core): move document trait for CodeSearchDocument to index/code/document.rs (#2148) update --- crates/tabby-common/src/api/code.rs | 128 ------------------ .../tabby-common/src/index/code/document.rs | 128 ++++++++++++++++++ .../src/index/{code.rs => code/mod.rs} | 2 + 3 files changed, 130 insertions(+), 128 deletions(-) create mode 100644 crates/tabby-common/src/index/code/document.rs rename crates/tabby-common/src/index/{code.rs => code/mod.rs} (99%) diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index e6c9acb0f683..b578530ff3c9 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -1,18 +1,9 @@ use async_trait::async_trait; use derive_builder::Builder; use serde::{Deserialize, Serialize}; -use tantivy::{ - schema::{ - document::{DeserializeError, DocumentDeserialize, ReferenceValue, ReferenceValueLeaf}, - OwnedValue, Value, - }, - Document, -}; use thiserror::Error; use utoipa::ToSchema; -use crate::index::CodeSearchSchema; - #[derive(Default, Serialize, Deserialize, Debug, ToSchema)] pub struct CodeSearchResponse { pub num_hits: usize, @@ -73,122 +64,3 @@ pub trait CodeSearch: Send + Sync { offset: usize, ) -> Result; } - -impl Document for CodeSearchDocument { - type Value<'a> = RefValue<'a>; - type FieldsValuesIter<'a> = CodeSearchDocumentFieldValueIter<'a>; - - fn iter_fields_and_values(&self) -> Self::FieldsValuesIter<'_> { - CodeSearchDocumentFieldValueIter::new(self) - } -} - -pub struct CodeSearchDocumentFieldValueIter<'a> { - field_id: i32, - doc: &'a CodeSearchDocument, -} - -impl<'a> CodeSearchDocumentFieldValueIter<'a> { - fn new(doc: &'a CodeSearchDocument) -> Self { - Self { field_id: 0, doc } - } -} - -impl<'a> Iterator for CodeSearchDocumentFieldValueIter<'a> { - type Item = (tantivy::schema::Field, RefValue<'a>); - - fn next(&mut self) -> Option { - let schema = CodeSearchSchema::instance(); - let item = match self.field_id { - 0 => Some((schema.field_body, RefValue::Str(&self.doc.body))), - 1 => Some((schema.field_filepath, RefValue::Str(&self.doc.filepath))), - 2 => Some((schema.field_git_url, RefValue::Str(&self.doc.git_url))), - 3 => Some((schema.field_language, RefValue::Str(&self.doc.language))), - 4 => Some((schema.field_file_id, RefValue::Str(&self.doc.file_id))), - 5 => Some(( - schema.field_start_line, - RefValue::Usize(self.doc.start_line), - )), - _ => None, - }; - - if item.is_some() { - self.field_id += 1; - } - - item - } -} - -#[derive(Clone, Debug)] -pub enum RefValue<'a> { - Str(&'a str), - Usize(usize), -} -impl<'a> Value<'a> for RefValue<'a> { - type ArrayIter = std::iter::Empty; - type ObjectIter = std::iter::Empty<(&'a str, Self)>; - - fn as_value(&self) -> tantivy::schema::document::ReferenceValue<'a, Self> { - match self { - RefValue::Str(s) => ReferenceValue::Leaf(ReferenceValueLeaf::Str(s)), - RefValue::Usize(i) => ReferenceValue::Leaf(ReferenceValueLeaf::U64(*i as u64)), - } - } -} - -impl DocumentDeserialize for CodeSearchDocument { - fn deserialize<'de, D>(mut deserializer: D) -> Result - where - D: tantivy::schema::document::DocumentDeserializer<'de>, - { - let code = CodeSearchSchema::instance(); - let mut builder = CodeSearchDocumentBuilder::default(); - while let Some((field, value)) = deserializer.next_field::()? { - let value = value.as_ref(); - if field == code.field_body { - builder.body(value.try_deserialize()?); - } else if field == code.field_filepath { - builder.filepath(value.try_deserialize()?); - } else if field == code.field_git_url { - builder.git_url(value.try_deserialize()?); - } else if field == code.field_language { - builder.language(value.try_deserialize()?); - } else if field == code.field_file_id { - builder.file_id(value.try_deserialize()?); - } else if field == code.field_start_line { - builder.start_line(value.try_deserialize()?); - } - } - - builder - .build() - .map_err(|e| DeserializeError::Custom(e.to_string())) - } -} - -trait TryDeserialize { - fn try_deserialize(self) -> Result; -} - -impl TryDeserialize for &OwnedValue { - fn try_deserialize(self) -> Result { - match self { - OwnedValue::U64(i) => Ok(*i as usize), - _ => Err(DeserializeError::Custom( - "Field type doesn't match".to_string(), - )), - } - } -} - -impl TryDeserialize for &OwnedValue { - fn try_deserialize(self) -> Result { - match self { - OwnedValue::Str(s) => Ok(s.to_string()), - _ => Err(DeserializeError::Custom( - "Field type doesn't match".to_string(), - )), - } - } -} diff --git a/crates/tabby-common/src/index/code/document.rs b/crates/tabby-common/src/index/code/document.rs new file mode 100644 index 000000000000..01fb54e5564f --- /dev/null +++ b/crates/tabby-common/src/index/code/document.rs @@ -0,0 +1,128 @@ +use tantivy::{ + schema::{ + document::{DeserializeError, DocumentDeserialize, ReferenceValue, ReferenceValueLeaf}, + OwnedValue, Value, + }, + Document, +}; + +use super::CodeSearchSchema; +use crate::api::code::{CodeSearchDocument, CodeSearchDocumentBuilder}; +impl Document for CodeSearchDocument { + type Value<'a> = RefValue<'a>; + type FieldsValuesIter<'a> = CodeSearchDocumentFieldValueIter<'a>; + + fn iter_fields_and_values(&self) -> Self::FieldsValuesIter<'_> { + CodeSearchDocumentFieldValueIter::new(self) + } +} + +pub struct CodeSearchDocumentFieldValueIter<'a> { + field_id: i32, + doc: &'a CodeSearchDocument, +} + +impl<'a> CodeSearchDocumentFieldValueIter<'a> { + fn new(doc: &'a CodeSearchDocument) -> Self { + Self { field_id: 0, doc } + } +} + +impl<'a> Iterator for CodeSearchDocumentFieldValueIter<'a> { + type Item = (tantivy::schema::Field, RefValue<'a>); + + fn next(&mut self) -> Option { + let schema = CodeSearchSchema::instance(); + let item = match self.field_id { + 0 => Some((schema.field_body, RefValue::Str(&self.doc.body))), + 1 => Some((schema.field_filepath, RefValue::Str(&self.doc.filepath))), + 2 => Some((schema.field_git_url, RefValue::Str(&self.doc.git_url))), + 3 => Some((schema.field_language, RefValue::Str(&self.doc.language))), + 4 => Some((schema.field_file_id, RefValue::Str(&self.doc.file_id))), + 5 => Some(( + schema.field_start_line, + RefValue::Usize(self.doc.start_line), + )), + _ => None, + }; + + if item.is_some() { + self.field_id += 1; + } + + item + } +} + +#[derive(Clone, Debug)] +pub enum RefValue<'a> { + Str(&'a str), + Usize(usize), +} +impl<'a> Value<'a> for RefValue<'a> { + type ArrayIter = std::iter::Empty; + type ObjectIter = std::iter::Empty<(&'a str, Self)>; + + fn as_value(&self) -> tantivy::schema::document::ReferenceValue<'a, Self> { + match self { + RefValue::Str(s) => ReferenceValue::Leaf(ReferenceValueLeaf::Str(s)), + RefValue::Usize(i) => ReferenceValue::Leaf(ReferenceValueLeaf::U64(*i as u64)), + } + } +} + +impl DocumentDeserialize for CodeSearchDocument { + fn deserialize<'de, D>(mut deserializer: D) -> Result + where + D: tantivy::schema::document::DocumentDeserializer<'de>, + { + let code = CodeSearchSchema::instance(); + let mut builder = CodeSearchDocumentBuilder::default(); + while let Some((field, value)) = deserializer.next_field::()? { + let value = value.as_ref(); + if field == code.field_body { + builder.body(value.try_deserialize()?); + } else if field == code.field_filepath { + builder.filepath(value.try_deserialize()?); + } else if field == code.field_git_url { + builder.git_url(value.try_deserialize()?); + } else if field == code.field_language { + builder.language(value.try_deserialize()?); + } else if field == code.field_file_id { + builder.file_id(value.try_deserialize()?); + } else if field == code.field_start_line { + builder.start_line(value.try_deserialize()?); + } + } + + builder + .build() + .map_err(|e| DeserializeError::Custom(e.to_string())) + } +} + +trait TryDeserialize { + fn try_deserialize(self) -> Result; +} + +impl TryDeserialize for &OwnedValue { + fn try_deserialize(self) -> Result { + match self { + OwnedValue::U64(i) => Ok(*i as usize), + _ => Err(DeserializeError::Custom( + "Field type doesn't match".to_string(), + )), + } + } +} + +impl TryDeserialize for &OwnedValue { + fn try_deserialize(self) -> Result { + match self { + OwnedValue::Str(s) => Ok(s.to_string()), + _ => Err(DeserializeError::Custom( + "Field type doesn't match".to_string(), + )), + } + } +} diff --git a/crates/tabby-common/src/index/code.rs b/crates/tabby-common/src/index/code/mod.rs similarity index 99% rename from crates/tabby-common/src/index/code.rs rename to crates/tabby-common/src/index/code/mod.rs index 9f7523e320db..b25e0cb550ae 100644 --- a/crates/tabby-common/src/index/code.rs +++ b/crates/tabby-common/src/index/code/mod.rs @@ -1,3 +1,5 @@ +mod document; + use lazy_static::lazy_static; use tantivy::{ query::{BooleanQuery, ConstScoreQuery, Query, TermQuery},