Skip to content

Commit

Permalink
convert df_metadata to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 11, 2024
1 parent c99c9c2 commit dbd5713
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
14 changes: 5 additions & 9 deletions vizro-ai/src/vizro_ai/dashboard/graph/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PagePlanner,
_get_dashboard_plan,
)
from vizro_ai.dashboard.utils import DataFrameMetadata, DfMetadata

try:
from pydantic.v1 import BaseModel, validator
Expand All @@ -29,8 +30,6 @@
logger.setLevel(logging.INFO)


DfMetadata = Dict[str, Dict[str, Union[Dict[str, str], pd.DataFrame]]]
"""Cleaned dataframe names and their metadata."""

Messages = List[BaseMessage]
"""List of messages."""
Expand Down Expand Up @@ -87,15 +86,12 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet

df_name = data_sum_chain.invoke(
{"messages": messages, "df_schema": df_schema, "df_sample": df_sample, "current_df_names": current_df_names}
)
).dataset_name

current_df_names.append(df_name)

cleaned_df_name = df_name.dataset_name.lower()
cleaned_df_name = re.sub(r"\W+", "_", cleaned_df_name)
df_id = cleaned_df_name.strip("_")
logger.info(f"df_name: {df_name} --> df_id: {df_id}")
df_metadata[df_id] = {"df_schema": df_schema, "df": df} # TODO: shall be a dataclass
logger.info(f"df_name: {df_name}")
df_metadata.metadata[df_name] = DataFrameMetadata(df_schema=df_schema, df=df)

return {"df_metadata": df_metadata}

Expand All @@ -121,7 +117,7 @@ class BuildPageState(BaseModel):
"""

df_metadata: Dict[str, Dict[str, Any]]
df_metadata: DfMetadata
page_plan: PagePlanner = None


Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/dashboard/nodes/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from vizro_ai.dashboard.utils import DfMetadata

SINGLE_MODEL_PROMPT = ChatPromptTemplate.from_messages(
[
Expand Down Expand Up @@ -45,7 +46,7 @@ def _get_proxy_model(
query: str,
llm_model: BaseChatModel,
result_model: BaseModel,
df_metadata: Dict[str, Dict[str, str]],
df_metadata: DfMetadata,
max_retry: int = 2,
) -> BaseModel:
for i in range(max_retry):
Expand Down
4 changes: 3 additions & 1 deletion vizro-ai/src/vizro_ai/dashboard/nodes/data_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], str]:
class DfInfo(BaseModel):
"""Data Info output."""

dataset_name: str = Field(description="Name of the dataset")
dataset_name: str = Field(
pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case name of the dataset."
)
9 changes: 5 additions & 4 deletions vizro-ai/src/vizro_ai/dashboard/nodes/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import vizro.models as vm
from langchain_openai import ChatOpenAI
from vizro.models.types import ComponentType
from vizro_ai.dashboard.utils import DfMetadata

try:
from pydantic.v1 import BaseModel, Field, ValidationError, create_model, validator
Expand Down Expand Up @@ -63,7 +64,7 @@ def create(self, model, df_metadata) -> Union[ComponentType, None]:
if self.component_type == "Graph":
return vm.Graph(
id=self.component_id+"_"+self.page_id,
figure=vizro_ai.plot(df=df_metadata[self.data_frame]["df"], user_input=self.component_description)
figure=vizro_ai.plot(df=df_metadata.metadata[self.data_frame].df, user_input=self.component_description)
)
elif self.component_type == "AgGrid":
return vm.AgGrid(
Expand Down Expand Up @@ -140,8 +141,8 @@ def create(self, model, available_components, df_metadata):
)
try:
_df_schema, _df = (
df_metadata[self.data_frame]["df_schema"],
df_metadata[self.data_frame]["df"],
df_metadata.metadata[self.data_frame].df_schema,
df_metadata.metadata[self.data_frame].df,
)
_df_cols = list(_df_schema.keys())
# when wrong dataframe name is given
Expand Down Expand Up @@ -246,7 +247,7 @@ class DashboardPlanner(BaseModel):
def _get_dashboard_plan(
query: str,
model: Union[ChatOpenAI],
df_metadata: Dict[str, Dict[str, str]],
df_metadata: DfMetadata,
) -> DashboardPlanner:
return _get_proxy_model(query=query, llm_model=model, result_model=DashboardPlanner, df_metadata=df_metadata)

Expand Down
17 changes: 16 additions & 1 deletion vizro-ai/src/vizro_ai/dashboard/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Helper Functions For Vizro AI dashboard."""

from dataclasses import dataclass
from typing import Any, Dict
from dataclasses import dataclass, field

import pandas as pd

# import black
from typing import Any
Expand All @@ -15,6 +18,18 @@
"import vizro.models as vm\n"
)

@dataclass
class DataFrameMetadata:
"""Dataclass containing metadata for a dataframe."""

df_schema: Dict[str, str]
df: pd.DataFrame

@dataclass
class DfMetadata:
"""Dataclass containing metadata for all dataframes."""

metadata: Dict[str, DataFrameMetadata] = field(default_factory=dict)

@dataclass
class DashboardOutputs:
Expand Down

0 comments on commit dbd5713

Please sign in to comment.