diff --git a/.gitignore b/.gitignore index 891226ed..9da7f837 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .env __pycache__ uploads/ -myenv/ \ No newline at end of file +myenv/ +venv/ \ No newline at end of file diff --git a/README.md b/README.md index 991f4534..f76140b4 100644 Binary files a/README.md and b/README.md differ diff --git a/config.py b/config.py index b1938ca2..1e5ff507 100644 --- a/config.py +++ b/config.py @@ -16,6 +16,10 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: return default_value return value +RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/") +if not os.path.exists(RAG_UPLOAD_DIR): + os.makedirs(RAG_UPLOAD_DIR, exist_ok=True) + POSTGRES_DB = get_env_variable("POSTGRES_DB") POSTGRES_USER = get_env_variable("POSTGRES_USER") POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD") @@ -25,7 +29,7 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500")) CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100")) -UPLOAD_DIR = get_env_variable("UPLOAD_DIR", "./uploads/") + env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower() PDF_EXTRACT_IMAGES = True if env_value == "true" else False diff --git a/main.py b/main.py index 62567dc4..07e606a3 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,10 @@ import os import hashlib +from shutil import copyfileobj from langchain.schema import Document from contextlib import asynccontextmanager from dotenv import find_dotenv, load_dotenv -from fastapi import FastAPI, HTTPException, status +from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from langchain_core.runnables.config import run_in_executor from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -34,6 +35,7 @@ CHUNK_SIZE, CHUNK_OVERLAP, vector_store, + RAG_UPLOAD_DIR, known_source_ext, PDF_EXTRACT_IMAGES, # RAG_EMBEDDING_MODEL, @@ -121,29 +123,26 @@ async def delete_documents(ids: list[str]): @app.post("/query") async def query_embeddings_by_file_id(body: QueryRequestBody): try: - # Get the embedding of the query text embedding = vector_store.embedding_function.embed_query(body.query) - # Perform similarity search with the query embedding and filter by the file_id in metadata if isinstance(vector_store, AsyncPgVector): documents = await run_in_executor( None, vector_store.similarity_search_with_score_by_vector, embedding, k=body.k, - filter={"custom_id": body.file_id} + filter={"file_id": body.file_id} ) else: documents = vector_store.similarity_search_with_score_by_vector( embedding, k=body.k, - filter={"custom_id": body.file_id} + filter={"file_id": body.file_id} ) return documents except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - def generate_digest(page_content: str): hash_obj = hashlib.md5(page_content.encode()) @@ -264,6 +263,42 @@ async def embed_file(document: StoreDocument): detail=ERROR_MESSAGES.DEFAULT(e), ) +@app.post("/embed-upload") +async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile = File(...)): + temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename) + + try: + with open(temp_file_path, 'wb') as temp_file: + copyfileobj(uploaded_file.file, temp_file) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to save the uploaded file. Error: {str(e)}") + + try: + loader, known_type = get_loader(uploaded_file.filename, uploaded_file.content_type, temp_file_path) + + data = loader.load() + result = await store_data_in_vector_db(data, file_id) + + if not result: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process/store the file data.", + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Error during file processing: {str(e)}") + finally: + os.remove(temp_file_path) + + return { + "status": True, + "message": "File processed successfully.", + "file_id": file_id, + "filename": uploaded_file.filename, + "known_type": known_type, + } + @app.post("/query_multiple") async def query_embeddings_by_file_ids(body: QueryMultipleBody): try: diff --git a/pgvector_routes.py b/pgvector_routes.py index b1994e8f..32bada0c 100644 --- a/pgvector_routes.py +++ b/pgvector_routes.py @@ -74,3 +74,20 @@ async def get_all_records(table_name: str): records_json = [dict(record) for record in records] return records_json + +@router.get("/records") +async def get_records_filtered_by_custom_id(custom_id: str, table_name: str = "langchain_pg_embedding"): + # Validate that the table name is one of the expected ones to prevent SQL injection + if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]: + raise HTTPException(status_code=400, detail="Invalid table name") + + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + # Use parameterized queries to prevent SQL Injection + query = f"SELECT * FROM {table_name} WHERE custom_id=$1;" + records = await conn.fetch(query, custom_id) + + # Convert records to JSON serializable format, assuming the Record class has a dict method. + records_json = [dict(record) for record in records] + + return records_json \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1a1ac8c5..a0c1e154 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ docx2txt==0.8 pypandoc==1.13 python-jose==3.3.0 asyncpg==0.29.0 +python-multipart==0.0.9