-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
141 changed files
with
235,600 additions
and
294,561 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Any, Dict, List, Tuple, AsyncGenerator | ||
from fastapi import HTTPException | ||
from app.services.search_engine import local_search_engine, global_search_engine | ||
|
||
async def global_search(query: str) -> Tuple[str, Dict[str, List[Dict[str, Any]]]]: | ||
try: | ||
result = await global_search_engine.asearch(query) | ||
response = result.response | ||
context_data = _reformat_context_data(result.context_data) | ||
return response, context_data | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Global search error: {str(e)}") | ||
|
||
async def global_search_streaming(query: str) -> AsyncGenerator[str, None]: | ||
try: | ||
async for token in global_search_engine.astream(query): | ||
yield token | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Global search streaming error: {str(e)}") | ||
|
||
async def local_search(query: str) -> Tuple[str, Dict[str, List[Dict[str, Any]]]]: | ||
try: | ||
result = await local_search_engine.asearch(query) | ||
response = result.response | ||
context_data = _reformat_context_data(result.context_data) | ||
return response, context_data | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Local search error: {str(e)}") | ||
|
||
async def local_search_streaming(query: str) -> AsyncGenerator[str, None]: | ||
try: | ||
async for token in local_search_engine.astream(query): | ||
yield token | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Local search streaming error: {str(e)}") | ||
|
||
def _reformat_context_data(context_data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: | ||
reformatted_data = {} | ||
for key, value in context_data.items(): | ||
if hasattr(value, 'to_dict'): | ||
reformatted_data[key] = value.to_dict(orient='records') | ||
elif isinstance(value, list): | ||
reformatted_data[key] = value | ||
else: | ||
reformatted_data[key] = [value] | ||
return reformatted_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
|
||
class Settings: | ||
def __init__(self): | ||
self.API_KEY: str = os.environ.get("GRAPHRAG_API_KEY", "") | ||
self.LLM_MODEL: str = os.environ.get("GRAPHRAG_LLM_MODEL", "") | ||
self.EMBEDDING_MODEL: str = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "") | ||
self.API_BASE: str = os.environ.get("GRAPHRAG_API_BASE", "") | ||
self.API_VERSION: str = os.environ.get("GRAPHRAG_API_VERSION", "") | ||
self.INPUT_DIR: str = os.environ.get("GRAPHRAG_INPUT_DIR", "graphfleet/output/20240828-113421/artifacts") | ||
self.LANCEDB_URI: str = f"{self.INPUT_DIR}/lancedb" | ||
self.COMMUNITY_LEVEL: int = 2 | ||
self.MAX_TOKENS: int = 12000 | ||
|
||
settings = Settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from fastapi import FastAPI | ||
from app.routers import search | ||
|
||
app = FastAPI(title="GraphRAG API") | ||
|
||
app.include_router(search.router, prefix="/search", tags=["search"]) | ||
|
||
@app.get("/") | ||
async def root(): | ||
return {"message": "Welcome to GraphRAG API"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from fastapi import Request, HTTPException | ||
from fastapi.security import APIKeyHeader | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from app.config import settings | ||
|
||
api_key_header = APIKeyHeader(name="X-API-Key") | ||
|
||
class APIKeyMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
if request.url.path in ["/", "/health"]: | ||
return await call_next(request) | ||
|
||
api_key = request.headers.get("X-API-Key") | ||
if api_key != settings.API_KEY: | ||
raise HTTPException(status_code=403, detail="Invalid API Key") | ||
|
||
return await call_next(request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from fastapi import Request, HTTPException | ||
from fastapi.responses import JSONResponse | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
import time | ||
|
||
class RateLimiter(BaseHTTPMiddleware): | ||
def __init__(self, app, max_requests: int = 10, window_seconds: int = 60): | ||
super().__init__(app) | ||
self.max_requests = max_requests | ||
self.window_seconds = window_seconds | ||
self.requests = {} | ||
|
||
async def dispatch(self, request: Request, call_next): | ||
client_ip = request.client.host | ||
current_time = time.time() | ||
|
||
if client_ip not in self.requests: | ||
self.requests[client_ip] = [] | ||
|
||
self.requests[client_ip] = [t for t in self.requests[client_ip] if current_time - t < self.window_seconds] | ||
|
||
if len(self.requests[client_ip]) >= self.max_requests: | ||
return JSONResponse(status_code=429, content={"error": "Too many requests"}) | ||
|
||
self.requests[client_ip].append(current_time) | ||
|
||
response = await call_next(request) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from fastapi import Request | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from app.utils.logging import logger | ||
import time | ||
|
||
class RequestLoggerMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
start_time = time.time() | ||
response = await call_next(request) | ||
process_time = time.time() - start_time | ||
logger.info(f"{request.method} {request.url.path} - Status: {response.status_code} - Time: {process_time:.2f}s") | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from pydantic import BaseModel | ||
from typing import Optional | ||
|
||
class Entity(BaseModel): | ||
id: str | ||
title: str | ||
type: str | ||
description: Optional[str] | ||
source_id: str | ||
degree: int | ||
human_readable_id: str | ||
community: Optional[int] | ||
size: Optional[float] | ||
entity_type: Optional[str] | ||
top_level_node_id: str | ||
x: Optional[float] | ||
y: Optional[float] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
fastapi>=0.68.0,<0.111.0 | ||
uvicorn[standard]>=0.15.0,<0.26.0 | ||
pydantic>=1.8.2,<3.0.0 | ||
pydantic-settings>=2.0.0,<3.0.0 | ||
pandas>=2.2.2,<3.0.0 | ||
tiktoken>=0.3.3,<0.8.0 | ||
graphrag>=0.3.2,<0.4.0 | ||
lancedb>=0.11.0,<0.12.0 | ||
python-dotenv>=0.19.0,<1.1.0 | ||
starlette>=0.14.2,<0.35.0 | ||
requests>=2.26.0,<3.0.0 | ||
numpy>=1.25.2,<2.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from fastapi import APIRouter, HTTPException, BackgroundTasks | ||
from app.services.indexer import run_indexer, run_prompt_tuning | ||
|
||
router = APIRouter() | ||
|
||
@router.post("/index") | ||
async def trigger_indexing(background_tasks: BackgroundTasks): | ||
try: | ||
background_tasks.add_task(run_indexer) | ||
return {"message": "Indexing process started in the background"} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Failed to start indexing process: {str(e)}") | ||
|
||
@router.post("/prompt-tune") | ||
async def trigger_prompt_tuning(background_tasks: BackgroundTasks): | ||
try: | ||
background_tasks.add_task(run_prompt_tuning) | ||
return {"message": "Prompt tuning process started in the background"} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Failed to start prompt tuning process: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import BaseModel | ||
from typing import List | ||
from app.api import generate_questions | ||
|
||
router = APIRouter() | ||
|
||
class QuestionGenRequest(BaseModel): | ||
question_history: List[str] | ||
question_count: int = 5 | ||
|
||
class QuestionGenResponse(BaseModel): | ||
questions: List[str] | ||
|
||
@router.post("/generate_questions", response_model=QuestionGenResponse) | ||
async def api_generate_questions(request: QuestionGenRequest): | ||
questions = await generate_questions(request.question_history, request.question_count) | ||
return QuestionGenResponse(questions=questions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fastapi import APIRouter | ||
from fastapi.responses import StreamingResponse | ||
from pydantic import BaseModel | ||
from app.api import global_search, global_search_streaming, local_search, local_search_streaming | ||
|
||
router = APIRouter() | ||
|
||
class SearchQuery(BaseModel): | ||
query: str | ||
|
||
class SearchResponse(BaseModel): | ||
response: str | ||
context_data: dict | ||
|
||
@router.post("/local", response_model=SearchResponse) | ||
async def api_local_search(query: SearchQuery): | ||
response, context_data = await local_search(query.query) | ||
return SearchResponse(response=response, context_data=context_data) | ||
|
||
@router.post("/global", response_model=SearchResponse) | ||
async def api_global_search(query: SearchQuery): | ||
response, context_data = await global_search(query.query) | ||
return SearchResponse(response=response, context_data=context_data) | ||
|
||
@router.post("/local/stream") | ||
async def api_local_search_stream(query: SearchQuery): | ||
return StreamingResponse(local_search_streaming(query.query), media_type="text/event-stream") | ||
|
||
@router.post("/global/stream") | ||
async def api_global_search_stream(query: SearchQuery): | ||
return StreamingResponse(global_search_streaming(query.query), media_type="text/event-stream") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import asyncio | ||
from app.config import settings | ||
from app.utils.logging import logger | ||
|
||
async def run_indexer(verbose: bool = True): | ||
try: | ||
cmd = [ | ||
"python", "-m", "graphrag.index", | ||
"--root", settings.INPUT_DIR, | ||
"--config", f"{settings.INPUT_DIR}/settings.yaml" | ||
] | ||
if verbose: | ||
cmd.append("--verbose") | ||
|
||
process = await asyncio.create_subprocess_exec( | ||
*cmd, | ||
stdout=asyncio.subprocess.PIPE, | ||
stderr=asyncio.subprocess.PIPE | ||
) | ||
|
||
stdout, stderr = await process.communicate() | ||
|
||
if process.returncode == 0: | ||
logger.info("Indexing completed successfully") | ||
return stdout.decode() | ||
else: | ||
logger.error(f"Indexing failed: {stderr.decode()}") | ||
raise RuntimeError("Indexing process failed") | ||
except Exception as e: | ||
logger.error(f"Error during indexing: {str(e)}") | ||
raise | ||
|
||
async def run_prompt_tuning(no_entity_types: bool = True): | ||
try: | ||
cmd = [ | ||
"python", "-m", "graphrag.prompt_tune", | ||
"--config", f"{settings.INPUT_DIR}/settings.yaml", | ||
"--root", settings.INPUT_DIR, | ||
"--output", f"{settings.INPUT_DIR}/prompts" | ||
] | ||
if no_entity_types: | ||
cmd.append("--no-entity-types") | ||
|
||
process = await asyncio.create_subprocess_exec( | ||
*cmd, | ||
stdout=asyncio.subprocess.PIPE, | ||
stderr=asyncio.subprocess.PIPE | ||
) | ||
|
||
stdout, stderr = await process.communicate() | ||
|
||
if process.returncode == 0: | ||
logger.info("Prompt tuning completed successfully") | ||
return stdout.decode() | ||
else: | ||
logger.error(f"Prompt tuning failed: {stderr.decode()}") | ||
raise RuntimeError("Prompt tuning process failed") | ||
except Exception as e: | ||
logger.error(f"Error during prompt tuning: {str(e)}") | ||
raise |
Oops, something went wrong.