From a698dfc6aaa7776661050b1cec22822de2c86f11 Mon Sep 17 00:00:00 2001 From: JoyboyBrian Date: Wed, 28 Aug 2024 03:31:04 +0000 Subject: [PATCH] rename, remove langchain dependency --- examples/financial-advisor/README.md | 3 +- examples/financial-advisor/app.py | 60 ++++---- .../utils/financial_analyzer.py | 111 +++++++++++++++ .../financial-advisor/utils/pdf_processor.py | 118 ---------------- .../financial-advisor/utils/text_generator.py | 130 ------------------ 5 files changed, 142 insertions(+), 280 deletions(-) create mode 100644 examples/financial-advisor/utils/financial_analyzer.py delete mode 100644 examples/financial-advisor/utils/pdf_processor.py delete mode 100644 examples/financial-advisor/utils/text_generator.py diff --git a/examples/financial-advisor/README.md b/examples/financial-advisor/README.md index 8af14ac8..1b222b87 100644 --- a/examples/financial-advisor/README.md +++ b/examples/financial-advisor/README.md @@ -12,7 +12,6 @@ - File structure: - `app.py`: main Streamlit application - - `utils/pdf_processor.py`: processes PDF files and creates embeddings - `utils/text_generator.py`: handles similarity search and text generation - `assets/fake_bank_statements`: fake bank statement for testing purpose @@ -27,7 +26,7 @@ pip install -r requirements.txt 2. Usage: - Run the Streamlit app: `streamlit run app.py` -- Upload PDF financial docs (bank statements, SEC filings, etc.) and process them. +- Upload PDF financial docs (bank statements, SEC filings, etc.) and process them - Use the chat interface to query your financial data ### Resources: diff --git a/examples/financial-advisor/app.py b/examples/financial-advisor/app.py index 2e9938bc..0043633b 100644 --- a/examples/financial-advisor/app.py +++ b/examples/financial-advisor/app.py @@ -1,9 +1,6 @@ import sys import os import streamlit as st -from typing import Iterator -import subprocess -import json import shutil import pdfplumber from sentence_transformers import SentenceTransformer @@ -12,8 +9,7 @@ import re import traceback import logging -from nexa.gguf import NexaTextInference -import utils.text_generator as tg +from utils.financial_analyzer import FinancialAnalyzer # set up logging: logging.basicConfig(level=logging.INFO) @@ -26,12 +22,10 @@ @st.cache_resource def load_model(model_path): - st.session_state.messages = [] - nexa_model = NexaTextInference(model_path) - return nexa_model + return FinancialAnalyzer(model_path) def generate_response(query: str) -> str: - result = tg.financial_analysis(query) + result = st.session_state.nexa_model.financial_analysis(query) if isinstance(result, dict) and "error" in result: return f"An error occurred: {result['error']}" return result @@ -59,7 +53,7 @@ def chunk_text(text, model, max_tokens=256, overlap=20): current_tokens = 0 for sentence in sentences: - sentence_tokens = len(model.tokenizer.tokenize(sentence)) + sentence_tokens = len(model.tokenize(sentence)) if current_tokens + sentence_tokens > max_tokens: if current_chunk: chunks.append(' '.join(current_chunk)) @@ -165,11 +159,11 @@ def process_pdfs(uploaded_files): # verify files were saved & reload the FAISS index: if os.path.exists(os.path.join(output_dir, 'pdf_index.faiss')) and \ - os.path.exists(os.path.join(output_dir, 'pdf_chunks.npy')): - # Reload the FAISS index - tg.embeddings, tg.index, tg.stored_docs = tg.load_faiss_index() - st.success("PDFs processed and FAISS index reloaded successfully!") - return True + os.path.exists(os.path.join(output_dir, 'pdf_chunks.npy')): + # Reload the FAISS index + st.session_state.nexa_model.load_faiss_index() + st.success("PDFs processed and FAISS index reloaded successfully!") + return True else: st.error("Error: Processed files not found after saving.") return False @@ -181,9 +175,11 @@ def process_pdfs(uploaded_files): return False def check_faiss_index(): - if tg.embeddings is None or tg.index is None or tg.stored_docs is None: - tg.embeddings, tg.index, tg.stored_docs = tg.load_faiss_index() - return tg.embeddings is not None and tg.index is not None and tg.stored_docs is not None + if "nexa_model" not in st.session_state: + return False + return (st.session_state.nexa_model.embeddings_model is not None and + st.session_state.nexa_model.index is not None and + st.session_state.nexa_model.stored_docs is not None) # Streamlit app: def main(): @@ -193,6 +189,9 @@ def main(): # add an empty line: st.markdown("
", unsafe_allow_html=True) + if "nexa_model" not in st.session_state: + st.session_state.nexa_model = load_model(default_model) + # check if FAISS index exists: if not check_faiss_index(): st.info("No processed financial documents found. Please upload and process PDFs.") @@ -220,24 +219,25 @@ def main(): st.warning("Please enter a valid path or identifier for the model in Nexa Model Hub to proceed.") st.stop() - if "current_model_path" not in st.session_state or st.session_state.current_model_path != model_path: + if "nexa_model" not in st.session_state or "current_model_path" not in st.session_state or st.session_state.current_model_path != model_path: st.session_state.current_model_path = model_path st.session_state.nexa_model = load_model(model_path) if st.session_state.nexa_model is None: st.stop() st.sidebar.header("Generation Parameters") - temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.nexa_model.params["temperature"]) - max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 500, st.session_state.nexa_model.params["max_new_tokens"]) - top_k = st.sidebar.slider("Top K", 1, 100, st.session_state.nexa_model.params["top_k"]) - top_p = st.sidebar.slider("Top P", 0.0, 1.0, st.session_state.nexa_model.params["top_p"]) - - st.session_state.nexa_model.params.update({ - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "top_k": top_k, - "top_p": top_p, - }) + params = st.session_state.nexa_model.get_params() + temperature = st.sidebar.slider("Temperature", 0.0, 1.0, params["temperature"]) + max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 500, params["max_new_tokens"]) + top_k = st.sidebar.slider("Top K", 1, 100, params["top_k"]) + top_p = st.sidebar.slider("Top P", 0.0, 1.0, params["top_p"]) + + st.session_state.nexa_model.set_params( + temperature=temperature, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p + ) # step 3 - interactive financial analysis chat: st.header("Let's discuss your finances🧑‍💼") diff --git a/examples/financial-advisor/utils/financial_analyzer.py b/examples/financial-advisor/utils/financial_analyzer.py new file mode 100644 index 00000000..5723bd01 --- /dev/null +++ b/examples/financial-advisor/utils/financial_analyzer.py @@ -0,0 +1,111 @@ +import os +import faiss +import numpy as np +import logging +from nexa.gguf import NexaTextInference +from sentence_transformers import SentenceTransformer +from langchain_core.documents import Document + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class FinancialAnalyzer: + def __init__(self, model_path="gemma"): + self.model_path = model_path + self.inference = NexaTextInference( + model_path=self.model_path, + stop_words=[], + temperature=0.7, + max_new_tokens=256, + top_k=50, + top_p=0.9, + profiling=False + ) + self.embeddings_model = SentenceTransformer('all-MiniLM-L6-v2') + self.index = None + self.stored_docs = None + self.load_faiss_index() + + def get_params(self): + return self.inference.params + + def set_params(self, **kwargs): + self.inference.params.update(kwargs) + + def load_faiss_index(self): + try: + faiss_index_dir = "./assets/output/processed_data" + if not os.path.exists(faiss_index_dir): + logger.warning(f"FAISS index directory not found: {faiss_index_dir}") + return + + index_file = os.path.join(faiss_index_dir, "pdf_index.faiss") + if not os.path.exists(index_file): + logger.warning(f"FAISS index file not found: {index_file}") + return + + self.index = faiss.read_index(index_file) + logger.info(f"FAISS index loaded successfully.") + + doc_file = os.path.join(faiss_index_dir, "pdf_chunks.npy") + self.stored_docs = np.load(doc_file, allow_pickle=True) + logger.info(f"Loaded {len(self.stored_docs)} documents") + + if not isinstance(self.stored_docs[0], Document): + self.stored_docs = [Document(page_content=doc) for doc in self.stored_docs] + + except Exception as e: + logger.error(f"Error loading FAISS index: {str(e)}") + + def custom_search(self, query, k=3): + if self.embeddings_model is None or self.index is None or self.stored_docs is None: + logger.error("FAISS index or embeddings model not properly loaded") + return [] + try: + query_vector = self.embeddings_model.encode([query])[0] + scores, indices = self.index.search(np.array([query_vector]), k) + docs = [self.stored_docs[i] for i in indices[0]] + return list(zip(docs, scores[0])) + except Exception as e: + logger.error(f"Error in custom_search: {str(e)}") + return [] + + def truncate_text(self, text, max_tokens=256): + tokens = text.split() + if len(tokens) <= max_tokens: + return text + return ' '.join(tokens[:max_tokens]) + + def financial_analysis(self, query): + try: + if self.embeddings_model is None or self.index is None or self.stored_docs is None: + logger.error("FAISS index not loaded. Please process PDF files first.") + return {"error": "FAISS index not loaded. Please process PDF files first."} + + relevant_docs = self.custom_search(query, k=1) + if not relevant_docs: + logger.warning("No relevant documents found for the query.") + return {"error": "No relevant documents found for the query."} + + context = "\n".join([doc.page_content for doc, _ in relevant_docs]) + truncated_context = self.truncate_text(context) + prompt = f"Financial context: {truncated_context}\n\nAnalyze: {query}" + + prompt_tokens = len(prompt.split()) + logger.info(f"Prompt length: {prompt_tokens} tokens") + + if prompt_tokens > 250: + prompt = self.truncate_text(prompt, 250) + logger.info(f"Truncated prompt length: {len(prompt.split())} tokens") + + llm_input = [ + {"role": "user", "content": prompt} + ] + + return self.inference.create_chat_completion(llm_input, stream=True) + + except Exception as e: + logger.error(f"Error in financial_analysis: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return {"error": str(e)} \ No newline at end of file diff --git a/examples/financial-advisor/utils/pdf_processor.py b/examples/financial-advisor/utils/pdf_processor.py deleted file mode 100644 index 4f6788cc..00000000 --- a/examples/financial-advisor/utils/pdf_processor.py +++ /dev/null @@ -1,118 +0,0 @@ -import warnings -warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*") - -import os -import pdfplumber -from sentence_transformers import SentenceTransformer -import faiss -import numpy as np -import re - -input_dir = './assets/input/' -output_dir = './assets/output/processed_data/' - -# extract pdf file one by one: -def extract_text_from_pdf(pdf_path): - print(f"1️⃣ Extracting text from {pdf_path}") - with pdfplumber.open(pdf_path) as pdf: - text = '' - for i, page in enumerate(pdf.pages): - page_text = page.extract_text() - text += page_text + '\n' - print(f"Processed page {i+1}/{len(pdf.pages)}") - return text - -# chunk the file by tokens: -def chunk_text(text, model, max_tokens=256, overlap=20): - print("2️⃣ Chunking extracted text ...") - - # split the text into definitions and other parts - definitions = re.findall(r'"\w+(?:\s+\w+)*"\s+means.*?(?="\w+(?:\s+\w+)*"\s+means|\Z)', text, re.DOTALL) - other_parts = re.split(r'"\w+(?:\s+\w+)*"\s+means.*?(?="\w+(?:\s+\w+)*"\s+means|\Z)', text) - - chunks = [] - - for definition in definitions: - if len(model.tokenizer.tokenize(definition)) <= max_tokens: - chunks.append(definition.strip()) - else: - # if a definition is too long, split it into smaller parts - sentences = re.split(r'(?<=[.!?])\s+', definition) - current_chunk = [] - current_tokens = 0 - for sentence in sentences: - sentence_tokens = len(model.tokenizer.tokenize(sentence)) - if current_tokens + sentence_tokens > max_tokens: - chunks.append(' '.join(current_chunk).strip()) - current_chunk = [sentence] - current_tokens = sentence_tokens - else: - current_chunk.append(sentence) - current_tokens += sentence_tokens - if current_chunk: - chunks.append(' '.join(current_chunk).strip()) - - # process other parts - for part in other_parts: - if part.strip(): - sentences = re.split(r'(?<=[.!?])\s+', part) - current_chunk = [] - current_tokens = 0 - for sentence in sentences: - sentence_tokens = len(model.tokenizer.tokenize(sentence)) - if current_tokens + sentence_tokens > max_tokens: - chunks.append(' '.join(current_chunk).strip()) - current_chunk = [sentence] - current_tokens = sentence_tokens - else: - current_chunk.append(sentence) - current_tokens += sentence_tokens - if current_chunk: - chunks.append(' '.join(current_chunk).strip()) - - chunks = [chunk for chunk in chunks if chunk.strip()] # remove empty chunks, if any - - chunk_sizes = [len(model.tokenizer.tokenize(chunk)) for chunk in chunks] - print(f"👉 Created {len(chunks)} chunks") - print(f" Chunk sizes: min={min(chunk_sizes)}, max={max(chunk_sizes)}, avg={sum(chunk_sizes)/len(chunk_sizes):.1f}") - # print(f"👀{chunks}") - - return chunks - -# create embeddings for all chunks at once: -def create_embeddings(chunks, model): - print("3️⃣ Creating embeddings ...") - embeddings = model.encode(chunks) - print(f"👉 Created embeddings of shape: {embeddings.shape}") - return embeddings - -# add embeddings to FAISS index: -def build_faiss_index(embeddings): - print("4️⃣ Building FAISS index ...") - dimension = embeddings.shape[1] - index = faiss.IndexFlatL2(dimension) - index.add(embeddings.astype('float32')) - print(f"👉 Added {len(embeddings)} vectors to FAISS index") - return index - -model = SentenceTransformer('all-MiniLM-L6-v2') - -# process the pdf files: -all_chunks = [] -for filename in os.listdir(input_dir): - if filename.endswith('.pdf'): - pdf_path = os.path.join(input_dir, filename) - text = extract_text_from_pdf(pdf_path) - file_chunks = chunk_text(text, model) # using default overlap (20) - all_chunks.extend(file_chunks) - print(f" File: {filename}, Chunks: {len(file_chunks)}") -print(f"✅ Total chunks from all PDFs: {len(all_chunks)}") - -embeddings = create_embeddings(all_chunks, model) -faiss_index = build_faiss_index(embeddings) - -# save the index and chunks: -print("5️⃣ Saving FAISS index and chunks ...") -os.makedirs(output_dir, exist_ok=True) -faiss.write_index(faiss_index, os.path.join(output_dir, 'pdf_index.faiss')) -np.save(os.path.join(output_dir, 'pdf_chunks.npy'), all_chunks) diff --git a/examples/financial-advisor/utils/text_generator.py b/examples/financial-advisor/utils/text_generator.py deleted file mode 100644 index f792916e..00000000 --- a/examples/financial-advisor/utils/text_generator.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -import faiss -import numpy as np -import logging -from nexa.gguf import NexaTextInference -from langchain_community.embeddings import HuggingFaceEmbeddings -from langchain_core.documents import Document - -# set up logging: -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# model initialization: -model_path = "gemma" -inference = NexaTextInference( - model_path=model_path, - stop_words=[], - temperature=0.7, - max_new_tokens=256, - top_k=50, - top_p=0.9, - profiling=False -) - -print(f"Model loaded: {inference.downloaded_path}") -print(f"Chat format: {inference.chat_format}") - -# global variables: -embeddings = None -index = None -stored_docs = None - -# load FAISS index: -def load_faiss_index(): - global embeddings, index, stored_docs - try: - embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") - - faiss_index_dir = "./assets/output/processed_data" - - if not os.path.exists(faiss_index_dir): - logger.warning(f"FAISS index directory not found: {faiss_index_dir}") - return None, None, None - - index_file = os.path.join(faiss_index_dir, "pdf_index.faiss") - if not os.path.exists(index_file): - logger.warning(f"FAISS index file not found: {index_file}") - return None, None, None - - index = faiss.read_index(index_file) - logger.info(f"FAISS index loaded successfully.") - - # load the chunks: - doc_file = os.path.join(faiss_index_dir, "pdf_chunks.npy") - stored_docs = np.load(doc_file, allow_pickle=True) - logger.info(f"Loaded {len(stored_docs)} documents") - - # convert stored_docs to a list of Document objects: - if not isinstance(stored_docs[0], Document): - stored_docs = [Document(page_content=doc) for doc in stored_docs] - - return embeddings, index, stored_docs - - except Exception as e: - logger.error(f"Error loading FAISS index: {str(e)}") - return None, None, None - -# load the index at module level: -embeddings, index, stored_docs = load_faiss_index() - -def custom_search(query, k=3): - global embeddings, index, stored_docs - if embeddings is None or index is None or stored_docs is None: - logger.error("FAISS index or embeddings not properly loaded") - return [] - try: - query_vector = embeddings.embed_query(query) - scores, indices = index.search(np.array([query_vector]), k) - docs = [stored_docs[i] for i in indices[0]] - return list(zip(docs, scores[0])) - except Exception as e: - logger.error(f"Error in custom_search: {str(e)}") - return [] - -# truncate text to a specific token limit: -def truncate_text(text, max_tokens=256): - tokens = text.split() - if len(tokens) <= max_tokens: - return text - return ' '.join(tokens[:max_tokens]) - -# query FAISS and generate LLM response: -def financial_analysis(query): - global embeddings, index, stored_docs - try: - if embeddings is None or index is None or stored_docs is None: - logger.error("FAISS index not loaded. Please process PDF files first.") - return {"error": "FAISS index not loaded. Please process PDF files first."} - - relevant_docs = custom_search(query, k=1) - if not relevant_docs: - logger.warning("No relevant documents found for the query.") - return {"error": "No relevant documents found for the query."} - - context = "\n".join([doc.page_content for doc, _ in relevant_docs]) - - # truncate the context if it's too long: - truncated_context = truncate_text(context) - - prompt = f"Financial context: {truncated_context}\n\nAnalyze: {query}" - - prompt_tokens = len(prompt.split()) - logger.info(f"Prompt length: {prompt_tokens} tokens") - - if prompt_tokens > 250: - prompt = truncate_text(prompt, 250) - logger.info(f"Truncated prompt length: {len(prompt.split())} tokens") - - llm_input = [ - {"role": "user", "content": prompt} - ] - - # return the iterator - return inference.create_chat_completion(llm_input, stream=True) - - except Exception as e: - logger.error(f"Error in financial_analysis: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return {"error": str(e)}