Skip to content

Commit

Permalink
feature: local llm support (#16)
Browse files Browse the repository at this point in the history
* refactor: creating an LLM module in libmemex and moving embedding module inside

* adding a basic run prompt function for llm testing inside memex

* wip: sherpa - guiding llms using logit biasing, templates, etc.

* wip: tweaking sampler to output only what we want

* updating .env.template file

* removing sherpa stuff for now

* Creating LLM trait and sharing structs between local LLM & OpenAI impls

* Using LLM trait in API to switch between local/OpenAI when configured

* load llm client based on whether `OPENAI_API_KEY` or `LOCAL_LLM_CONFIG`
is set

* update README to point that out

* Create samplers from config and pass into `LocalLLM` struct

* add basic support for other architectures, but focus on llama for now

* add LLM ask example to README

* output debug messages for local llm responses

* correctly capture server errors & send back the error messages

* removing old dep

* ignore model test

* cargo fmt
  • Loading branch information
a5huynh authored Oct 9, 2023
1 parent 9b15fee commit 2f9cea6
Show file tree
Hide file tree
Showing 21 changed files with 754 additions and 203 deletions.
10 changes: 9 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@ PORT=8181
# Use postgres for "production"
DATABASE_CONNECTION=sqlite://data/sqlite.db
# Use qdrant/etc. for "production"
VECTOR_CONNECTION=hnsw://data/vdb
VECTOR_CONNECTION=hnsw://data/vdb
# When using OpenSearch as the vector backend
# VECTOR_CONNECTION=opensearch+https://admin:admin@localhost:9200

# If using OpenAPI, setup your API key here
OPENAI_API_KEY=
# Or point to local LLM configuration file. By default, memex wil use
# llama2
LOCAL_LLM_CONFIG=resources/config.llama2.toml
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,22 @@ since Linux ARM builds are very finicky.
2023-06-13T05:04:21.518732Z INFO memex: starting server with roles: [Api, Worker]
```

## Add a document
## Using a LLM
You can use either OpenAI or a local LLM for LLM based functionality (such as the
summarization or extraction APIs).

Set `OPENAI_API_KEY` to your API key in the `.env` file or set `LOCAL_LLM_CONFIG` to
a LLM configuration file. See `resources/config.llama2.toml` for an example. By
default, a base memex will use the llama-2 configuration file.

### Supported local models

Currently we have supported (and have tested) the following models:
- Llama based models (llama 1 & 2, Mistral, etc.) - *recommended*
- Gptj (e.g. GPT4All)


## Adding a document

NOTE: If the `test` collection does not initially exist, it'll be created.

Expand Down Expand Up @@ -80,7 +95,7 @@ Or if it's finished, something like so:
One the task is shown as "Completed", you can now run a query against the doc(s)
you've just added.

## Run a query
## Run a search query

``` bash
> curl http://localhost:8181/api/collections/test/search \
Expand All @@ -100,6 +115,22 @@ you've just added.
}
```

## Ask a question
```bash
> curl http://localhost:8181/api/action/ask \
-H "Content-Type: application/json" \
-X POST \
-d "{\"text\": \"<context if any>\", \"query\": \"What is the airspeed velocity of an unladen swallow?\", "json_schema": { .. }}"
{
"time": 1.234,
"status": "ok",
"result": {
"answer": "The airspeed velocity of an unladen swallow is..."
}
}

