Skip to content

Commit

Permalink
update progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 16, 2024
1 parent e0fc9e9 commit a17bfbb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 19 deletions.
2 changes: 1 addition & 1 deletion vizro-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"openai>=1.0.0",
"langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class, update to pydantic v2 and remove upper bound
"langchain-openai",
"langgraph",
"langgraph>=0.1.2",
"python-dotenv>=1.0.0", # TODO decide env var management to see if we need this
"vizro>=0.1.4", # TODO set upper bound later
"ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382
Expand Down
1 change: 1 addition & 0 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"Anthropic": [
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
],
}

Expand Down
41 changes: 25 additions & 16 deletions vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from langchain_core.runnables import RunnableConfig
from langgraph.constants import END, Send
from langgraph.graph import StateGraph
from tqdm.auto import tqdm
from vizro_ai.dashboard.nodes.build import PageBuilder
from vizro_ai.dashboard.nodes.data_summary import DfInfo, _get_df_info, df_sum_prompt
from vizro_ai.dashboard.nodes.plan import (
DashboardPlanner,
PagePlanner,
_get_dashboard_plan,
)
from vizro_ai.dashboard.utils import DataFrameMetadata, DfMetadata
from vizro_ai.dashboard.utils import DataFrameMetadata, DfMetadata, _execute_step

try:
from pydantic.v1 import BaseModel, validator
Expand All @@ -26,7 +27,6 @@


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


Messages = List[BaseMessage]
Expand Down Expand Up @@ -71,37 +71,48 @@ def check_dataframes(cls, v):

def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMetadata]:
"""Store information about the dataframes."""
logger.info("*** _store_df_info ***")
dfs = state.dfs
df_metadata = state.df_metadata
messages = state.messages
current_df_names = []
for df in dfs:
df_schema, df_sample = _get_df_info(df)
with tqdm(total=len(dfs), desc="Store df info") as pbar:
for df in dfs:
df_schema, df_sample = _get_df_info(df)

llm = config["configurable"].get("model", None)
data_sum_chain = df_sum_prompt | llm.with_structured_output(DfInfo)
llm = config["configurable"].get("model", None)
data_sum_chain = df_sum_prompt | llm.with_structured_output(DfInfo)

df_name = data_sum_chain.invoke(
{"messages": messages, "df_schema": df_schema, "df_sample": df_sample, "current_df_names": current_df_names}
).dataset_name
df_name = data_sum_chain.invoke(
{
"messages": messages,
"df_schema": df_schema,
"df_sample": df_sample,
"current_df_names": current_df_names,
}
).dataset_name

current_df_names.append(df_name)
current_df_names.append(df_name)

logger.info(f"df_name: {df_name}")
df_metadata.metadata[df_name] = DataFrameMetadata(df_schema=df_schema, df=df, df_sample=df_sample)
pbar.write(f"df_name: {df_name}")
pbar.update(1)
df_metadata.metadata[df_name] = DataFrameMetadata(df_schema=df_schema, df=df, df_sample=df_sample)

return {"df_metadata": df_metadata}


def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlanner]:
"""Generate a dashboard plan."""
logger.info("*** _dashboard_plan ***")
node_desc = "Generate dashboard plan"
pbar = tqdm(total=2, desc=node_desc)
query = state.messages[0].content
df_metadata = state.df_metadata

llm = config["configurable"].get("model", None)

_execute_step(pbar, node_desc + " --> in progress", None)
dashboard_plan = _get_dashboard_plan(query=query, model=llm, df_metadata=df_metadata)
_execute_step(pbar, node_desc + " --> done", None)
pbar.close()

return {"dashboard_plan": dashboard_plan}

Expand Down Expand Up @@ -136,7 +147,6 @@ def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List

def _continue_to_pages(state: GraphState) -> List[Send]:
"""Continue to build pages."""
logger.info("*** build_page ***")
df_metadata = state.df_metadata
return [
Send(node="_build_page", arg={"page_plan": v, "df_metadata": df_metadata}) for v in state.dashboard_plan.pages
Expand All @@ -145,7 +155,6 @@ def _continue_to_pages(state: GraphState) -> List[Send]:

def _build_dashboard(state: GraphState) -> Dict[str, vm.Dashboard]:
"""Build a dashboard."""
logger.info("*** build_dashboard ***")
dashboard_plan = state.dashboard_plan
pages = state.pages

Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], str]:
Inspect the provided data and give a short unique name to the dataset. \n
Here is the dataframe sample: \n ------- \n {df_sample} \n ------- \n
Here is the schema: \n ------- \n {df_schema} \n ------- \n
Avoid the following names currently in use: \n ------- \n {current_df_names} \n ------- \n
AVOID the following names: \n ------- \n {current_df_names} \n ------- \n
\n ------- \n
""",
),
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/src/vizro_ai/dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class DashboardOutputs:


def _execute_step(pbar: tqdm.std.tqdm, description: str, value: Any) -> Any:
pbar.set_description_str(description)
pbar.update(1)
pbar.set_description_str(description)
return value


Expand Down

0 comments on commit a17bfbb

Please sign in to comment.