Skip to content

Commit

Permalink
Add Streamlit UI
Browse files Browse the repository at this point in the history
  • Loading branch information
Zochory committed Aug 29, 2024
1 parent 28824f5 commit df508a8
Show file tree
Hide file tree
Showing 49 changed files with 382,870 additions and 240 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"ms-toolsai.vscode-jupyter-cell-tags",
"ms-toolsai.jupyter",
"VisualStudioExptTeam.vscodeintellicode",
"github.vscode-github-actions"
"github.vscode-github-actions",
"mikestead.dotenv"
]
}
},
Expand Down
4 changes: 3 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"ms-azuretools.vscode-azurefunctions",
"genaiscript.genaiscript-vscode",
"ms-vscode.azurecli",
"github.vscode-github-actions"
"github.vscode-github-actions",
"mikestead.dotenv",
"bartosz-dude.folder-scopes"
]
}
56 changes: 45 additions & 11 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,33 @@
local_search_engine,
global_search_engine,
)
import logging
import numpy as np
from app.utils import convert_numpy

# ... existing imports ...

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


async def global_search(query: str) -> Dict[str, Any]:
try:
logger.debug(f"Starting global search with query: {query}")

Check failure

Code scanning / CodeQL

Log Injection High

This log entry depends on a
user-provided value
.
result = await global_search_engine.asearch(query)
logger.debug("Global search completed successfully")
context_data = _reformat_context_data(result.context_data)
logger.debug(f"Context data reformatted: {list(context_data.keys())}")
return {
"response": result.response,
"context_data": _reformat_context_data(result.context_data),
"context_data": context_data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Global search error: {str(e)}")
logger.error(f"Global search error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Global search error: {str(e)}"
)


async def global_search_streaming(query: str) -> AsyncGenerator[str, None]:
Expand All @@ -32,12 +48,17 @@ async def global_search_streaming(query: str) -> AsyncGenerator[str, None]:

async def local_search(query: str) -> Dict[str, Any]:
try:
logger.debug(f"Starting local search with query: {query}")

Check failure

Code scanning / CodeQL

Log Injection High

This log entry depends on a
user-provided value
.
result = await local_search_engine.asearch(query)
logger.debug("Local search completed successfully")
context_data = _reformat_context_data(result.context_data)
logger.debug(f"Context data reformatted: {list(context_data.keys())}")
return {
"response": result.response,
"context_data": _reformat_context_data(result.context_data),
"context_data": context_data,
}
except Exception as e:
logger.error(f"Local search error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Local search error: {str(e)}"
Expand All @@ -50,14 +71,27 @@ async def local_search_streaming(query: str) -> AsyncGenerator[str, None]:
yield chunk
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Local search streaming error: {str(e)}"
)
status_code=500,
detail=f"Local search streaming error: {str(e)}"
)


def _reformat_context_data(context_data: Dict[str, Any]) -> Dict[str, List[Any]]:
return {
key: value.to_dict(orient="records")
if hasattr(value, "to_dict")
else (value if isinstance(value, list) else [value])
for key, value in context_data.items()
}
reformatted_data = {}
for key, value in context_data.items():
logger.debug(f"Processing key: {key}, Type: {type(value)}")
try:
if hasattr(value, "to_dict"):
logger.debug(f"Converting {key} using to_dict method")
reformatted_data[key] = convert_numpy(value.to_dict(orient="records"))
elif isinstance(value, list):
logger.debug(f"Converting {key} as list")
reformatted_data[key] = convert_numpy(value)
else:
logger.debug(f"Converting {key} as single value")
reformatted_data[key] = convert_numpy([value])
logger.debug(f"Converted {key}, Type: {type(reformatted_data[key])}")
except Exception as e:
logger.error(f"Error converting {key}: {str(e)}", exc_info=True)
reformatted_data[key] = str(value) # Fallback to string representation
return reformatted_data
19 changes: 8 additions & 11 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@


class Settings(BaseSettings):
API_KEY: str = os.getenv("GRAPHRAG_API_KEY")
LLM_MODEL: str = os.getenv("GRAPHRAG_LLM_MODEL")
EMBEDDING_MODEL: str = os.getenv("GRAPHRAG_EMBEDDING_MODEL")
API_BASE: str = os.getenv("GRAPHRAG_API_BASE")
API_VERSION: str = os.getenv("GRAPHRAG_API_VERSION")
INPUT_DIR: str = os.getenv("GRAPHRAG_INPUT_DIR")
LANCEDB_URI: str = os.getenv("GRAPHRAG_LANCEDB_URI")
API_KEY: str = os.getenv("GRAPHRAG_API_KEY", "default_api_key")
LLM_MODEL: str = os.getenv("GRAPHRAG_LLM_MODEL", "default_llm_model")
EMBEDDING_MODEL: str = os.getenv("GRAPHRAG_EMBEDDING_MODEL", "default_embedding_model")
API_BASE: str = os.getenv("GRAPHRAG_API_BASE", "default_api_base")
API_VERSION: str = os.getenv("GRAPHRAG_API_VERSION", "default_api_version")
INPUT_DIR: str = os.getenv("GRAPHRAG_INPUT_DIR", "graphfleet/output/20240829-184001/artifacts")
LANCEDB_URI: str = os.getenv("GRAPHRAG_LANCEDB_URI", "graphfleet/output/20240829-184001/artifacts/lancedb")
COMMUNITY_LEVEL: int = int(os.getenv("GRAPHRAG_COMMUNITY_LEVEL", 2))
MAX_TOKENS: int = int(os.getenv("GRAPHRAG_MAX_TOKENS", 12000))
INPUT_DIR: str = "graphfleet/output/20240828-113421/artifacts"
LANCEDB_URI: str = ""
COMMUNITY_LEVEL: int = 2