```

## Env variables

- `HOST`: Defaults to `127.0.0.1`
Expand Down
19 changes: 18 additions & 1 deletion bin/memex/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use api::ApiConfig;
use clap::{Parser, Subcommand};
use futures::future::join_all;
use std::{net::Ipv4Addr, process::ExitCode};
Expand Down Expand Up @@ -25,6 +26,10 @@ pub struct Args {
database_connection: Option<String>,
#[clap(long, value_parser, value_name = "VECTOR_CONNECTION", env)]
vector_connection: Option<String>,
#[clap(long, value_parser, value_name = "OPENAI_API_KEY", env)]
openai_api_key: Option<String>,
#[clap(long, value_parser, value_name = "LOCAL_LLM_CONFIG", env)]
local_llm_config: Option<String>,
}

#[derive(Debug, Display, Clone, PartialEq, EnumString)]
Expand Down Expand Up @@ -100,9 +105,21 @@ async fn main() -> ExitCode {

let _vector_store_uri = args.vector_connection.expect("VECTOR_CONNECTION not set");

if args.openai_api_key.is_none() && args.local_llm_config.is_none() {
log::error!("Must set either OPENAI_API_KEY or LOCAL_LLM_CONFIG");
return ExitCode::FAILURE;
}

if roles.contains(&Roles::Api) {
let db_uri = db_uri.clone();
handles.push(tokio::spawn(api::start(host, port, db_uri)));
let cfg = ApiConfig {
host,
port,
db_uri,
open_ai_key: args.openai_api_key,
local_llm_config: args.local_llm_config,
};
handles.push(tokio::spawn(api::start(cfg)));
}

if roles.contains(&Roles::Worker) {
Expand Down
8 changes: 5 additions & 3 deletions lib/api/src/endpoints/actions/filters.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::Arc;

use crate::{endpoints::json_body, with_db, with_llm};
use libmemex::llm::openai::OpenAIClient;
use libmemex::llm::LLM;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand All @@ -24,7 +26,7 @@ pub struct SummarizeRequest {
}

fn extract(
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
warp::path!("action" / "ask")
.and(warp::post())
Expand All @@ -44,7 +46,7 @@ fn summarize(
}

pub fn build(
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
db: &DatabaseConnection,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
extract(llm).or(summarize(db))
Expand Down
14 changes: 7 additions & 7 deletions lib/api/src/endpoints/actions/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::{
schema::{ApiResponse, TaskResult},
ServerError,
Expand All @@ -9,19 +11,16 @@ use warp::reject::Rejection;
use super::filters;
use libmemex::{
db::queue,
llm::{
openai::{truncate_text, OpenAIClient},
prompter,
},
llm::{prompter, LLM},
};

pub async fn handle_extract(
llm: OpenAIClient,
llm: Arc<Box<dyn LLM>>,
request: filters::AskRequest,
) -> Result<impl warp::Reply, Rejection> {
let time = std::time::Instant::now();

let (content, model) = truncate_text(&request.text);
let (content, model) = llm.truncate_text(&request.text);

// Build prompt
let prompt = if let Some(schema) = &request.json_schema {
Expand All @@ -34,10 +33,11 @@ pub async fn handle_extract(
};

let response = llm
.chat_completion(&model, &prompt)
.chat_completion(model.as_ref(), &prompt)
.await
.map_err(|err| ServerError::Other(err.to_string()))?;

log::debug!("llm response: {response}");
let val = serde_json::from_str::<serde_json::Value>(&response)
.map_err(|err| ServerError::Other(err.to_string()))?;

Expand Down
2 changes: 1 addition & 1 deletion lib/api/src/endpoints/collections/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use libmemex::{
db::{embedding, queue},
embedding::{ModelConfig, SentenceEmbedder},
llm::embedding::{ModelConfig, SentenceEmbedder},
storage::get_vector_storage,
};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
Expand Down
6 changes: 4 additions & 2 deletions lib/api/src/endpoints/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use libmemex::llm::openai::OpenAIClient;
use std::sync::Arc;

use libmemex::llm::LLM;
use sea_orm::DatabaseConnection;
use serde::de::DeserializeOwned;
use warp::Filter;
Expand All @@ -24,7 +26,7 @@ pub fn json_body<T: std::marker::Send + DeserializeOwned>(

pub fn build(
db: &DatabaseConnection,
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
actions::filters::build(llm, db)
.or(collections::filters::build(db))
Expand Down
48 changes: 37 additions & 11 deletions lib/api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use dotenv_codegen::dotenv;
use libmemex::{db::create_connection_by_uri, llm::openai::OpenAIClient};
use libmemex::{
db::create_connection_by_uri,
llm::{local::load_from_cfg, openai::OpenAIClient, LLM},
};
use sea_orm::DatabaseConnection;
use serde_json::json;
use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf};
use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf, sync::Arc};
use thiserror::Error;
use warp::{hyper::StatusCode, reject::Reject, Filter, Rejection, Reply};

Expand All @@ -22,6 +25,14 @@ pub enum ServerError {

impl Reject for ServerError {}

pub struct ApiConfig {
pub host: Ipv4Addr,
pub port: u16,
pub db_uri: String,
pub open_ai_key: Option<String>,
pub local_llm_config: Option<String>,
}

// Handle custom errors/rejections
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
let code;
Expand All @@ -35,6 +46,12 @@ async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
// and render it however we want
code = StatusCode::METHOD_NOT_ALLOWED;
message = "METHOD_NOT_ALLOWED".into();
} else if let Some(err) = err.find::<ServerError>() {
(code, message) = match err {
ServerError::ClientRequestError(err) => (StatusCode::BAD_REQUEST, err.to_string()),
ServerError::DatabaseError(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
ServerError::Other(err) => (StatusCode::BAD_REQUEST, err.to_string()),
};
} else {
// We should have expected this... Just log and say its a 500
eprintln!("unhandled rejection: {:?}", err);
Expand All @@ -59,8 +76,8 @@ pub fn health_check() -> impl Filter<Extract = (impl warp::Reply,), Error = warp
.map(move || warp::reply::json(&json!({ "version": version })))
}

pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
log::info!("starting api server @ {}:{}", host, port);
pub async fn start(config: ApiConfig) {
log::info!("starting api server @ {}:{}", config.host, config.port);

log::info!("checking for upload directory...");
let data_dir_path: PathBuf = endpoints::UPLOAD_DATA_DIR.into();
Expand All @@ -70,12 +87,21 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
}

// Attempt to connect to db
let db_connection = create_connection_by_uri(&db_uri, true)
let db_connection = create_connection_by_uri(&config.db_uri, true)
.await
.unwrap_or_else(|err| panic!("Unable to connect to database: {} - {err}", db_uri));
.unwrap_or_else(|err| panic!("Unable to connect to database: {} - {err}", config.db_uri));

let llm_client: Arc<Box<dyn LLM>> = if let Some(openai_key) = config.open_ai_key {
Arc::new(Box::new(OpenAIClient::new(&openai_key)))
} else if let Some(llm_config_path) = config.local_llm_config {
let llm = load_from_cfg(llm_config_path.into(), true)
.await
.expect("Unable to load local LLM");
Arc::new(llm)
} else {
panic!("Please setup OPENAI_API_KEY or LOCAL_LLM_CONFIG");
};

let llm_client =
OpenAIClient::new(&std::env::var("OPENAI_API_KEY").expect("OpenAI API key not set"));
let cors = warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"])
Expand All @@ -88,7 +114,7 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
let filters = health_check().or(api).with(cors).recover(handle_rejection);

let (_addr, handle) =
warp::serve(filters).bind_with_graceful_shutdown((host, port), async move {
warp::serve(filters).bind_with_graceful_shutdown((config.host, config.port), async move {
tokio::signal::ctrl_c()
.await
.expect("failed to listen to shutdown signal");
Expand All @@ -105,7 +131,7 @@ pub fn with_db(
}

pub fn with_llm(
llm: OpenAIClient,
) -> impl Filter<Extract = (OpenAIClient,), Error = std::convert::Infallible> + Clone {
llm: Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (Arc<Box<dyn LLM>>,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || llm.clone())
}
6 changes: 5 additions & 1 deletion lib/libmemex/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@ chrono = { workspace = true }
dotenv = { workspace = true }
handlebars = "4.4.0"
hnsw_rs = { git = "https://github.com/jean-pierreBoth/hnswlib-rs", rev = "52a7f9174e002820d168fa65ca7303364ee3ac33" }
llm = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" }
log = { workspace = true }
migration ={ path = "../../migration" }
opensearch = "2.1.0"
qdrant-client = "1.2.0"
reqwest = { version = "0.11.16", features = ["stream" ] }
rand = "0.8.5"
rust-bert = { version = "0.21.0", features= ["download-libtorch"] }
sea-orm = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
strum = "0.25"
strum_macros = "0.25"
tera = "1.19.0"
thiserror = "1.0"
tiktoken-rs = "0.5.4"
tokenizers = { version = "0.14", features = ["http"] }
tokio = { workspace = true }
toml = "0.7.4"
url = "2.4.0"
uuid = { workspace = true }
uuid = { workspace = true }
1 change: 0 additions & 1 deletion lib/libmemex/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod db;
pub mod embedding;
pub mod llm;
pub mod storage;

Expand Down
File renamed without changes.
Loading

0 comments on commit 2f9cea6

Please sign in to comment.