-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathapi_server.py
69 lines (55 loc) · 1.98 KB
/
api_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from data_processing.pdf_to_markdown import convert_all_pdfs
from data_processing.text_splitter import split_text_files
from embedding.sentence_encoder import get_encoder, generate_embeddings
from vector_db.qdrant_client import init_qdrant_client, create_collection, prepare_data_for_upload, upload_data
from chat_model.chat_groq import init_chat_model, generate_response
from flashrank import Ranker
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Set up paths
INPUT_FOLDER = 'docs'
OUTPUT_FOLDER = 'parsed_docs'
COLLECTION_NAME = 'my_text_chunks'
# Convert PDFs to Markdown
convert_all_pdfs(INPUT_FOLDER, OUTPUT_FOLDER)
# Split Markdown files into chunks
chunks = split_text_files(OUTPUT_FOLDER)
# Generate embeddings
encoder = get_encoder()
embeddings = generate_embeddings(encoder, chunks)
# Initialise Qdrant client and create collection
client = init_qdrant_client()
create_collection(client, COLLECTION_NAME)
# Prepare and upload data
points = prepare_data_for_upload(chunks, embeddings)
upload_data(client, COLLECTION_NAME, points)
# Initialise ChatGroq model
chat_model = init_chat_model()
# Initialise Ranker
ranker = Ranker()
class Query(BaseModel):
"""
Query model for user input.
"""
user_query: str
@app.post("/chat")
async def chat(query: Query):
"""
Endpoint to handle chat queries.
"""
try:
logger.info("Received query: %s", query.user_query)
response = generate_response(chat_model, encoder, ranker, query.user_query, COLLECTION_NAME, client)
logger.info("Generated response: %s", response)
return {"response": response}
except Exception as e:
logger.error("Error generating response: %s", str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)