class Config:
env_prefix = "GRAPHRAG_"
Expand All @@ -28,4 +25,4 @@ def lancedb_uri(self) -> str:
return f"{self.INPUT_DIR}/lancedb"


settings = Settings()
settings = Settings()
112 changes: 111 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,118 @@
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.services.question_generator import create_question_generator
from app.services.search_engine import create_search_engines, LocalSearchWrapper, GlobalSearchWrapper
from app.routers import search
import pandas as pd
import logging
import numpy as np
from app.utils import convert_numpy
from app.api import _reformat_context_data

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI(title="GraphFleet API")

class SearchQuery(BaseModel):
query: str

@app.post("/search/local")
async def local_search(search_query: SearchQuery):
try:
logger.debug(f"Received local search query: {search_query.query}")

Check failure

Code scanning / CodeQL

Log Injection High

This log entry depends on a
user-provided value
.

search_engines = create_search_engines()
logger.debug("Search engines created successfully")

local_search_engine = search_engines[0]
logger.debug("Local search engine retrieved")

logger.debug("Starting local search")
result = await local_search_engine.asearch(search_query.query)
logger.debug("Local search completed")

logger.debug(f"Raw context_data keys: {result.context_data.keys()}")
for key, value in result.context_data.items():
logger.debug(f"Key: {key}, Type: {type(value)}")

logger.debug("Starting context data reformatting")
context_data = _reformat_context_data(result.context_data)
logger.debug(f"Context data reformatted: {list(context_data.keys())}")

logger.debug("Preparing response data")
response_data = {
"response": result.response,
"context_data": context_data,
"reports_head": context_data.get("reports", [])[:5] if "reports" in context_data else []
}
logger.debug("Response data prepared")

return response_data
except Exception as e:
logger.error(f"Error in local_search: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An error occurred during local search: {str(e)}")

@app.post("/search/global")
async def global_search(search_query: SearchQuery):
search_engines = create_search_engines()
global_search_engine = search_engines[1]
try:
logger.debug(f"Received global search query: {search_query.query}")

Check failure

Code scanning / CodeQL

Log Injection High

This log entry depends on a
user-provided value
.
result = await global_search_engine.asearch(search_query.query)
context_data = _reformat_context_data(result.context_data)
logger.debug(f"Context data reformatted: {list(context_data.keys())}")

total_report_count = len(context_data.get("reports", []))
filtered_report_count = len(context_data.get("reports", []))

logger.debug("Global search completed successfully")
return {
"response": result.response,
"context_data": context_data,
"reports_head": context_data.get("reports", [])[:5] if "reports" in context_data else [],
"total_report_count": total_report_count,
"filtered_report_count": filtered_report_count,
"reports": context_data.get("reports", [])
}
except Exception as e:
logger.error(f"Error in global_search: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An error occurred during global search: {str(e)}")

@app.post("/search/local/stream")
async def local_search_stream(search_query: SearchQuery):
search_engines = create_search_engines()
local_search_engine = search_engines[0]
try:
async def stream_generator():
async for chunk in local_search_engine.astream(search_query.query):
yield f"data: {chunk}\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/search/global/stream")
async def global_search_stream(search_query: SearchQuery):
search_engines = create_search_engines()
global_search_engine = search_engines[1]
try:
async def stream_generator():
async for chunk in global_search_engine.astream(search_query.query):
yield f"data: {chunk}\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate_questions")
async def generate_questions(search_query: SearchQuery):
question_generator = create_question_generator()
try:
result = await question_generator.agenerate(question_history=[search_query.query], context_data=None, question_count=5)
return {"questions": result.response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

app.include_router(search.router, prefix="/search", tags=["search"])

@app.get("/")
Expand Down
4 changes: 3 additions & 1 deletion app/services/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from app.config import settings
from app.utils.logging import logger


async def run_indexer(verbose: bool = True):
try:
cmd = [
Expand Down Expand Up @@ -30,6 +31,7 @@ async def run_indexer(verbose: bool = True):
logger.error(f"Error during indexing: {str(e)}")
raise


async def run_prompt_tuning(no_entity_types: bool = True):
try:
cmd = [
Expand Down Expand Up @@ -57,4 +59,4 @@ async def run_prompt_tuning(no_entity_types: bool = True):
raise RuntimeError("Prompt tuning process failed")
except Exception as e:
logger.error(f"Error during prompt tuning: {str(e)}")
raise
raise
28 changes: 23 additions & 5 deletions app/services/question_generator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from app.services.search_engine import create_search_engine
from graphrag.query.question_gen.local_gen import (
LocalQuestionGen,
LocalContextBuilder,
)
from app.services.search_engine import create_search_engines


class ConcreteLocalContextBuilder(LocalContextBuilder):
def build_context(self, *args, **kwargs):
# Implement the build_context method
pass


def create_question_generator():
search_engine = create_search_engine()
search_engines = create_search_engines()
search_engine = search_engines[0]

context_builder = (
search_engine.context_builder
if isinstance(search_engine.context_builder, LocalContextBuilder)
else ConcreteLocalContextBuilder()
)

return LocalQuestionGen(
llm=search_engine.llm,
context_builder=search_engine.context_builder,
context_builder=context_builder,
token_encoder=search_engine.token_encoder,
llm_params=search_engine.llm_params,
context_builder_params=search_engine.context_builder_params,
)

question_generator = create_question_generator()

question_generator = create_question_generator()
Loading

0 comments on commit df508a8

Please sign in to comment.