From 1132a3600e7b3d6d54f5a42a2051e7def15bdb1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Urban?= Date: Tue, 27 Feb 2024 14:33:13 +0100 Subject: [PATCH] refactor: optimize env variable reading and split code into more readable functions initialize clients for ollama and paperless in seperate methods --- src/main.rs | 125 +++++++++++++++++------------------------------ src/paperless.rs | 44 ++++++++++++++++- 2 files changed, 88 insertions(+), 81 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1abf916..0bb188a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,11 +14,9 @@ use std::result::Result; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use serde::{Deserialize, Serialize}; use serde_json::{Value}; -use std::collections::HashMap; use std::env; -use std::error::Error as StdError; use crate::llm_api::generate_response; -use crate::paperless::{get_data_from_paperless, query_custom_fields}; +use crate::paperless::{get_data_from_paperless, query_custom_fields, update_document_fields}; #[derive(Serialize, Deserialize, Debug, Clone)] struct Document { @@ -64,54 +62,41 @@ struct Field { } - - -#[tokio::main] -async fn main() -> Result<(), Box> { - let base_url = std::env::var("PAPERLESS_BASE_URL").unwrap(); - - let token = env::var("PAPERLESS_TOKEN").expect("TOKEN is not set in .env file"); - // Create HeaderMap and add Authorization header +// Initialize the HTTP client with Paperless API token and base URL +fn init_paperless_client(token: &str) -> Client { let mut headers = HeaderMap::new(); - let header_value = HeaderValue::from_str(&format!("Token {}", token)).unwrap(); + let header_value = HeaderValue::from_str(&format!("Token {}", token)) + .expect("Invalid header value for TOKEN"); headers.insert(AUTHORIZATION, header_value); - let client = Client::builder().default_headers(headers).build().unwrap(); - // Create a Client with the default headers - let ollama = Ollama::new("http://localhost".to_string(), 11434); - //let model = "mistral:latest".to_string(); - let model = "llama2:13b".to_string(); - let prompt_base = "Please extract metadata from the provided document and return it in JSON format. The fields I need are: title,topic,sender,recipient,urgency(with value either n/a or low or medium or high),date_received,category. Analyze the document to find the values for these fields and format the response as a JSON object. Use the most likely answer for each field. The response should contain only JSON data where the key and values are all in simple string format(no nested object) for direct parsing by another program. So now additional text or explanation, no introtext, the answer should start and end with curly brackets delimiting the json object ".to_string(); - - let fields = query_custom_fields(&client, &base_url).await?; - //let res = ollama.generate(GenerationRequest::new(model, prompt)).await; + Client::builder() + .default_headers(headers) + .build() + .expect("Failed to build client") +} - // if let Ok(res) = res { - // println!("{}", res.response); - // } +// Initialize Ollama client +fn init_ollama_client(host: &str, port: u16, secure_endpoint: bool) -> Ollama { + let protocol = if secure_endpoint { "https" } else { "http" }; + let ollama_base_url = format!("{}://{}", protocol, host); + Ollama::new(ollama_base_url, port) +} - // Query data from paperless-ngx endpoint +// Refactor the main process into a function for better readability +async fn process_documents(client: &Client, ollama: &Ollama, model: &str, base_url: &str) -> Result<(), Box> { + let prompt_base = "Please extract metadata from the provided document and return it in JSON format. The fields I need are: title,topic,sender,recipient,urgency(with value either n/a or low or medium or high),date_received,category. Analyze the document to find the values for these fields and format the response as a JSON object. Use the most likely answer for each field. The response should contain only JSON data where the key and values are all in simple string format(no nested object) for direct parsing by another program. So now additional text or explanation, no introtext, the answer should start and end with curly brackets delimiting the json object ".to_string(); + let fields = query_custom_fields(client, base_url).await?; match get_data_from_paperless(&client, &base_url).await { Ok(data) => { for document in data { - let res = generate_response(&ollama, &model, &prompt_base, &document).await; - if let Ok(res) = res { - println!("Response: {}", res.response); - if let Some(json_str) = extract_json_object(&res.response) { - println!("JSON: {}", json_str); - let parsed_json = serde_json::from_str(&json_str); - match parsed_json { - Ok(json) => { - update_document_fields(&client, document.id, &fields, &json, &base_url).await; - // Use the parsed JSON here - } - Err(e) => { - eprintln!("Error parsing JSON: {}", e); - } - } - } else { - eprintln!("No JSON object found in the response"); + let res = generate_response(ollama, &model.to_string(), &prompt_base.to_string(), &document).await?; + if let Some(json_str) = extract_json_object(&res.response) { + match serde_json::from_str(&json_str) { + Ok(json) => update_document_fields(client, document.id, &fields, &json, base_url).await?, + Err(e) => eprintln!("Error parsing JSON: {}", e), } + } else { + eprintln!("No JSON object found in the response"); } } } @@ -120,44 +105,24 @@ async fn main() -> Result<(), Box> { Ok(()) } -async fn update_document_fields( - client: &Client, - document_id: u32, - fields: &Vec, - metadata: &HashMap>, - base_url: &str -) -> std::result::Result<(), Box> { - let mut custom_fields = Vec::new(); - - for (key, value) in metadata { - if key == "title" { - continue; - } - if let Some(field) = fields.iter().find(|&f| f.name == *key) { - let custom_field = CustomField { - field: field.id.clone(), - value: value.as_ref().cloned(), - }; - custom_fields.push(custom_field); - } - } - // Add the tagged field, to indicate that the document has been processed - let custom_field = CustomField { - field: 1, - value: Some(serde_json::json!(true)), - }; - custom_fields.push(custom_field); - let mut payload = serde_json::Map::new(); - - payload.insert("custom_fields".to_string(), serde_json::json!(custom_fields)); - if let Some(value) = metadata.get("title").and_then(|v| v.as_ref().and_then(|v| v.as_str())) { - payload.insert("title".to_string(), serde_json::json!(value)); - } - let url = format!("{}/api/documents/{}/", base_url, document_id); - let res = client.patch(&url).json(&payload).send().await?; - let body = res.text().await?; - println!("{}", body); - Ok(()) +#[tokio::main] +async fn main() -> Result<(), Box> { + let token = env::var("PAPERLESS_TOKEN").expect("PAPERLESS_TOKEN is not set in .env file"); + let base_url = env::var("PAPERLESS_BASE_URL").expect("PAPERLESS_BASE_URL is not set in .env file"); + let client = init_paperless_client(&token); + + let ollama_host = env::var("OLLAMA_HOST").unwrap_or_else(|_| "localhost".to_string()); + let ollama_port = env::var("OLLAMA_PORT") + .unwrap_or_else(|_| "11434".to_string()) + .parse::().unwrap_or(11434); + let ollama_secure_endpoint = env::var("OLLAMA_SECURE_ENDPOINT") + .unwrap_or_else(|_| "false".to_string()) + .parse::().unwrap_or(false); + + let ollama = init_ollama_client(&ollama_host, ollama_port, ollama_secure_endpoint); + let model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama2:13b".to_string()); + + process_documents(&client, &ollama, &model, &base_url).await } fn extract_json_object(input: &str) -> Option { diff --git a/src/paperless.rs b/src/paperless.rs index 143fa1b..1cf4fbb 100644 --- a/src/paperless.rs +++ b/src/paperless.rs @@ -1,6 +1,8 @@ +use std::collections::HashMap; use reqwest::Client; use serde::de::StdError; -use crate::{Document, Field, Response}; +use serde_json::Value; +use crate::{CustomField, Document, Field, Response}; pub async fn get_data_from_paperless( client: &Client, @@ -60,4 +62,44 @@ pub async fn query_custom_fields( Err(e.into()) // Remove the semicolon here } } +} + +pub async fn update_document_fields( + client: &Client, + document_id: u32, + fields: &Vec, + metadata: &HashMap>, + base_url: &str +) -> std::result::Result<(), Box> { + let mut custom_fields = Vec::new(); + + for (key, value) in metadata { + if key == "title" { + continue; + } + if let Some(field) = fields.iter().find(|&f| f.name == *key) { + let custom_field = CustomField { + field: field.id.clone(), + value: value.as_ref().cloned(), + }; + custom_fields.push(custom_field); + } + } + // Add the tagged field, to indicate that the document has been processed + let custom_field = CustomField { + field: 1, + value: Some(serde_json::json!(true)), + }; + custom_fields.push(custom_field); + let mut payload = serde_json::Map::new(); + + payload.insert("custom_fields".to_string(), serde_json::json!(custom_fields)); + if let Some(value) = metadata.get("title").and_then(|v| v.as_ref().and_then(|v| v.as_str())) { + payload.insert("title".to_string(), serde_json::json!(value)); + } + let url = format!("{}/api/documents/{}/", base_url, document_id); + let res = client.patch(&url).json(&payload).send().await?; + let body = res.text().await?; + println!("{}", body); + Ok(()) } \ No newline at end of file