Skip to content

Commit

Permalink
restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 17, 2024
1 parent a97e19c commit 30282ce
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 29 deletions.
18 changes: 12 additions & 6 deletions vizro-ai/src/vizro_ai/dashboard/graph/dashboard_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from langgraph.graph import StateGraph
from tqdm.auto import tqdm
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.dashboard.data_preprocess.df_info import _get_df_info, _get_df_sum_output
from vizro_ai.dashboard.page_build.page import PageBuilder
from vizro_ai.dashboard.plan.dashboard import DashboardPlanner
from vizro_ai.dashboard.plan.page import PagePlanner
from vizro_ai.dashboard.response_models.dashboard import DashboardPlanner
from vizro_ai.dashboard.response_models.df_info import DfInfo, _create_df_info_content, _get_df_info
from vizro_ai.dashboard.response_models.page import PagePlanner
from vizro_ai.dashboard.response_models.page_build import PageBuilder
from vizro_ai.dashboard.utils import DfMetadata, MetadataContent, _execute_step

try:
Expand Down Expand Up @@ -76,10 +76,16 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet
with tqdm(total=len(dfs), desc="Store df info") as pbar:
for df in dfs:
df_schema, df_sample = _get_df_info(df)
df_info = _create_df_info_content(
df_schema=df_schema, df_sample=df_sample, current_df_names=current_df_names
)

llm = config["configurable"].get("model", None)
df_name = _get_df_sum_output(
df_schema=df_schema, df_sample=df_sample, current_df_names=current_df_names, query=query, llm_model=llm
df_name = _get_pydantic_output(
query=query,
llm_model=llm,
result_model=DfInfo,
df_info=df_info,
).dataset_name

current_df_names.append(df_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic.v1 import BaseModel, Field
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field
from vizro_ai.dashboard.plan.page import PagePlanner
from vizro_ai.dashboard.response_models.page import PagePlanner

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Data Summary Node."""

from typing import Any, Dict, Tuple
from typing import Any, Dict, Tuple, List

import pandas as pd
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -38,15 +38,3 @@ def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], pd.DataFrame]:
def _create_df_info_content(df_schema: Any, df_sample: Any, current_df_names: list) -> dict:
"""Create the message content for the dataframe summarization."""
return DF_SUM_PROMPT.format(df_sample=df_sample, df_schema=df_schema, current_df_names=current_df_names)


def _get_df_sum_output(
df_schema: Any, df_sample: Any, current_df_names: list, llm_model: BaseChatModel, query: str
) -> BaseModel:
"""Get the dataframe summary output from the LLM model with retry logic."""
return _get_pydantic_output(
query=query,
llm_model=llm_model,
result_model=DfInfo,
df_info=_create_df_info_content(df_schema=df_schema, df_sample=df_sample, current_df_names=current_df_names),
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from pydantic.v1 import BaseModel, Field, validator
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field, validator
from vizro_ai.dashboard.plan.components import ComponentPlan
from vizro_ai.dashboard.plan.controls import ControlPlan
from vizro_ai.dashboard.plan.layout import LayoutPlan
from vizro_ai.dashboard.response_models.components import ComponentPlan
from vizro_ai.dashboard.response_models.controls import ControlPlan
from vizro_ai.dashboard.response_models.layout import LayoutPlan

logger = logging.getLogger(__name__)

Expand All @@ -34,5 +34,5 @@ def _check_components_plan(cls, v):
raise ValueError("A page must contain at least one component.")
return v

# def create():
# pass
def create():
pass
2 changes: 1 addition & 1 deletion vizro-ai/tests/unit/vizro-ai/dashboard/nodes/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from langchain.output_parsers import PydanticOutputParser
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.dashboard.plan.components import ComponentPlan
from vizro_ai.dashboard.response_models.components import ComponentPlan


class FakeListLLM(FakeListLLM):
Expand Down
6 changes: 3 additions & 3 deletions vizro-ai/tests/unit/vizro-ai/dashboard/nodes/test_plan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from vizro_ai.dashboard.plan.controls import create_filter_proxy
from vizro_ai.dashboard.plan.dashboard import DashboardPlanner
from vizro_ai.dashboard.plan.page import PagePlanner
from vizro_ai.dashboard.response_models.controls import create_filter_proxy
from vizro_ai.dashboard.response_models.dashboard import DashboardPlanner
from vizro_ai.dashboard.response_models.page import PagePlanner

try:
from pydantic.v1 import ValidationError
Expand Down

0 comments on commit 30282ce

Please sign in to comment.