diff --git a/admin_app/package-lock.json b/admin_app/package-lock.json index e2db041a6..496137d7e 100644 --- a/admin_app/package-lock.json +++ b/admin_app/package-lock.json @@ -4373,9 +4373,9 @@ } }, "node_modules/source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", + "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==", "engines": { "node": ">=0.10.0" } diff --git a/admin_app/src/app/dashboard/api.ts b/admin_app/src/app/dashboard/api.ts index 41c8958ac..e7126c88a 100644 --- a/admin_app/src/app/dashboard/api.ts +++ b/admin_app/src/app/dashboard/api.ts @@ -21,6 +21,40 @@ const getOverviewPageData = async (period: Period, token: string) => { }); }; +const fetchTopicsData = async (period: Period, token: string) => { + return fetch(`${NEXT_PUBLIC_BACKEND_URL}/dashboard/insights/${period}`, { + method: "GET", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + }).then((response) => { + if (response.ok) { + let resp = response.json(); + return resp; + } else { + throw new Error("Error fetching Topics data"); + } + }); +}; + +const generateNewTopics = async (period: Period, token: string) => { + return fetch(`${NEXT_PUBLIC_BACKEND_URL}/dashboard/insights/${period}/refresh`, { + method: "GET", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + }).then((response) => { + if (response.ok) { + let resp = response.json(); + return resp; + } else { + throw new Error("Error generating Topics data"); + } + }); +}; + const getPerformancePageData = async (period: Period, token: string) => { return fetch(`${NEXT_PUBLIC_BACKEND_URL}/dashboard/performance/${period}`, { method: "GET", @@ -90,4 +124,6 @@ export { getPerformancePageData, getPerformanceDrawerData, getPerformanceDrawerAISummary, + fetchTopicsData, + generateNewTopics, }; diff --git a/admin_app/src/app/dashboard/components/Insights.tsx b/admin_app/src/app/dashboard/components/Insights.tsx new file mode 100644 index 000000000..ea340db7f --- /dev/null +++ b/admin_app/src/app/dashboard/components/Insights.tsx @@ -0,0 +1,135 @@ +import React from "react"; +import Grid from "@mui/material/Unstable_Grid2"; +import Topics from "./insights/Topics"; +import Queries from "./insights/Queries"; +import Box from "@mui/material/Box"; +import { useState } from "react"; +import { QueryData, Period, TopicModelingResponse } from "../types"; +import { generateNewTopics, fetchTopicsData } from "../api"; +import { useAuth } from "@/utils/auth"; + +interface InsightProps { + timePeriod: Period; +} + +const Insight: React.FC = ({ timePeriod }) => { + const { token } = useAuth(); + const [selectedTopicId, setSelectedTopicId] = useState(null); + const [topicQueries, setTopicQueries] = useState([]); + const [refreshTimestamp, setRefreshTimestamp] = useState(""); + const [refreshing, setRefreshing] = useState(false); + const [aiSummary, setAiSummary] = useState(""); + + const [dataFromBackend, setDataFromBackend] = useState({ + data: [], + refreshTimeStamp: "", + unclustered_queries: [], + }); + + const runRefresh = () => { + setRefreshing(true); + generateNewTopics(timePeriod, token!).then((_) => { + const date = new Date(); + setRefreshTimestamp(date.toLocaleString()); + setRefreshing(false); + }); + }; + + React.useEffect(() => { + if (token) { + fetchTopicsData(timePeriod, token).then((dataFromBackend) => { + setDataFromBackend(dataFromBackend); + if (dataFromBackend.data.length > 0) { + setSelectedTopicId(dataFromBackend.data[0].topic_id); + } + }); + } else { + console.log("No token found"); + } + }, [token, refreshTimestamp, timePeriod]); + + React.useEffect(() => { + if (selectedTopicId !== null) { + const filterQueries = dataFromBackend.data.find( + (topic) => topic.topic_id === selectedTopicId, + ); + + if (filterQueries) { + setTopicQueries(filterQueries.topic_samples); + setAiSummary(filterQueries.topic_summary); + } else { + setTopicQueries([]); + setAiSummary("Not available."); + } + } else { + setTopicQueries([]); + setAiSummary("Not available."); + } + }, [dataFromBackend, selectedTopicId, refreshTimestamp, timePeriod]); + + const topics = dataFromBackend.data.map( + ({ topic_id, topic_name, topic_popularity }) => ({ + topic_id, + topic_name, + topic_popularity, + }), + ); + + return ( + + + + + + + + + + + + -- Chart - Coming Soon! -- + + + + ); +}; + +export default Insight; diff --git a/admin_app/src/app/dashboard/components/insights/Queries.tsx b/admin_app/src/app/dashboard/components/insights/Queries.tsx new file mode 100644 index 000000000..72aab0f30 --- /dev/null +++ b/admin_app/src/app/dashboard/components/insights/Queries.tsx @@ -0,0 +1,171 @@ +import React from "react"; +import { Box } from "@mui/material"; +import Table from "@mui/material/Table"; +import TableBody from "@mui/material/TableBody"; +import TableCell from "@mui/material/TableCell"; +import TableContainer from "@mui/material/TableContainer"; +import TableHead from "@mui/material/TableHead"; +import TableRow from "@mui/material/TableRow"; +import { grey, orange } from "@mui/material/colors"; +import Typography from "@mui/material/Typography"; +import AutoAwesomeIcon from "@mui/icons-material/AutoAwesome"; +import Button from "@mui/material/Button"; +import { QueryData } from "../../types"; +import CircularProgress from "@mui/material/CircularProgress"; + +interface QueriesProps { + data: QueryData[]; + onRefreshClick: () => void; + lastRefreshed: string; + refreshing: boolean; + aiSummary: string; +} + +interface AISummaryProps { + aiSummary: string; +} +const AISummary: React.FC = ({ aiSummary }) => { + return ( + + + + + AI Overview + + + + {aiSummary} + + + ); +}; + +const Queries: React.FC = ({ + data, + onRefreshClick, + lastRefreshed, + refreshing, + aiSummary, +}) => { + const formattedLastRefreshed = + lastRefreshed.length > 0 + ? Intl.DateTimeFormat("en-ZA", { + dateStyle: "short", + timeStyle: "short", + }).format(new Date(lastRefreshed)) + : "Never"; + + return ( + + + Example Queries + + + Last run: {formattedLastRefreshed} + + + + + + + {data.length > 0 ? ( + + + + + Timestamp + User Question + + + + {data.map((row, index) => ( + + + {Intl.DateTimeFormat("en-ZA", { + dateStyle: "short", + timeStyle: "short", + }).format(new Date(row.query_datetime_utc))} + + {row.query_text} + + ))} + +
+
+ ) : ( + + No queries found. Please re-run discovery + + )} +
+
+ ); +}; + +export default Queries; diff --git a/admin_app/src/app/dashboard/components/insights/Topics.tsx b/admin_app/src/app/dashboard/components/insights/Topics.tsx new file mode 100644 index 000000000..a7e73e41b --- /dev/null +++ b/admin_app/src/app/dashboard/components/insights/Topics.tsx @@ -0,0 +1,114 @@ +import React, { useEffect } from "react"; +import { Box } from "@mui/material"; +import Chip from "@mui/material/Chip"; +import ListItemButton from "@mui/material/ListItemButton"; +import { orange } from "@mui/material/colors"; +import { TopicData } from "../../types"; +import Pagination from "@mui/material/Pagination"; +import { useState } from "react"; + +interface TopicProps { + data: TopicData[]; + selectedTopicId: number | null; + onClick: (topic: number | null) => void; + topicsPerPage: number; +} + +const Topics: React.FC = ({ + data, + selectedTopicId, + onClick, + topicsPerPage, +}) => { + const [page, setPage] = useState(1); + const [dataToShow, setDataToShow] = useState(data.slice(0, topicsPerPage)); + + const handlePageChange = (_: React.ChangeEvent, value: number) => { + setPage(value); + filterPageData(value); + }; + const filterPageData = (value: number) => { + const start = (value - 1) * topicsPerPage; + const end = value * topicsPerPage; + setDataToShow(data.slice(start, end)); + }; + + useEffect(() => { + filterPageData(1); + }, [data]); + + return ( + + + Topics + {dataToShow.map((topic) => ( + + + onClick(topic.topic_id)} + sx={{ + display: "flex", + flexDirection: "row", + justifyContent: "space-between", + borderRadius: 2, + my: 0.5, + ml: -0.5, + }} + > + {topic.topic_name} + + + + + + ))} + + + + + + + ); +}; + +export default Topics; diff --git a/admin_app/src/app/dashboard/page.tsx b/admin_app/src/app/dashboard/page.tsx index e6fb9f759..82b795250 100644 --- a/admin_app/src/app/dashboard/page.tsx +++ b/admin_app/src/app/dashboard/page.tsx @@ -7,6 +7,8 @@ import TabPanel from "@/app/dashboard/components/TabPanel"; import { Period, drawerWidth } from "./types"; import Overview from "@/app/dashboard/components/Overview"; import ContentPerformance from "@/app/dashboard/components/ContentPerformance"; +import Insights from "./components/Insights"; + import { appColors } from "@/utils"; type Page = { @@ -46,7 +48,7 @@ const Dashboard: React.FC = () => { case "Content Performance": return ; case "Content Gaps": - return
Coming Soon!
; + return ; default: return
Page not found.
; } diff --git a/admin_app/src/app/dashboard/types.ts b/admin_app/src/app/dashboard/types.ts index e3f88b974..f8d227ca7 100644 --- a/admin_app/src/app/dashboard/types.ts +++ b/admin_app/src/app/dashboard/types.ts @@ -46,6 +46,31 @@ interface RowDataType extends ContentData { id: number; } +interface QueryData { + query_text: string; + query_datetime_utc: string; +} + +interface TopicModelingData { + topic_id: number; + topic_samples: QueryData[]; + topic_summary: string; + topic_name: string; + topic_popularity: number; +} + +interface TopicModelingResponse { + refreshTimeStamp: string; + data: TopicModelingData[]; + unclustered_queries: QueryData[]; +} + +interface TopicData { + topic_id: number; + topic_name: string; + topic_popularity: number; +} + export type { DrawerData, Period, @@ -54,6 +79,10 @@ export type { ApexData, TopContentData, RowDataType, + QueryData, + TopicData, + TopicModelingData, + TopicModelingResponse, }; export { drawerWidth }; diff --git a/core_backend/add_dummy_data_to_db.py b/core_backend/add_dummy_data_to_db.py index 1116a5810..e419c8852 100644 --- a/core_backend/add_dummy_data_to_db.py +++ b/core_backend/add_dummy_data_to_db.py @@ -39,7 +39,7 @@ ADMIN_PASSWORD = os.environ.get("ADMIN_PASSWORD", "fullaccess") _USER_ID = 1 -N_DATAPOINTS = 100 +N_DATAPOINTS = 2000 URGENCY_RATE = 0.1 NEGATIVE_FEEDBACK_RATE = 0.1 @@ -187,7 +187,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB: user_id=_USER_ID, session_id=1, feedback_secret_key="abc123", # pragma: allowlist secret - query_text="test query", + query_text=generate_synthetic_query(), query_generate_llm_response=False, query_metadata={}, query_datetime_utc=dt, @@ -339,6 +339,104 @@ def add_content_data() -> None: session.close() +MATERNAL_HEALTH_TERMS = [ + # General Terms + "pregnancy", + "birth", + "postpartum", + "natal care", + "breastfeeding", + "midwife", + "maternal health", + "childbirth", + "labor", + "delivery", + "newborn", + "baby", + "preterm birth", + "gestational diabetes", + "ultrasound", + "fetal monitoring", + "prenatal care", + "maternity leave", + "family planning", + # Medical Conditions + "preeclampsia", + "eclampsia", + "placenta previa", + "placental abruption", + "hyperemesis gravidarum", + "chorioamnionitis", + "amniotic fluid embolism", + "postpartum hemorrhage", + "polyhydramnios", + "oligohydramnios", + "intrauterine growth restriction", + "stillbirth", + "hemolytic disease", + # Procedures and Tests + "amniocentesis", + "chorionic villus sampling", + "non-stress test", + "biophysical profile", + "doppler ultrasound", + "glucose tolerance test", + "cervical check", + "internal fetal monitoring", + # Support and Care + "lactation consultant", + "doula", + "support group", + "parenting classes", + "infant care", + "postpartum support", + "mental health screening", + "breastfeeding support", + "pediatric care", + # Wellness and Lifestyle + "nutrition during pregnancy", + "exercise during pregnancy", + "birth plan", + "home birth", + "hospital birth", + "water birth", + "natural birth", + "epidural", + "pain management", + "birthing center", + # Emotional and Psychological Aspects + "postpartum depression", + "anxiety", + "parenting stress", + "bonding with baby", + "maternal bonding", + "new parent support", + "adjustment to parenthood", + "family dynamics", +] + +# Common query templates +QUERY_TEMPLATES = [ + "What are the symptoms of {term}?", + "How can I manage {term} during pregnancy?", + "What is {term} and how does it affect childbirth?", + "Where can I find support for {term}?", + "What are the latest treatments for {term}?", + "Is {term} common during pregnancy?", + "How does {term} impact postpartum recovery?", + "What should I know about {term} before giving birth?", + "Can {term} affect my baby’s health?", + "What are the best practices for dealing with {term}?", +] + + +def generate_synthetic_query() -> str: + """Generates a random human-like query related to maternal health.""" + template = random.choice(QUERY_TEMPLATES) + term = random.choice(MATERNAL_HEALTH_TERMS) + return template.format(term=term) + + if __name__ == "__main__": add_content_data() add_year_data() diff --git a/core_backend/app/config.py b/core_backend/app/config.py index 2dcadf86a..618a174ba 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -55,6 +55,10 @@ LITELLM_MODEL_DASHBOARD_SUMMARY = os.environ.get( "LITELLM_MODEL_DASHBOARD_SUMMARY", "openai/dashboard-summary" ) + +LITELLM_MODEL_TOPIC_MODEL = os.environ.get( + "LITELLM_MODEL_TOPIC_MODEL", "openai/topic-label" +) # On/Off Topic variables SERVICE_IDENTITY = os.environ.get( "SERVICE_IDENTITY", "air pollution and air quality chatbot" diff --git a/core_backend/app/dashboard/config.py b/core_backend/app/dashboard/config.py index 9319f81c0..47ce47e26 100644 --- a/core_backend/app/dashboard/config.py +++ b/core_backend/app/dashboard/config.py @@ -8,3 +8,5 @@ MAX_FEEDBACK_RECORDS_FOR_TOP_CONTENT = os.environ.get( "MAX_FEEDBACK_RECORDS_FOR_TOP_CONTENT", 7 ) + +TOPIC_MODELING_CONTEXT = os.environ.get("TOPIC_MODELING_CONTEXT", "maternal health") diff --git a/core_backend/app/dashboard/models.py b/core_backend/app/dashboard/models.py index 3c462e639..f54ccc387 100644 --- a/core_backend/app/dashboard/models.py +++ b/core_backend/app/dashboard/models.py @@ -1,6 +1,6 @@ """This module contains functionalities for managing the dashboard statistics.""" -from datetime import date +from datetime import date, datetime, timezone from typing import Any, Sequence, cast, get_args from sqlalchemy import Row, case, desc, func, literal_column, select, text, true @@ -33,8 +33,11 @@ TopContentTimeSeries, UrgencyStats, UserFeedback, + UserQuery, ) +N_SAMPLES_TOPIC_MODELING = 2000 + async def get_stats_cards( *, user_id: int, asession: AsyncSession, start_date: date, end_date: date @@ -1245,3 +1248,52 @@ def get_percentage_increase(n_curr: int, n_prev: int) -> float: return 0.0 return (n_curr - n_prev) / n_prev + + +async def get_raw_queries( + asession: AsyncSession, + user_id: int, + start_date: date, +) -> list[UserQuery]: + """Retrieve 2000 randomly sampled raw queries (query_text) and their + datetime stamps within the specified date range. + Parameters + ---------- + asession + `AsyncSession` object for database transactions. + user_id + The ID of the user to retrieve the queries for. + start_date + The starting date for the queries. + Returns + ------- + list[UserQuery] + A list of UserQuery objects + """ + + statement = ( + select(QueryDB.query_text, QueryDB.query_datetime_utc, QueryDB.query_id) + .where( + (QueryDB.user_id == user_id) + & (QueryDB.query_datetime_utc >= start_date) + & (QueryDB.query_datetime_utc < datetime.now(tz=timezone.utc)) + ) + .order_by(func.random()) + .limit(N_SAMPLES_TOPIC_MODELING) + ) + + result = await asession.execute(statement) + rows = result.fetchall() + if not rows: + query_list = [] + else: + query_list = [ + UserQuery( + query_id=row.query_id, + query_text=row.query_text, + query_datetime_utc=row.query_datetime_utc, + ) + for row in rows + ] + + return query_list diff --git a/core_backend/app/dashboard/routers.py b/core_backend/app/dashboard/routers.py index b516ae8d0..bf65d7655 100644 --- a/core_backend/app/dashboard/routers.py +++ b/core_backend/app/dashboard/routers.py @@ -1,10 +1,12 @@ """This module contains the FastAPI router for the dashboard endpoints.""" +import json from datetime import date, datetime, timedelta, timezone from typing import Annotated, Literal, Tuple from dateutil.relativedelta import relativedelta from fastapi import APIRouter, Depends +from fastapi.requests import Request from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import get_current_user @@ -20,6 +22,7 @@ get_content_details, get_heatmap, get_overview_timeseries, + get_raw_queries, get_stats_cards, get_timeseries_top_content, get_top_content, @@ -30,7 +33,9 @@ DashboardPerformance, DetailsDrawer, TimeFrequency, + TopicsData, ) +from .topic_modeling import topic_model_queries TAG_METADATA = { "name": "Dashboard", @@ -257,3 +262,81 @@ async def retrieve_overview( time_series=time_series, top_content=top_content, ) + + +@router.get("/insights/{time_frequency}/refresh", response_model=dict) +async def refresh_insights_frequency( + time_frequency: DashboardTimeFilter, + user_db: Annotated[UserDB, Depends(get_current_user)], + request: Request, + asession: AsyncSession = Depends(get_async_session), +) -> dict: + """ + Refresh topic modelling insights for the time period specified. + """ + + _, start_date = get_frequency_and_startdate(time_frequency) + + await refresh_insights( + time_frequency=time_frequency, + user_db=user_db, + request=request, + start_date=start_date, + asession=asession, + ) + + return {"status": "success"} + + +async def refresh_insights( + time_frequency: DashboardTimeFilter, + user_db: Annotated[UserDB, Depends(get_current_user)], + request: Request, + start_date: date, + asession: AsyncSession = Depends(get_async_session), +) -> TopicsData: + """ + Retrieve topic modelling insights for the time period specified + and write to Redis. + """ + + redis = request.app.state.redis + time_period_queries = await get_raw_queries( + user_id=user_db.user_id, + asession=asession, + start_date=start_date, + ) + topic_output = await topic_model_queries(user_db.user_id, time_period_queries) + + await redis.set( + f"{user_db.username}_insights_{time_frequency}_results", + topic_output.model_dump_json(), + ) + return topic_output + + +@router.get("/insights/{time_frequency}", response_model=TopicsData) +async def retrieve_insights_frequency( + time_frequency: DashboardTimeFilter, + user_db: Annotated[UserDB, Depends(get_current_user)], + request: Request, +) -> TopicsData: + """ + Retrieve topic modelling insights for the time period specified. + """ + + redis = request.app.state.redis + + if await redis.exists(f"{user_db.username}_insights_{time_frequency}_results"): + payload = await redis.get( + f"{user_db.username}_insights_{time_frequency}_results" + ) + parsed_payload = json.loads(payload) + topics_data = TopicsData(**parsed_payload) + return topics_data + + return TopicsData( + refreshTimeStamp="", + data=[], + unclustered_queries=[], + ) diff --git a/core_backend/app/dashboard/schemas.py b/core_backend/app/dashboard/schemas.py index 67f9d7b29..3fc0dfb4e 100644 --- a/core_backend/app/dashboard/schemas.py +++ b/core_backend/app/dashboard/schemas.py @@ -177,6 +177,49 @@ class DashboardOverview(BaseModel): top_content: list[TopContent] +class Topic(BaseModel): + """ + This class is used to define the schema for one topic + extracted from the user queries. Used for Insights page. + """ + + topic_id: int + topic_samples: list[dict[str, str]] + topic_name: str + topic_summary: str + topic_popularity: int + + +class TopicsData(BaseModel): + """ + This class is used to define the schema for the a large group + of individual Topics. Used for Insights page. + """ + + refreshTimeStamp: str + data: list[Topic] + unclustered_queries: list[dict[str, str]] + + +class UserQuery(BaseModel): + """ + This class is used to define the schema for the insights queries + """ + + query_id: int + query_text: str + query_datetime_utc: datetime + + +class QueryCollection(BaseModel): + """ + This class is used to define the schema for the insights queries data + """ + + n_queries: int + queries: list[UserQuery] + + class UserFeedback(BaseModel): """ This class is used to define the schema for the user feedback diff --git a/core_backend/app/dashboard/topic_modeling.py b/core_backend/app/dashboard/topic_modeling.py new file mode 100644 index 000000000..4c98e03af --- /dev/null +++ b/core_backend/app/dashboard/topic_modeling.py @@ -0,0 +1,108 @@ +""" +This module contains the main function for the topic modelling pipeline. +""" + +import asyncio +from datetime import datetime, timezone + +import pandas as pd +from bertopic import BERTopic +from hdbscan import HDBSCAN +from sentence_transformers import SentenceTransformer + +from ..llm_call.dashboard import generate_topic_label +from .config import TOPIC_MODELING_CONTEXT +from .schemas import Topic, TopicsData, UserQuery + + +async def topic_model_queries(user_id: int, data: list[UserQuery]) -> TopicsData: + """Turn a list of raw queries, run them through a BERTopic pipeline + and return the Data for the front end. + + Parameters + ---------- + user_id : int + The ID of the user making the request. + data + A list of UserQuery objects containing the raw queries and their + corresponding datetime stamps. + + Returns + ------- + list[tuple[str, datetime]] + A list of tuples where each tuple contains the raw query + (query_text) and its corresponding datetime stamp (query_datetime_utc). + """ + + if not data: + return TopicsData(refreshTimeStamp="", data=[], unclustered_queries=[]) + + query_df = pd.DataFrame.from_records([x.model_dump() for x in data]) + query_df["query_datetime_utc"] = query_df["query_datetime_utc"].astype(str) + docs = query_df["query_text"].tolist() + + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + hdbscan_model = HDBSCAN( + min_cluster_size=15, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) + topic_model = BERTopic(hdbscan_model=hdbscan_model).fit(docs, embeddings) + query_df["topic_id"], query_df["probs"] = topic_model.transform(docs, embeddings) + + # Queries with low probability of being in a cluster assigned -1 + query_df.loc[query_df["probs"] < 0.75, "topic_id"] = -1 + unclustered_examples = [ + { + "query_text": str(row.query_text), + "query_datetime_utc": str(row.query_datetime_utc), + } + for row in query_df.loc[query_df["topic_id"] == -1].itertuples() + ] + + query_df = query_df.loc[query_df["probs"] > 0.8] + + _idx = 0 + topic_data = [] + tasks = [] + for _, topic_df in query_df.groupby("topic_id"): + topic_samples = topic_df[["query_text", "query_datetime_utc"]][:5] + tasks.append( + generate_topic_label( + user_id, + TOPIC_MODELING_CONTEXT, + topic_samples["query_text"].tolist(), + ) + ) + + topic_dicts = await asyncio.gather(*tasks) + for topic_dict, (topic_id, topic_df) in zip( + topic_dicts, query_df.groupby("topic_id") + ): + topic_samples_slice = topic_df[["query_text", "query_datetime_utc"]][:20] + string_topic_samples = [ + { + "query_text": str(sample["query_text"]), + "query_datetime_utc": str(sample["query_datetime_utc"]), + } + for sample in topic_samples_slice.to_dict(orient="records") + ] + topic_data.append( + Topic( + topic_id=int(topic_id) if isinstance(topic_id, int) else -1, + topic_name=topic_dict["topic_title"], + topic_summary=topic_dict["topic_summary"], + topic_samples=string_topic_samples, + topic_popularity=len(topic_df), + ) + ) + + topic_data = sorted(topic_data, key=lambda x: x.topic_popularity, reverse=True) + return TopicsData( + refreshTimeStamp=datetime.now(timezone.utc).isoformat(), + data=topic_data, + unclustered_queries=unclustered_examples, + ) diff --git a/core_backend/app/llm_call/dashboard.py b/core_backend/app/llm_call/dashboard.py index fb3f77668..c6833fd9f 100644 --- a/core_backend/app/llm_call/dashboard.py +++ b/core_backend/app/llm_call/dashboard.py @@ -2,9 +2,9 @@ These are LLM functions used by the dashbaord. """ -from ..config import LITELLM_MODEL_DASHBOARD_SUMMARY +from ..config import LITELLM_MODEL_DASHBOARD_SUMMARY, LITELLM_MODEL_TOPIC_MODEL from ..utils import create_langfuse_metadata, setup_logger -from .llm_prompts import get_feedback_summary_prompt +from .llm_prompts import TopicModelLabelling, get_feedback_summary_prompt from .utils import _ask_llm_async logger = setup_logger("DASHBOARD AI SUMMARY") @@ -33,3 +33,41 @@ async def generate_ai_summary( logger.info(f"AI Summary generated for {content_title} with feedback: {feedback}") return ai_summary + + +async def generate_topic_label( + user_id: int, + context: str, + sample_texts: list[str], +) -> dict[str, str]: + """ + Generates topic labels for example queries + """ + metadata = create_langfuse_metadata(feature_name="topic-modeling", user_id=user_id) + topic_model_labelling = TopicModelLabelling(context) + + combined_texts = "\n".join( + [f"{i+1}. {text}" for i, text in enumerate(sample_texts)] + ) + + topic_json = await _ask_llm_async( + user_message=combined_texts, + system_message=topic_model_labelling.get_prompt(), + litellm_model=LITELLM_MODEL_TOPIC_MODEL, + metadata=metadata, + json=True, + ) + + try: + topic = topic_model_labelling.parse_json(topic_json) + except ValueError as e: + logger.warning( + ( + f"Error generating topic label for {context}: {e}. " + "Setting topic to 'Unknown'" + ) + ) + topic = {"topic_title": "Unknown", "topic_summary": "Not available."} + + logger.info(f"Topic label generated for {context}: {topic}") + return topic diff --git a/core_backend/app/llm_call/llm_prompts.py b/core_backend/app/llm_call/llm_prompts.py index 1c8b31f82..c5839e060 100644 --- a/core_backend/app/llm_call/llm_prompts.py +++ b/core_backend/app/llm_call/llm_prompts.py @@ -389,3 +389,73 @@ def get_feedback_summary_prompt(content_title: str, content: str) -> str: content_title=content_title, content=content, ) + + +class TopicModelLabelling: + """ + Topic model labelling task. + """ + + class TopicModelLabellingResult(BaseModel): + """ + Pydantic model for the output of the topic model labelling task. + """ + + topic_title: str + topic_summary: str + + _context: str + + _prompt_base: str = textwrap.dedent( + """ + You are a summarization bot designed to condense multiple + messages into a topic description specific to {context}. If unknown, respond + with topic_title as "Unknown" and topic_summary as "Not available". + + When coming up with topic_title, be very concise. + "topic_summary" should be a summary of the topics found in the + provided messages. It expands on the topic_title. Restrict it to ONLY + summarization. Do not include any additional information. + """ + ).strip() + + _response_prompt: str = textwrap.dedent( + """ + Respond in json string: + + { + topic_title: str + topic_summary: str + } + """ + ).strip() + + def __init__(self, context: str) -> None: + """ + Initialize the topic model labelling task with context. + """ + self._context = context + + def get_prompt(self) -> str: + """ + Returns the prompt for the topic model labelling task. + """ + prompt = self._prompt_base.format(context=self._context) + + return prompt + "\n\n" + self._response_prompt + + def parse_json(self, json_str: str) -> dict[str, str]: + """ + Validates the output of the topic model labelling task. + """ + + json_str = remove_json_markdown(json_str) + + try: + result = TopicModelLabelling.TopicModelLabellingResult.model_validate_json( + json_str + ) + except ValueError as e: + raise ValueError(f"Error validating the output: {e}") from e + + return result.model_dump() diff --git a/core_backend/requirements.txt b/core_backend/requirements.txt index 088fc32cd..c08a40dd5 100644 --- a/core_backend/requirements.txt +++ b/core_backend/requirements.txt @@ -18,6 +18,8 @@ pandas-stubs==2.2.2.240603 types-openpyxl==3.1.4.20240621 redis==5.0.8 python-dateutil==2.8.2 +gTTS==2.5.1 +bertopic==0.16.3 google-cloud-storage==2.18.2 google-cloud-texttospeech==2.16.5 google-cloud-speech==2.27.0 diff --git a/deployment/docker-compose/litellm_proxy_config.yaml b/deployment/docker-compose/litellm_proxy_config.yaml index ad818c1f8..44a81762c 100644 --- a/deployment/docker-compose/litellm_proxy_config.yaml +++ b/deployment/docker-compose/litellm_proxy_config.yaml @@ -54,6 +54,10 @@ model_list: litellm_params: model: gpt-4o api_key: "os.environ/OPENAI_API_KEY" + - model_name: topic-label + litellm_params: + model: gpt-4o + api_key: "os.environ/OPENAI_API_KEY" - model_name: alignscore litellm_params: # Set VERTEXAI_ENDPOINT environment variable or directly enter the value: diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 000000000..8dbefdaa9 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "aaq-core", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/package.json b/package.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/package.json @@ -0,0 +1 @@ +{} diff --git a/pyproject.toml b/pyproject.toml index 90f792412..ad4a9bd22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,15 @@ disallow_untyped_defs = true [[tool.mypy.overrides]] -module = ['litellm', "nltk", "alignscore","pgvector.sqlalchemy", "google.auth.transport", "google.oauth2", "google.cloud", "pydub"] +module = ['litellm', "nltk", "alignscore","pgvector.sqlalchemy", + "google.auth.transport", "google.oauth2", + "gtts", "google.cloud", "bertopic", "hdbscan", "pydub"] ignore_missing_imports = true + [tool.ruff] lint.select = ["E", "F", "B", "Q", "I"] line-length = 88 -ignore = ["B008"] # Do not perform function calls in argument defaults. +lint.ignore = ["B008"] # Do not perform function calls in argument defaults. [tool.ruff.lint.flake8-bugbear] extend-immutable-calls = ["fastapi.Depends", "fastapi.params.Depends","typer.Option"]