Skip to content

Commit

Permalink
refactor: optimize env variable reading and split code into more read…
Browse files Browse the repository at this point in the history
…able functions

initialize clients for ollama and paperless in seperate methods
  • Loading branch information
Björn Urban committed Feb 27, 2024
1 parent 92e70c3 commit 1132a36
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 81 deletions.
125 changes: 45 additions & 80 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ use std::result::Result;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::{Deserialize, Serialize};
use serde_json::{Value};
use std::collections::HashMap;
use std::env;
use std::error::Error as StdError;
use crate::llm_api::generate_response;
use crate::paperless::{get_data_from_paperless, query_custom_fields};
use crate::paperless::{get_data_from_paperless, query_custom_fields, update_document_fields};

#[derive(Serialize, Deserialize, Debug, Clone)]
struct Document {
Expand Down Expand Up @@ -64,54 +62,41 @@ struct Field {
}




#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let base_url = std::env::var("PAPERLESS_BASE_URL").unwrap();

let token = env::var("PAPERLESS_TOKEN").expect("TOKEN is not set in .env file");
// Create HeaderMap and add Authorization header
// Initialize the HTTP client with Paperless API token and base URL
fn init_paperless_client(token: &str) -> Client {
let mut headers = HeaderMap::new();
let header_value = HeaderValue::from_str(&format!("Token {}", token)).unwrap();
let header_value = HeaderValue::from_str(&format!("Token {}", token))
.expect("Invalid header value for TOKEN");
headers.insert(AUTHORIZATION, header_value);
let client = Client::builder().default_headers(headers).build().unwrap();

// Create a Client with the default headers
let ollama = Ollama::new("http://localhost".to_string(), 11434);
//let model = "mistral:latest".to_string();
let model = "llama2:13b".to_string();
let prompt_base = "Please extract metadata from the provided document and return it in JSON format. The fields I need are: title,topic,sender,recipient,urgency(with value either n/a or low or medium or high),date_received,category. Analyze the document to find the values for these fields and format the response as a JSON object. Use the most likely answer for each field. The response should contain only JSON data where the key and values are all in simple string format(no nested object) for direct parsing by another program. So now additional text or explanation, no introtext, the answer should start and end with curly brackets delimiting the json object ".to_string();

let fields = query_custom_fields(&client, &base_url).await?;
//let res = ollama.generate(GenerationRequest::new(model, prompt)).await;
Client::builder()
.default_headers(headers)
.build()
.expect("Failed to build client")
}

// if let Ok(res) = res {
// println!("{}", res.response);
// }
// Initialize Ollama client
fn init_ollama_client(host: &str, port: u16, secure_endpoint: bool) -> Ollama {
let protocol = if secure_endpoint { "https" } else { "http" };
let ollama_base_url = format!("{}://{}", protocol, host);
Ollama::new(ollama_base_url, port)
}

// Query data from paperless-ngx endpoint
// Refactor the main process into a function for better readability
async fn process_documents(client: &Client, ollama: &Ollama, model: &str, base_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let prompt_base = "Please extract metadata from the provided document and return it in JSON format. The fields I need are: title,topic,sender,recipient,urgency(with value either n/a or low or medium or high),date_received,category. Analyze the document to find the values for these fields and format the response as a JSON object. Use the most likely answer for each field. The response should contain only JSON data where the key and values are all in simple string format(no nested object) for direct parsing by another program. So now additional text or explanation, no introtext, the answer should start and end with curly brackets delimiting the json object ".to_string();
let fields = query_custom_fields(client, base_url).await?;
match get_data_from_paperless(&client, &base_url).await {
Ok(data) => {
for document in data {
let res = generate_response(&ollama, &model, &prompt_base, &document).await;
if let Ok(res) = res {
println!("Response: {}", res.response);
if let Some(json_str) = extract_json_object(&res.response) {
println!("JSON: {}", json_str);
let parsed_json = serde_json::from_str(&json_str);
match parsed_json {
Ok(json) => {
update_document_fields(&client, document.id, &fields, &json, &base_url).await;
// Use the parsed JSON here
}
Err(e) => {
eprintln!("Error parsing JSON: {}", e);
}
}
} else {
eprintln!("No JSON object found in the response");
let res = generate_response(ollama, &model.to_string(), &prompt_base.to_string(), &document).await?;
if let Some(json_str) = extract_json_object(&res.response) {
match serde_json::from_str(&json_str) {
Ok(json) => update_document_fields(client, document.id, &fields, &json, base_url).await?,
Err(e) => eprintln!("Error parsing JSON: {}", e),
}
} else {
eprintln!("No JSON object found in the response");
}
}
}
Expand All @@ -120,44 +105,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

async fn update_document_fields(
client: &Client,
document_id: u32,
fields: &Vec<Field>,
metadata: &HashMap<String, Option<Value>>,
base_url: &str
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut custom_fields = Vec::new();

for (key, value) in metadata {
if key == "title" {
continue;
}
if let Some(field) = fields.iter().find(|&f| f.name == *key) {
let custom_field = CustomField {
field: field.id.clone(),
value: value.as_ref().cloned(),
};
custom_fields.push(custom_field);
}
}
// Add the tagged field, to indicate that the document has been processed
let custom_field = CustomField {
field: 1,
value: Some(serde_json::json!(true)),
};
custom_fields.push(custom_field);
let mut payload = serde_json::Map::new();

payload.insert("custom_fields".to_string(), serde_json::json!(custom_fields));
if let Some(value) = metadata.get("title").and_then(|v| v.as_ref().and_then(|v| v.as_str())) {
payload.insert("title".to_string(), serde_json::json!(value));
}
let url = format!("{}/api/documents/{}/", base_url, document_id);
let res = client.patch(&url).json(&payload).send().await?;
let body = res.text().await?;
println!("{}", body);
Ok(())
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let token = env::var("PAPERLESS_TOKEN").expect("PAPERLESS_TOKEN is not set in .env file");
let base_url = env::var("PAPERLESS_BASE_URL").expect("PAPERLESS_BASE_URL is not set in .env file");
let client = init_paperless_client(&token);

let ollama_host = env::var("OLLAMA_HOST").unwrap_or_else(|_| "localhost".to_string());
let ollama_port = env::var("OLLAMA_PORT")
.unwrap_or_else(|_| "11434".to_string())
.parse::<u16>().unwrap_or(11434);
let ollama_secure_endpoint = env::var("OLLAMA_SECURE_ENDPOINT")
.unwrap_or_else(|_| "false".to_string())
.parse::<bool>().unwrap_or(false);

let ollama = init_ollama_client(&ollama_host, ollama_port, ollama_secure_endpoint);
let model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama2:13b".to_string());

process_documents(&client, &ollama, &model, &base_url).await
}

fn extract_json_object(input: &str) -> Option<String> {
Expand Down
44 changes: 43 additions & 1 deletion src/paperless.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use reqwest::Client;
use serde::de::StdError;
use crate::{Document, Field, Response};
use serde_json::Value;
use crate::{CustomField, Document, Field, Response};

pub async fn get_data_from_paperless(
client: &Client,
Expand Down Expand Up @@ -60,4 +62,44 @@ pub async fn query_custom_fields(
Err(e.into()) // Remove the semicolon here
}
}
}

pub async fn update_document_fields(
client: &Client,
document_id: u32,
fields: &Vec<Field>,
metadata: &HashMap<String, Option<Value>>,
base_url: &str
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut custom_fields = Vec::new();

for (key, value) in metadata {
if key == "title" {
continue;
}
if let Some(field) = fields.iter().find(|&f| f.name == *key) {
let custom_field = CustomField {
field: field.id.clone(),
value: value.as_ref().cloned(),
};
custom_fields.push(custom_field);
}
}
// Add the tagged field, to indicate that the document has been processed
let custom_field = CustomField {
field: 1,
value: Some(serde_json::json!(true)),
};
custom_fields.push(custom_field);
let mut payload = serde_json::Map::new();

payload.insert("custom_fields".to_string(), serde_json::json!(custom_fields));
if let Some(value) = metadata.get("title").and_then(|v| v.as_ref().and_then(|v| v.as_str())) {
payload.insert("title".to_string(), serde_json::json!(value));
}
let url = format!("{}/api/documents/{}/", base_url, document_id);
let res = client.patch(&url).json(&payload).send().await?;
let body = res.text().await?;
println!("{}", body);
Ok(())
}

0 comments on commit 1132a36

Please sign in to comment.