Skip to content

Commit

Permalink
fix[chat]: metadata and filtering error margins (#22)
Browse files Browse the repository at this point in the history
* fix[chat]: metadata and filtering error margins

* dev: add vscode file

* fix: error in condition

---------

Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
ArslanSaleem and gventuri authored Oct 14, 2024
1 parent 0340ba5 commit 19eed60
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 68 deletions.
61 changes: 44 additions & 17 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
user_repository,
)
from app.requests import chat_query
from app.utils import clean_text
from app.vectorstore.chroma import ChromaDB
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
Expand Down Expand Up @@ -91,7 +92,10 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
conversation_id = str(conversation.id)

content = response["response"]
content_length = len(content)
clean_content = clean_text(content)
text_references = []
not_exact_matched_refs = []

for reference in response["references"]:
sentence = reference["sentence"]
Expand All @@ -102,28 +106,51 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d

doc_sent, doc_ids, doc_metadata = vectorstore.get_relevant_segments(
original_sentence,
k=1,
k=5,
num_surrounding_sentences=0,
# TODO: uncomment this when we fix the filename in the metadata
# metadata_filter={"filename": original_filename}
metadata_filter={"filename": original_filename},
)

for sent, id, metadata in zip(doc_sent, doc_ids, doc_metadata):
index = content.find(sentence)
if index != -1:
text_reference = {
"asset_id": metadata["asset_id"],
"project_id": metadata["project_id"],
"page_number": metadata["page_number"],
"filename": original_filename,
"source": [sent],
"start": index,
"end": index + len(sentence),
}
text_references.append(text_reference)
# Search for exact match
best_match_index = 0

for index, sent in enumerate(doc_sent):
if clean_text(original_sentence) in clean_text(sent):
best_match_index = index

metadata = doc_metadata[best_match_index]
sent = doc_sent[best_match_index]

index = clean_content.find(clean_text(sentence))

if index != -1:
text_reference = {
"asset_id": metadata["asset_id"],
"project_id": metadata["project_id"],
"page_number": metadata["page_number"],
"filename": original_filename,
"source": [sent],
"start": index,
"end": index + len(sentence),
}
text_references.append(text_reference)
else:
no_exact_reference = {
"asset_id": metadata["asset_id"],
"project_id": metadata["project_id"],
"page_number": metadata["page_number"],
"filename": original_filename,
"source": [sent],
"start": 0,
"end": content_length,
}
not_exact_matched_refs.append(no_exact_reference)

# group text references based on start and end
refs = group_by_start_end(text_references)
if len(text_references) > 0:
refs = group_by_start_end(text_references)
else:
refs = group_by_start_end(not_exact_matched_refs)

conversation_repository.create_conversation_message(
db,
Expand Down
117 changes: 89 additions & 28 deletions backend/app/api/v1/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
@project_router.post("/", status_code=201)
def create_project(project: ProjectCreate, db: Session = Depends(get_db)):
if not project.name.strip():
raise HTTPException(status_code=400, detail="Project name is required and cannot be empty.")
raise HTTPException(
status_code=400, detail="Project name is required and cannot be empty."
)

db_project = project_repository.create_project(db=db, project=project)
return {
Expand All @@ -39,7 +41,7 @@ def create_project(project: ProjectCreate, db: Session = Depends(get_db)):
@project_router.get("/")
def get_projects(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
page_size: int = Query(100, ge=1, le=100),
db: Session = Depends(get_db),
):
try:
Expand Down Expand Up @@ -67,15 +69,20 @@ def get_projects(
}
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An internal server error occurred while processing your request. Please try again later.")
raise HTTPException(
status_code=500,
detail="An internal server error occurred while processing your request. Please try again later.",
)


@project_router.get("/{id}")
def get_project(id: int, db: Session = Depends(get_db)):
try:
project = project_repository.get_project(db=db, project_id=id)
if project is None:
raise HTTPException(status_code=404, detail="The requested project could not be found.")
raise HTTPException(
status_code=404, detail="The requested project could not be found."
)

return {
"status": "success",
Expand All @@ -92,7 +99,10 @@ def get_project(id: int, db: Session = Depends(get_db)):
raise
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An internal server error occurred while retrieving the project. Please try again later.")
raise HTTPException(
status_code=500,
detail="An internal server error occurred while retrieving the project. Please try again later.",
)


@project_router.get("/{id}/assets")
Expand Down Expand Up @@ -127,7 +137,10 @@ def get_assets(
}
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An internal server error occurred while retrieving assets. Please try again later.")
raise HTTPException(
status_code=500,
detail="An internal server error occurred while retrieving assets. Please try again later.",
)


@project_router.post("/{id}/assets")
Expand All @@ -137,7 +150,9 @@ async def upload_files(
try:
project = project_repository.get_project(db=db, project_id=id)
if project is None:
raise HTTPException(status_code=404, detail="The specified project could not be found.")
raise HTTPException(
status_code=404, detail="The specified project could not be found."
)

# Ensure the upload directory exists
os.makedirs(os.path.join(settings.upload_dir, str(id)), exist_ok=True)
Expand All @@ -147,14 +162,15 @@ async def upload_files(
# Check if the uploaded file is a PDF
if file.content_type != "application/pdf":
raise HTTPException(
status_code=400, detail=f"The file '{file.filename}' is not a valid PDF. Please upload only PDF files."
status_code=400,
detail=f"The file '{file.filename}' is not a valid PDF. Please upload only PDF files.",
)

# Check if the file size is greater than 20MB
if file.size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file."
detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file.",
)

# Generate a secure filename
Expand Down Expand Up @@ -190,7 +206,10 @@ async def upload_files(
raise
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An error occurred while uploading files. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while uploading files. Please try again later.",
)


@project_router.post("/{id}/assets/url")
Expand All @@ -199,14 +218,22 @@ async def add_url_asset(id: int, data: UrlAssetCreate, db: Session = Depends(get
urls = data.url
project = project_repository.get_project(db=db, project_id=id)
if project is None:
raise HTTPException(status_code=404, detail="The specified project could not be found.")
raise HTTPException(
status_code=404, detail="The specified project could not be found."
)

if not urls:
raise HTTPException(status_code=400, detail="No URLs provided. Please provide at least one valid URL.")
raise HTTPException(
status_code=400,
detail="No URLs provided. Please provide at least one valid URL.",
)

for url in urls:
if not is_valid_url(url):
raise HTTPException(status_code=400, detail=f"Invalid URL format: {url}. Please provide a valid URL.")
raise HTTPException(
status_code=400,
detail=f"Invalid URL format: {url}. Please provide a valid URL.",
)

url_assets = []
for url in urls:
Expand Down Expand Up @@ -244,7 +271,10 @@ async def add_url_asset(id: int, data: UrlAssetCreate, db: Session = Depends(get
raise
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An error occurred while processing the URL asset. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while processing the URL asset. Please try again later.",
)


@project_router.get("/{id}/assets/{asset_id}")
Expand All @@ -254,14 +284,18 @@ async def get_file(asset_id: int, db: Session = Depends(get_db)):

if asset is None:
raise HTTPException(
status_code=404, detail="The requested file could not be found in the database."
status_code=404,
detail="The requested file could not be found in the database.",
)

filepath = asset.path

# Check if the file exists
if not os.path.isfile(filepath):
raise HTTPException(status_code=404, detail="The requested file could not be found on the server.")
raise HTTPException(
status_code=404,
detail="The requested file could not be found on the server.",
)

# Return the file
return FileResponse(
Expand All @@ -272,7 +306,10 @@ async def get_file(asset_id: int, db: Session = Depends(get_db)):
raise
except Exception as e:
logger.error(f"Error retrieving file: {str(e)}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail="An error occurred while retrieving the file. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while retrieving the file. Please try again later.",
)


@project_router.get("/{id}/processes")
Expand Down Expand Up @@ -317,7 +354,9 @@ def update_project(id: int, project: ProjectUpdate, db: Session = Depends(get_db
try:
db_project = project_repository.get_project(db=db, project_id=id)
if db_project is None:
raise HTTPException(status_code=404, detail="The specified project could not be found.")
raise HTTPException(
status_code=404, detail="The specified project could not be found."
)
updated_project = project_repository.update_project(
db=db, project_id=id, project=project
)
Expand All @@ -330,15 +369,20 @@ def update_project(id: int, project: ProjectUpdate, db: Session = Depends(get_db
raise
except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An error occurred while updating the project. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while updating the project. Please try again later.",
)


@project_router.delete("/{project_id}")
async def delete_project(project_id: int, db: Session = Depends(get_db)):
try:
project = project_repository.get_project(db, project_id)
if not project:
raise HTTPException(status_code=404, detail="The specified project could not be found.")
raise HTTPException(
status_code=404, detail="The specified project could not be found."
)

project.deleted_at = datetime.now(tz=timezone.utc)

Expand All @@ -359,28 +403,39 @@ async def delete_project(project_id: int, db: Session = Depends(get_db)):

except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="An error occurred while deleting the project. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while deleting the project. Please try again later.",
)


@project_router.delete("/{project_id}/assets/{asset_id}")
async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get_db)):
try:
project = project_repository.get_project(db, project_id)
if not project:
raise HTTPException(status_code=404, detail="The specified project could not be found.")
raise HTTPException(
status_code=404, detail="The specified project could not be found."
)

asset = project_repository.get_asset(db, asset_id)
if asset is None:
raise HTTPException(status_code=404, detail="The specified asset could not be found in the database.")
raise HTTPException(
status_code=404,
detail="The specified asset could not be found in the database.",
)

if asset.project_id != project_id:
raise HTTPException(status_code=400, detail="The specified asset does not belong to the given project.")
raise HTTPException(
status_code=400,
detail="The specified asset does not belong to the given project.",
)

# Store asset information before deletion
asset_info = {
"id": asset.id,
"filename": asset.filename,
"project_id": asset.project_id
"project_id": asset.project_id,
}

try:
Expand All @@ -393,8 +448,11 @@ async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get
# Soft delete the asset
asset.deleted_at = datetime.now(tz=timezone.utc)
db.commit()

logger.log("info", f"Asset {asset_info['id']} successfully marked as deleted in the database")

logger.log(
"info",
f"Asset {asset_info['id']} successfully marked as deleted in the database",
)
return {"message": "Asset deleted successfully"}

except HTTPException as http_error:
Expand All @@ -404,4 +462,7 @@ async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get
db.rollback()
error_msg = f"An error occurred while deleting the asset: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail="An error occurred while deleting the asset. Please try again later.")
raise HTTPException(
status_code=500,
detail="An error occurred while deleting the asset. Please try again later.",
)
Loading

0 comments on commit 19eed60

Please sign in to comment.