Skip to content

Commit

Permalink
codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
psytraxx committed Nov 6, 2024
1 parent 896e0df commit 5e70459
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 148 deletions.
30 changes: 22 additions & 8 deletions src/bin/query.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use photo_scanner_rust::domain::models::VectorOutputListUtils;
use photo_scanner_rust::domain::ports::{Chat, VectorDB};
use photo_scanner_rust::outbound::openai::OpenAI;
use photo_scanner_rust::outbound::qdrant::QdrantClient;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::info;
use tracing::{debug, info, warn};

const QDRANT_GRPC: &str = "http://dot.dynamicflash.de:6334";

Expand All @@ -24,22 +25,35 @@ async fn main() -> Result<()> {

let vector_db = Arc::new(QdrantClient::new(QDRANT_GRPC, 1024)?);

// what is our favorite beach holiday destination in europe
// which festivals has annina visited in in the last years
let question = "which cities did we visit in japan";

// Get the folder path from command line arguments.
let args: Vec<String> = std::env::args().collect();
if args.len() != 2 {
return Err(anyhow!("Please provide question"));
}
let question = &args[1];
let embeddings = chat.get_embeddings(vec![question.to_string()]).await?;

let result = vector_db
let mut result = vector_db
.search_points("photos", embeddings[0].as_slice(), HashMap::new())
.await?;

// Sort the results by score.
result.sort_by_score();

if result.is_empty() {
warn!(
"{:?}",
"Please check your search input - no matching documents found"
);
return Ok(());
}

let result: Vec<String> = result
.iter()
.map(|r| r.payload.get("description").cloned().unwrap_or_default())
.collect();

info!("{:?}", result);
debug!("{:?}", result);

let result = chat.process_search_result(question, &result).await?;

Expand Down
6 changes: 3 additions & 3 deletions src/domain/descriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ impl DescriptionService {
#[cfg(test)]
mod tests {
use crate::{
domain::{
descriptions::DescriptionService, embeddings::tests::ChatMock, ports::XMPMetadata,
domain::{descriptions::DescriptionService, ports::XMPMetadata},
outbound::{
image_provider::ImageCrateEncoder, test_mocks::tests::ChatMock, xmp::XMPToolkitMetadata,
},
outbound::{image_provider::ImageCrateEncoder, xmp::XMPToolkitMetadata},
};
use anyhow::Result;
use std::{
Expand Down
135 changes: 7 additions & 128 deletions src/domain/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,26 +173,20 @@ impl EmbeddingsService {
}

#[cfg(test)]
pub(super) mod tests {
pub mod tests {
use crate::{
domain::{
embeddings::EmbeddingsService,
models::{VectorInput, VectorOutput},
ports::{Chat, VectorDB},
domain::{embeddings::EmbeddingsService, ports::VectorDB},
outbound::{
test_mocks::tests::{ChatMock, VectorDBMock},
xmp::XMPToolkitMetadata,
},
outbound::xmp::XMPToolkitMetadata,
};
use anyhow::Result;
use async_trait::async_trait;
use rand::Rng;
use std::collections::HashMap;
use std::sync::Mutex;
use std::{
fs::{copy, remove_file},
path::PathBuf,
sync::Arc,
};
use tracing::debug;
#[tokio::test]
async fn test_generate_embeddings() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
Expand All @@ -205,8 +199,9 @@ pub(super) mod tests {
// Initialize dependencies
let chat = Arc::new(ChatMock);
let xmp_metadata = Arc::new(XMPToolkitMetadata::new());
let vector_db = Arc::new(VectorDBMock::new());

let vector_db = Arc::new(VectorDBMock::new());
vector_db.create_collection("photos").await?;
// Create the DescriptionService instance
let service = EmbeddingsService::new(chat, xmp_metadata.clone(), vector_db);

Expand All @@ -220,120 +215,4 @@ pub(super) mod tests {

Ok(())
}

pub struct ChatMock;

#[async_trait]
impl Chat for ChatMock {
async fn get_image_description(
&self,
_image_base64: &str,
_persons: &[String],
_folder_name: &Option<String>,
) -> Result<String> {
Ok("description".to_string())
}

async fn get_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let mut rng = rand::thread_rng();
let embedding: Vec<f32> = (0..1536).map(|_| rng.gen()).collect();
Ok(vec![embedding])
}

async fn process_search_result(
&self,
_question: &str,
_options: &[String],
) -> Result<String> {
unimplemented!()
}
}

struct VectorDBMock {
store_embeddings: Mutex<HashMap<String, Vec<VectorInput>>>,
}

impl VectorDBMock {
pub fn new() -> Self {
Self {
store_embeddings: Mutex::new(HashMap::new()),
}
}
}

#[async_trait]
impl VectorDB for VectorDBMock {
async fn create_collection(&self, _collection: &str) -> Result<bool> {
unimplemented!()
}

async fn delete_collection(&self, _text: &str) -> Result<bool> {
unimplemented!()
}

async fn find_by_id(
&self,
_collection_name: &str,
_id: &u64,
) -> Result<Option<VectorOutput>> {
return Ok(None);
}

async fn upsert_points(
&self,
collection_name: &str,
inputs: &[VectorInput],
) -> Result<bool> {
let mut entries = self.store_embeddings.lock().unwrap();
if !entries.contains_key(collection_name) {
entries.insert(collection_name.to_string(), Vec::new());
}

let collection = entries.get_mut(collection_name).unwrap();

inputs.iter().for_each(|input| {
// Find and remove an existing entry with the same ID
if collection.iter().any(|entry| entry.id == input.id) {
collection.retain(|entry| entry.id != input.id);
}

// Insert a new entry
collection.push(VectorInput {
id: input.id,
embedding: input.embedding.clone(),
payload: input.payload.clone(),
});
});
Ok(true)
}

async fn search_points(
&self,
collection_name: &str,
_input_vectors: &[f32],
_payload_required: HashMap<String, String>,
) -> Result<Vec<VectorOutput>> {
let entries = self.store_embeddings.lock().unwrap();
match entries.get(collection_name) {
Some(entries) => {
debug!(
"Found {:?} entries in collection {}",
entries, collection_name
);

entries
.iter()
.map(|entry| {
Ok(VectorOutput {
id: entry.id,
score: None,
payload: entry.payload.clone(),
})
})
.collect()
}
None => return Ok(Vec::new()),
}
}
}
}
101 changes: 100 additions & 1 deletion src/domain/models.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,114 @@
use std::collections::HashMap;

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct VectorOutput {
pub id: u64,
pub score: Option<f32>,
pub payload: HashMap<String, String>,
}

/// A trait for utility methods on a list of VectorOutput.
pub trait VectorOutputListUtils {
/// Sort the VectorOutputList in-place by score in descending order.
///
/// This method uses the `sort_by` method of Vec to sort the elements in-place based on the result of a comparison function.
/// The `partial_cmp` method is used to compare two Option<f32> values in a way that treats None as less than Some.
fn sort_by_score(&mut self);

/// Filter out results with scores below a given threshold.
///
/// This method uses the `retain` method of Vec to keep only the elements specified by the predicate.
/// The `map_or` method is used to return the provided value if the `Option` is `None`, or apply a function to the contained value if `Some`.
/// In this case, it checks if the score is `Some` and if it's greater than the threshold.
///
/// # Arguments
///
/// * `score` - The threshold score. Results with scores below this value will be removed.
fn limit_results(&mut self, score: f32);
}
pub type VectorOutputList = Vec<VectorOutput>;

impl VectorOutputListUtils for VectorOutputList {
// A method to sort the outputs in descending order of score
fn sort_by_score(&mut self) {
// Sort the VectorOutputList in-place by score in descending order
// The `sort_by` method sorts the elements in-place based on the result of a comparison function
// The `partial_cmp` method compares two Option<f32> values in a way that treats None as less than Some
self.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
}
// A method to filter out results with scores below a given threshold
fn limit_results(&mut self, score: f32) {
// Filter out results with scores below the threshold
// The `retain` method keeps only the elements specified by the predicate
// The `map_or` method returns the provided value if the `Option` is `None`, or applies a function to the contained value if `Some`
// In this case, it checks if the score is `Some` and if it's greater than the threshold
self.retain(|output| output.score.map_or(false, |s| s > score));
}
}

#[derive(Debug, Clone)]
pub struct VectorInput {
pub id: u64,
pub embedding: Vec<f32>,
pub payload: HashMap<String, String>,
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_sort_by_score() {
let mut outputs = vec![
VectorOutput {
id: 1,
score: Some(0.3),
payload: HashMap::new(),
},
VectorOutput {
id: 2,
score: Some(0.5),
payload: HashMap::new(),
},
VectorOutput {
id: 3,
score: Some(0.1),
payload: HashMap::new(),
},
];

outputs.sort_by_score();

assert_eq!(outputs[0].id, 2);
assert_eq!(outputs[1].id, 1);
assert_eq!(outputs[2].id, 3);
}

#[test]
fn test_limit_results() {
let mut output_list = vec![
VectorOutput {
score: Some(0.5),
..VectorOutput::default()
},
VectorOutput {
score: Some(0.8),
..VectorOutput::default()
},
VectorOutput {
score: Some(0.3),
..VectorOutput::default()
},
VectorOutput {
score: None,
..VectorOutput::default()
},
];

output_list.limit_results(0.4);

assert_eq!(output_list.len(), 2);
assert_eq!(output_list[0].score, Some(0.5));
assert_eq!(output_list[1].score, Some(0.8));
}
}
Loading

0 comments on commit 5e70459

Please sign in to comment.