Skip to content

Commit

Permalink
refactor(query): use QueryRequestBody
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Mar 18, 2024
1 parent b8dac6c commit c03a9ec
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
14 changes: 7 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.schema import Document
from langchain_core.runnables.config import run_in_executor

from models import DocumentModel, DocumentResponse, StoreDocument
from models import DocumentModel, DocumentResponse, StoreDocument, QueryRequestBody
from store import AsyncPgVector

load_dotenv(find_dotenv())
Expand Down Expand Up @@ -133,25 +133,25 @@ async def delete_documents(ids: list[str]):
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query-embeddings-by-file-id/")
async def query_embeddings_by_file_id(file_id: str, query: str, k: int = 4):
async def query_embeddings_by_file_id(body: QueryRequestBody):
try:
# Get the embedding of the query text
embedding = pgvector_store.embedding_function.embed_query(query)
embedding = pgvector_store.embedding_function.embed_query(body.query)

# Perform similarity search with the query embedding and filter by the file_id in metadata
if isinstance(pgvector_store, AsyncPgVector):
documents = await run_in_executor(
None,
pgvector_store.similarity_search_with_score_by_vector,
embedding,
k=k,
filter={"file_id": file_id}
k=body.k,
filter={"file_id": body.file_id}
)
else:
documents = pgvector_store.similarity_search_with_score_by_vector(
embedding,
k=k,
filter={"file_id": file_id}
k=body.k,
filter={"file_id": body.file_id}
)

return documents
Expand Down
5 changes: 5 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class StoreDocument(BaseModel):
file_content_type: str
file_id: str

class QueryRequestBody(BaseModel):
file_id: str
query: str
k: int = 4

class CleanupMethod(str, Enum):
incremental = "incremental"
full = "full"

0 comments on commit c03a9ec

Please sign in to comment.