From dbd57133b8aab27a8e44f02b5b0762353481ccf5 Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 11 Jul 2024 03:09:56 -0400 Subject: [PATCH] convert df_metadata to dataclass --- .../vizro_ai/dashboard/graph/code_generation.py | 14 +++++--------- vizro-ai/src/vizro_ai/dashboard/nodes/_model.py | 3 ++- .../vizro_ai/dashboard/nodes/data_summary.py | 4 +++- vizro-ai/src/vizro_ai/dashboard/nodes/plan.py | 9 +++++---- vizro-ai/src/vizro_ai/dashboard/utils.py | 17 ++++++++++++++++- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py b/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py index 7f858299d..7199e65aa 100644 --- a/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py +++ b/vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py @@ -18,6 +18,7 @@ PagePlanner, _get_dashboard_plan, ) +from vizro_ai.dashboard.utils import DataFrameMetadata, DfMetadata try: from pydantic.v1 import BaseModel, validator @@ -29,8 +30,6 @@ logger.setLevel(logging.INFO) -DfMetadata = Dict[str, Dict[str, Union[Dict[str, str], pd.DataFrame]]] -"""Cleaned dataframe names and their metadata.""" Messages = List[BaseMessage] """List of messages.""" @@ -87,15 +86,12 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet df_name = data_sum_chain.invoke( {"messages": messages, "df_schema": df_schema, "df_sample": df_sample, "current_df_names": current_df_names} - ) + ).dataset_name current_df_names.append(df_name) - cleaned_df_name = df_name.dataset_name.lower() - cleaned_df_name = re.sub(r"\W+", "_", cleaned_df_name) - df_id = cleaned_df_name.strip("_") - logger.info(f"df_name: {df_name} --> df_id: {df_id}") - df_metadata[df_id] = {"df_schema": df_schema, "df": df} # TODO: shall be a dataclass + logger.info(f"df_name: {df_name}") + df_metadata.metadata[df_name] = DataFrameMetadata(df_schema=df_schema, df=df) return {"df_metadata": df_metadata} @@ -121,7 +117,7 @@ class BuildPageState(BaseModel): """ - df_metadata: Dict[str, Dict[str, Any]] + df_metadata: DfMetadata page_plan: PagePlanner = None diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/_model.py b/vizro-ai/src/vizro_ai/dashboard/nodes/_model.py index df0687054..36e9c8bba 100644 --- a/vizro-ai/src/vizro_ai/dashboard/nodes/_model.py +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/_model.py @@ -10,6 +10,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate +from vizro_ai.dashboard.utils import DfMetadata SINGLE_MODEL_PROMPT = ChatPromptTemplate.from_messages( [ @@ -45,7 +46,7 @@ def _get_proxy_model( query: str, llm_model: BaseChatModel, result_model: BaseModel, - df_metadata: Dict[str, Dict[str, str]], + df_metadata: DfMetadata, max_retry: int = 2, ) -> BaseModel: for i in range(max_retry): diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py b/vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py index e7f996f0e..b5a1d58b4 100644 --- a/vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py @@ -38,4 +38,6 @@ def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], str]: class DfInfo(BaseModel): """Data Info output.""" - dataset_name: str = Field(description="Name of the dataset") + dataset_name: str = Field( + pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case name of the dataset." + ) diff --git a/vizro-ai/src/vizro_ai/dashboard/nodes/plan.py b/vizro-ai/src/vizro_ai/dashboard/nodes/plan.py index c6a82c2ca..10203a8c1 100644 --- a/vizro-ai/src/vizro_ai/dashboard/nodes/plan.py +++ b/vizro-ai/src/vizro_ai/dashboard/nodes/plan.py @@ -6,6 +6,7 @@ import vizro.models as vm from langchain_openai import ChatOpenAI from vizro.models.types import ComponentType +from vizro_ai.dashboard.utils import DfMetadata try: from pydantic.v1 import BaseModel, Field, ValidationError, create_model, validator @@ -63,7 +64,7 @@ def create(self, model, df_metadata) -> Union[ComponentType, None]: if self.component_type == "Graph": return vm.Graph( id=self.component_id+"_"+self.page_id, - figure=vizro_ai.plot(df=df_metadata[self.data_frame]["df"], user_input=self.component_description) + figure=vizro_ai.plot(df=df_metadata.metadata[self.data_frame].df, user_input=self.component_description) ) elif self.component_type == "AgGrid": return vm.AgGrid( @@ -140,8 +141,8 @@ def create(self, model, available_components, df_metadata): ) try: _df_schema, _df = ( - df_metadata[self.data_frame]["df_schema"], - df_metadata[self.data_frame]["df"], + df_metadata.metadata[self.data_frame].df_schema, + df_metadata.metadata[self.data_frame].df, ) _df_cols = list(_df_schema.keys()) # when wrong dataframe name is given @@ -246,7 +247,7 @@ class DashboardPlanner(BaseModel): def _get_dashboard_plan( query: str, model: Union[ChatOpenAI], - df_metadata: Dict[str, Dict[str, str]], + df_metadata: DfMetadata, ) -> DashboardPlanner: return _get_proxy_model(query=query, llm_model=model, result_model=DashboardPlanner, df_metadata=df_metadata) diff --git a/vizro-ai/src/vizro_ai/dashboard/utils.py b/vizro-ai/src/vizro_ai/dashboard/utils.py index 3559e5cda..fcb4460b3 100644 --- a/vizro-ai/src/vizro_ai/dashboard/utils.py +++ b/vizro-ai/src/vizro_ai/dashboard/utils.py @@ -1,6 +1,9 @@ """Helper Functions For Vizro AI dashboard.""" -from dataclasses import dataclass +from typing import Any, Dict +from dataclasses import dataclass, field + +import pandas as pd # import black from typing import Any @@ -15,6 +18,18 @@ "import vizro.models as vm\n" ) +@dataclass +class DataFrameMetadata: + """Dataclass containing metadata for a dataframe.""" + + df_schema: Dict[str, str] + df: pd.DataFrame + +@dataclass +class DfMetadata: + """Dataclass containing metadata for all dataframes.""" + + metadata: Dict[str, DataFrameMetadata] = field(default_factory=dict) @dataclass class DashboardOutputs: