Skip to content

Commit

Permalink
fix(structure): split code into different modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Björn Urban committed Feb 27, 2024
1 parent e364833 commit 3d7ebf6
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 76 deletions.
17 changes: 17 additions & 0 deletions src/llm_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use ollama_rs::generation::completion::GenerationResponse;
use ollama_rs::generation::completion::request::GenerationRequest;
use ollama_rs::Ollama;
use crate::Document;

pub async fn generate_response(
ollama: &Ollama,
model: &String,
prompt_base: &String,
document: &Document,
) -> std::result::Result<GenerationResponse, Box<dyn std::error::Error>> {
let prompt = format!("{} {}", document.content, prompt_base);
let res = ollama
.generate(GenerationRequest::new(model.clone(), prompt))
.await;
res.map_err(|e| e.into()) // Map the Err variant to a Box<dyn std::error::Error>
}
85 changes: 9 additions & 76 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
mod llm_api;
mod paperless;

use ollama_rs::{
generation::completion::{request::GenerationRequest, GenerationContext, GenerationResponse},
Ollama,
};
use substring::Substring;

use reqwest::{Client, Error};
use reqwest::{Client, };
use std::result::Result;
use tokio::io::stdout;
use tokio::runtime::Runtime;

//function that fetches data from the endpoint
//write function that queries a rest endpoint for a given url
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::{Deserialize, Serialize};
use serde_json::{Value, Map};
use serde_json::{Value};
use std::collections::HashMap;
use std::env;
use std::error::Error as StdError;
use crate::llm_api::generate_response;
use crate::paperless::{get_data_from_paperless, query_custom_fields};

#[derive(Serialize, Deserialize, Debug, Clone)]
struct Document {
id: u32,
Expand Down Expand Up @@ -59,80 +63,9 @@ struct Field {
data_type: String,
}

async fn get_data_from_paperless(
client: &Client,
url: &str,
) -> std::result::Result<Vec<Document>, Box<dyn StdError + Send + Sync>> {
// Read token from environment
let token = env::var("TOKEN").expect("TOKEN is not set in .env file");

//Define filter string
let filter = "NOT tagged=true".to_string();

let response = client.get(format!("{}/api/documents/?query={}", url, filter)).send().await?;
let body = response.text().await?;

// Remove the "Document content: " prefix
let json = body.trim_start_matches("Document content: ");
//println!("{}",json);
// Parse the JSON string into a generic JSON structure
//let value: serde_json::Value = serde_json::from_str(json).unwrap();

// Print the part of the JSON structure that's causing the error
//let error_part = value.pointer("/results/0").unwrap();
//println!("Error part: {}", error_part);
// Parse the JSON string into the Response struct
let data: std::result::Result<Response<Document>, _> = serde_json::from_str(json);
match data {
Ok(data) => Ok(data.results),
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
println!("Error at column {}: {}", column, &json[start..end]);
Err(e.into()) // Remove the semicolon here
}
}
}
async fn query_custom_fields(
client: &Client,
base_url: &str,
) -> std::result::Result<Vec<Field>, Box<dyn std::error::Error>> {
let res = client
.get(format!("{}/api/custom_fields/", base_url))
.send()
.await?;
let body = res.text().await?;
// Remove the "Document content: " prefix
let json = body.trim_start_matches("Field: ");
let data: std::result::Result<Response<Field>, _> = serde_json::from_str(json);
match data {
Ok(data) => {
println!("Fields: {:?}", data.results);
Ok(data.results)
},
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
println!("Error at column {}: {}", column, &json[start..end]);
Err(e.into()) // Remove the semicolon here
}
}
}

async fn generate_response(
ollama: &Ollama,
model: &String,
prompt_base: &String,
document: &Document,
) -> std::result::Result<GenerationResponse, Box<dyn std::error::Error>> {
let prompt = format!("{} {}", document.content, prompt_base);
let res = ollama
.generate(GenerationRequest::new(model.clone(), prompt))
.await;
res.map_err(|e| e.into()) // Map the Err variant to a Box<dyn std::error::Error>
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let base_url = std::env::var("PAPERLESS_BASE_URL").unwrap();
Expand Down
63 changes: 63 additions & 0 deletions src/paperless.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use reqwest::Client;
use serde::de::StdError;
use crate::{Document, Field, Response};

pub async fn get_data_from_paperless(
client: &Client,
url: &str,
) -> std::result::Result<Vec<Document>, Box<dyn StdError + Send + Sync>> {
// Read token from environment
//Define filter string
let filter = "NOT tagged=true".to_string();

let response = client.get(format!("{}/api/documents/?query={}", url, filter)).send().await?;
let body = response.text().await?;

// Remove the "Document content: " prefix
let json = body.trim_start_matches("Document content: ");
//println!("{}",json);
// Parse the JSON string into a generic JSON structure
//let value: serde_json::Value = serde_json::from_str(json).unwrap();

// Print the part of the JSON structure that's causing the error
//let error_part = value.pointer("/results/0").unwrap();
//println!("Error part: {}", error_part);
// Parse the JSON string into the Response struct
let data: std::result::Result<Response<Document>, _> = serde_json::from_str(json);
match data {
Ok(data) => Ok(data.results),
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
println!("Error at column {}: {}", column, &json[start..end]);
Err(e.into()) // Remove the semicolon here
}
}
}
pub async fn query_custom_fields(
client: &Client,
base_url: &str,
) -> std::result::Result<Vec<Field>, Box<dyn std::error::Error>> {
let res = client
.get(format!("{}/api/custom_fields/", base_url))
.send()
.await?;
let body = res.text().await?;
// Remove the "Document content: " prefix
let json = body.trim_start_matches("Field: ");
let data: std::result::Result<Response<Field>, _> = serde_json::from_str(json);
match data {
Ok(data) => {
println!("Fields: {:?}", data.results);
Ok(data.results)
},
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
println!("Error at column {}: {}", column, &json[start..end]);
Err(e.into()) // Remove the semicolon here
}
}
}

0 comments on commit 3d7ebf6

Please sign in to comment.