diff --git a/src/llm_api.rs b/src/llm_api.rs new file mode 100644 index 0000000..b6239dd --- /dev/null +++ b/src/llm_api.rs @@ -0,0 +1,17 @@ +use ollama_rs::generation::completion::GenerationResponse; +use ollama_rs::generation::completion::request::GenerationRequest; +use ollama_rs::Ollama; +use crate::Document; + +pub async fn generate_response( + ollama: &Ollama, + model: &String, + prompt_base: &String, + document: &Document, +) -> std::result::Result> { + let prompt = format!("{} {}", document.content, prompt_base); + let res = ollama + .generate(GenerationRequest::new(model.clone(), prompt)) + .await; + res.map_err(|e| e.into()) // Map the Err variant to a Box +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 1723d79..1abf916 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,25 @@ +mod llm_api; +mod paperless; + use ollama_rs::{ - generation::completion::{request::GenerationRequest, GenerationContext, GenerationResponse}, Ollama, }; use substring::Substring; -use reqwest::{Client, Error}; +use reqwest::{Client, }; use std::result::Result; -use tokio::io::stdout; -use tokio::runtime::Runtime; + //function that fetches data from the endpoint //write function that queries a rest endpoint for a given url use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, Map}; +use serde_json::{Value}; use std::collections::HashMap; use std::env; use std::error::Error as StdError; +use crate::llm_api::generate_response; +use crate::paperless::{get_data_from_paperless, query_custom_fields}; + #[derive(Serialize, Deserialize, Debug, Clone)] struct Document { id: u32, @@ -59,80 +63,9 @@ struct Field { data_type: String, } -async fn get_data_from_paperless( - client: &Client, - url: &str, -) -> std::result::Result, Box> { - // Read token from environment - let token = env::var("TOKEN").expect("TOKEN is not set in .env file"); - - //Define filter string - let filter = "NOT tagged=true".to_string(); - - let response = client.get(format!("{}/api/documents/?query={}", url, filter)).send().await?; - let body = response.text().await?; - // Remove the "Document content: " prefix - let json = body.trim_start_matches("Document content: "); - //println!("{}",json); - // Parse the JSON string into a generic JSON structure - //let value: serde_json::Value = serde_json::from_str(json).unwrap(); - // Print the part of the JSON structure that's causing the error - //let error_part = value.pointer("/results/0").unwrap(); - //println!("Error part: {}", error_part); - // Parse the JSON string into the Response struct - let data: std::result::Result, _> = serde_json::from_str(json); - match data { - Ok(data) => Ok(data.results), - Err(e) => { - let column = e.column(); - let start = (column as isize - 30).max(0) as usize; - let end = (column + 30).min(json.len()); - println!("Error at column {}: {}", column, &json[start..end]); - Err(e.into()) // Remove the semicolon here - } - } -} -async fn query_custom_fields( - client: &Client, - base_url: &str, -) -> std::result::Result, Box> { - let res = client - .get(format!("{}/api/custom_fields/", base_url)) - .send() - .await?; - let body = res.text().await?; - // Remove the "Document content: " prefix - let json = body.trim_start_matches("Field: "); - let data: std::result::Result, _> = serde_json::from_str(json); - match data { - Ok(data) => { - println!("Fields: {:?}", data.results); - Ok(data.results) - }, - Err(e) => { - let column = e.column(); - let start = (column as isize - 30).max(0) as usize; - let end = (column + 30).min(json.len()); - println!("Error at column {}: {}", column, &json[start..end]); - Err(e.into()) // Remove the semicolon here - } - } -} -async fn generate_response( - ollama: &Ollama, - model: &String, - prompt_base: &String, - document: &Document, -) -> std::result::Result> { - let prompt = format!("{} {}", document.content, prompt_base); - let res = ollama - .generate(GenerationRequest::new(model.clone(), prompt)) - .await; - res.map_err(|e| e.into()) // Map the Err variant to a Box -} #[tokio::main] async fn main() -> Result<(), Box> { let base_url = std::env::var("PAPERLESS_BASE_URL").unwrap(); diff --git a/src/paperless.rs b/src/paperless.rs new file mode 100644 index 0000000..143fa1b --- /dev/null +++ b/src/paperless.rs @@ -0,0 +1,63 @@ +use reqwest::Client; +use serde::de::StdError; +use crate::{Document, Field, Response}; + +pub async fn get_data_from_paperless( + client: &Client, + url: &str, +) -> std::result::Result, Box> { + // Read token from environment + //Define filter string + let filter = "NOT tagged=true".to_string(); + + let response = client.get(format!("{}/api/documents/?query={}", url, filter)).send().await?; + let body = response.text().await?; + + // Remove the "Document content: " prefix + let json = body.trim_start_matches("Document content: "); + //println!("{}",json); + // Parse the JSON string into a generic JSON structure + //let value: serde_json::Value = serde_json::from_str(json).unwrap(); + + // Print the part of the JSON structure that's causing the error + //let error_part = value.pointer("/results/0").unwrap(); + //println!("Error part: {}", error_part); + // Parse the JSON string into the Response struct + let data: std::result::Result, _> = serde_json::from_str(json); + match data { + Ok(data) => Ok(data.results), + Err(e) => { + let column = e.column(); + let start = (column as isize - 30).max(0) as usize; + let end = (column + 30).min(json.len()); + println!("Error at column {}: {}", column, &json[start..end]); + Err(e.into()) // Remove the semicolon here + } + } +} +pub async fn query_custom_fields( + client: &Client, + base_url: &str, +) -> std::result::Result, Box> { + let res = client + .get(format!("{}/api/custom_fields/", base_url)) + .send() + .await?; + let body = res.text().await?; + // Remove the "Document content: " prefix + let json = body.trim_start_matches("Field: "); + let data: std::result::Result, _> = serde_json::from_str(json); + match data { + Ok(data) => { + println!("Fields: {:?}", data.results); + Ok(data.results) + }, + Err(e) => { + let column = e.column(); + let start = (column as isize - 30).max(0) as usize; + let end = (column + 30).min(json.len()); + println!("Error at column {}: {}", column, &json[start..end]); + Err(e.into()) // Remove the semicolon here + } + } +} \ No newline at end of file