Skip to content

Commit

Permalink
API refactor and functional
Browse files Browse the repository at this point in the history
  • Loading branch information
Zochory committed Aug 28, 2024
1 parent 5c278d8 commit 8b158b9
Show file tree
Hide file tree
Showing 141 changed files with 235,600 additions and 294,561 deletions.
7 changes: 7 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
"features": {
"ghcr.io/devcontainers-contrib/features/poetry:2": {}
},
"customizations": {
"vscode": {
"extensions": [
"ms-toolsai.prompty"
]
}
},

// Uncomment and adjust the following lines as needed
// "forwardPorts": [8000],
Expand Down
59 changes: 0 additions & 59 deletions app.py

This file was deleted.

46 changes: 46 additions & 0 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any, Dict, List, Tuple, AsyncGenerator
from fastapi import HTTPException
from app.services.search_engine import local_search_engine, global_search_engine

async def global_search(query: str) -> Tuple[str, Dict[str, List[Dict[str, Any]]]]:
try:
result = await global_search_engine.asearch(query)
response = result.response
context_data = _reformat_context_data(result.context_data)
return response, context_data
except Exception as e:
raise HTTPException(status_code=500, detail=f"Global search error: {str(e)}")

async def global_search_streaming(query: str) -> AsyncGenerator[str, None]:
try:
async for token in global_search_engine.astream(query):
yield token
except Exception as e:
raise HTTPException(status_code=500, detail=f"Global search streaming error: {str(e)}")

async def local_search(query: str) -> Tuple[str, Dict[str, List[Dict[str, Any]]]]:
try:
result = await local_search_engine.asearch(query)
response = result.response
context_data = _reformat_context_data(result.context_data)
return response, context_data
except Exception as e:
raise HTTPException(status_code=500, detail=f"Local search error: {str(e)}")

async def local_search_streaming(query: str) -> AsyncGenerator[str, None]:
try:
async for token in local_search_engine.astream(query):
yield token
except Exception as e:
raise HTTPException(status_code=500, detail=f"Local search streaming error: {str(e)}")

def _reformat_context_data(context_data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
reformatted_data = {}
for key, value in context_data.items():
if hasattr(value, 'to_dict'):
reformatted_data[key] = value.to_dict(orient='records')
elif isinstance(value, list):
reformatted_data[key] = value
else:
reformatted_data[key] = [value]
return reformatted_data
15 changes: 15 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

class Settings:
def __init__(self):
self.API_KEY: str = os.environ.get("GRAPHRAG_API_KEY", "")
self.LLM_MODEL: str = os.environ.get("GRAPHRAG_LLM_MODEL", "")
self.EMBEDDING_MODEL: str = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "")
self.API_BASE: str = os.environ.get("GRAPHRAG_API_BASE", "")
self.API_VERSION: str = os.environ.get("GRAPHRAG_API_VERSION", "")
self.INPUT_DIR: str = os.environ.get("GRAPHRAG_INPUT_DIR", "graphfleet/output/20240828-113421/artifacts")
self.LANCEDB_URI: str = f"{self.INPUT_DIR}/lancedb"
self.COMMUNITY_LEVEL: int = 2
self.MAX_TOKENS: int = 12000

settings = Settings()
10 changes: 10 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI
from app.routers import search

app = FastAPI(title="GraphRAG API")

app.include_router(search.router, prefix="/search", tags=["search"])

@app.get("/")
async def root():
return {"message": "Welcome to GraphRAG API"}
17 changes: 17 additions & 0 deletions app/middleware/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from fastapi import Request, HTTPException
from fastapi.security import APIKeyHeader
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings

api_key_header = APIKeyHeader(name="X-API-Key")

class APIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.url.path in ["/", "/health"]:
return await call_next(request)

api_key = request.headers.get("X-API-Key")
if api_key != settings.API_KEY:
raise HTTPException(status_code=403, detail="Invalid API Key")

return await call_next(request)
28 changes: 28 additions & 0 deletions app/middleware/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import time

class RateLimiter(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int = 10, window_seconds: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = {}

async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()

if client_ip not in self.requests:
self.requests[client_ip] = []

self.requests[client_ip] = [t for t in self.requests[client_ip] if current_time - t < self.window_seconds]

if len(self.requests[client_ip]) >= self.max_requests:
return JSONResponse(status_code=429, content={"error": "Too many requests"})

self.requests[client_ip].append(current_time)

response = await call_next(request)
return response
12 changes: 12 additions & 0 deletions app/middleware/request_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.utils.logging import logger
import time

class RequestLoggerMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
logger.info(f"{request.method} {request.url.path} - Status: {response.status_code} - Time: {process_time:.2f}s")
return response
17 changes: 17 additions & 0 deletions app/models/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel
from typing import Optional

class Entity(BaseModel):
id: str
title: str
type: str
description: Optional[str]
source_id: str
degree: int
human_readable_id: str
community: Optional[int]
size: Optional[float]
entity_type: Optional[str]
top_level_node_id: str
x: Optional[float]
y: Optional[float]
12 changes: 12 additions & 0 deletions app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fastapi>=0.68.0,<0.111.0
uvicorn[standard]>=0.15.0,<0.26.0
pydantic>=1.8.2,<3.0.0
pydantic-settings>=2.0.0,<3.0.0
pandas>=2.2.2,<3.0.0
tiktoken>=0.3.3,<0.8.0
graphrag>=0.3.2,<0.4.0
lancedb>=0.11.0,<0.12.0
python-dotenv>=0.19.0,<1.1.0
starlette>=0.14.2,<0.35.0
requests>=2.26.0,<3.0.0
numpy>=1.25.2,<2.0.0
20 changes: 20 additions & 0 deletions app/routers/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.services.indexer import run_indexer, run_prompt_tuning

router = APIRouter()

@router.post("/index")
async def trigger_indexing(background_tasks: BackgroundTasks):
try:
background_tasks.add_task(run_indexer)
return {"message": "Indexing process started in the background"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to start indexing process: {str(e)}")

@router.post("/prompt-tune")
async def trigger_prompt_tuning(background_tasks: BackgroundTasks):
try:
background_tasks.add_task(run_prompt_tuning)
return {"message": "Prompt tuning process started in the background"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to start prompt tuning process: {str(e)}")
18 changes: 18 additions & 0 deletions app/routers/question_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List
from app.api import generate_questions

router = APIRouter()

class QuestionGenRequest(BaseModel):
question_history: List[str]
question_count: int = 5

class QuestionGenResponse(BaseModel):
questions: List[str]

@router.post("/generate_questions", response_model=QuestionGenResponse)
async def api_generate_questions(request: QuestionGenRequest):
questions = await generate_questions(request.question_history, request.question_count)
return QuestionGenResponse(questions=questions)
31 changes: 31 additions & 0 deletions app/routers/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.api import global_search, global_search_streaming, local_search, local_search_streaming

router = APIRouter()

class SearchQuery(BaseModel):
query: str

class SearchResponse(BaseModel):
response: str
context_data: dict

@router.post("/local", response_model=SearchResponse)
async def api_local_search(query: SearchQuery):
response, context_data = await local_search(query.query)
return SearchResponse(response=response, context_data=context_data)

@router.post("/global", response_model=SearchResponse)
async def api_global_search(query: SearchQuery):
response, context_data = await global_search(query.query)
return SearchResponse(response=response, context_data=context_data)

@router.post("/local/stream")
async def api_local_search_stream(query: SearchQuery):
return StreamingResponse(local_search_streaming(query.query), media_type="text/event-stream")

@router.post("/global/stream")
async def api_global_search_stream(query: SearchQuery):
return StreamingResponse(global_search_streaming(query.query), media_type="text/event-stream")
60 changes: 60 additions & 0 deletions app/services/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio
from app.config import settings
from app.utils.logging import logger

async def run_indexer(verbose: bool = True):
try:
cmd = [
"python", "-m", "graphrag.index",
"--root", settings.INPUT_DIR,
"--config", f"{settings.INPUT_DIR}/settings.yaml"
]
if verbose:
cmd.append("--verbose")

process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)

stdout, stderr = await process.communicate()

if process.returncode == 0:
logger.info("Indexing completed successfully")
return stdout.decode()
else:
logger.error(f"Indexing failed: {stderr.decode()}")
raise RuntimeError("Indexing process failed")
except Exception as e:
logger.error(f"Error during indexing: {str(e)}")
raise

async def run_prompt_tuning(no_entity_types: bool = True):
try:
cmd = [
"python", "-m", "graphrag.prompt_tune",
"--config", f"{settings.INPUT_DIR}/settings.yaml",
"--root", settings.INPUT_DIR,
"--output", f"{settings.INPUT_DIR}/prompts"
]
if no_entity_types:
cmd.append("--no-entity-types")

process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)

stdout, stderr = await process.communicate()

if process.returncode == 0:
logger.info("Prompt tuning completed successfully")
return stdout.decode()
else:
logger.error(f"Prompt tuning failed: {stderr.decode()}")
raise RuntimeError("Prompt tuning process failed")
except Exception as e:
logger.error(f"Error during prompt tuning: {str(e)}")
raise
Loading

0 comments on commit 8b158b9

Please sign in to comment.