Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 16, 2024
1 parent a17bfbb commit 7e0cbef
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 42 deletions.
2 changes: 1 addition & 1 deletion vizro-ai/snyk/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ tabulate
openai>=1.0.0
langchain>=0.1.0, <0.3.0
langchain-openai
langgraph
langgraph>=0.1.2
python-dotenv>=1.0.0
vizro>=0.1.4
ipython>=8.10.0
Expand Down
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vizro_ai.chains._llm_models import _get_llm_model, _get_model_name
from vizro_ai.components import GetCodeExplanation, GetDebugger
from vizro_ai.dashboard.graph.code_generation import _create_and_compile_graph
from vizro_ai.dashboard.graph.dashboard_creation import _create_and_compile_graph
from vizro_ai.dashboard.utils import DashboardOutputs, _dashboard_code
from vizro_ai.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
Expand Down Expand Up @@ -166,7 +166,7 @@ def dashboard(
user_input: str,
return_elements: bool = False,
) -> Union[DashboardOutputs, vm.Dashboard]:
"""Create dashboard using vizro via english descriptions, english to dashboard translation.
"""Create a Vizro dashboard using english descriptions.
Args:
dfs: The dataframes to be analyzed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import operator
from typing import Annotated, Dict, List
from typing import Annotated, Dict, List, Optional

import pandas as pd
import vizro.models as vm
Expand All @@ -11,14 +11,14 @@
from langgraph.constants import END, Send
from langgraph.graph import StateGraph
from tqdm.auto import tqdm
from vizro_ai.dashboard.nodes._pydantic_output import _get_pydantic_output
from vizro_ai.dashboard.nodes.build import PageBuilder
from vizro_ai.dashboard.nodes.data_summary import DfInfo, _get_df_info, df_sum_prompt
from vizro_ai.dashboard.nodes.plan import (
DashboardPlanner,
PagePlanner,
_get_dashboard_plan,
)
from vizro_ai.dashboard.utils import DataFrameMetadata, DfMetadata, _execute_step
from vizro_ai.dashboard.utils import DfMetadata, MetadataContent, _execute_step

try:
from pydantic.v1 import BaseModel, validator
Expand Down Expand Up @@ -49,9 +49,9 @@ class GraphState(BaseModel):
messages: List[BaseMessage]
dfs: List[pd.DataFrame]
df_metadata: DfMetadata
dashboard_plan: DashboardPlanner = None
dashboard_plan: Optional[DashboardPlanner] = None
pages: Annotated[List, operator.add]
dashboard: vm.Dashboard = None
dashboard: Optional[vm.Dashboard] = None

class Config:
"""Pydantic configuration."""
Expand Down Expand Up @@ -95,7 +95,7 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet

pbar.write(f"df_name: {df_name}")
pbar.update(1)
df_metadata.metadata[df_name] = DataFrameMetadata(df_schema=df_schema, df=df, df_sample=df_sample)
df_metadata.metadata[df_name] = MetadataContent(df_schema=df_schema, df=df, df_sample=df_sample)

return {"df_metadata": df_metadata}

Expand All @@ -110,7 +110,9 @@ def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, Dash
llm = config["configurable"].get("model", None)

_execute_step(pbar, node_desc + " --> in progress", None)
dashboard_plan = _get_dashboard_plan(query=query, model=llm, df_metadata=df_metadata)
dashboard_plan = _get_pydantic_output(
query=query, llm_model=llm, result_model=DashboardPlanner, df_info=df_metadata.get_schemas_and_samples()
)
_execute_step(pbar, node_desc + " --> done", None)
pbar.close()

Expand Down Expand Up @@ -146,7 +148,7 @@ def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List


def _continue_to_pages(state: GraphState) -> List[Send]:
"""Continue to build pages."""
"""Map-reduce logic to build pages in parallel."""
df_metadata = state.df_metadata
return [
Send(node="_build_page", arg={"page_plan": v, "df_metadata": df_metadata}) for v in state.dashboard_plan.pages
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the _get_structured_output for the Vizro AI dashboard."""
"""Contains the _get_pydantic_output for the Vizro AI dashboard."""

# ruff: noqa: F821
try:
Expand Down Expand Up @@ -40,7 +40,7 @@
)


def _get_structured_output(
def _get_pydantic_output(
query: str,
llm_model: BaseChatModel,
result_model: BaseModel,
Expand Down
25 changes: 13 additions & 12 deletions vizro-ai/src/vizro_ai/dashboard/nodes/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ def components(self):
def _build_components(self):
components = []
logger.info(f"Building components of page: {self._page_plan.title}")
# Could potentially be parallelized or sent as a batch to the API
for i in trange(
len(self._page_plan.components), desc=f"Building components of page: {self._page_plan.title}", leave=False
len(self._page_plan.components_plan),
desc=f"Building components of page: {self._page_plan.title}",
leave=False,
):
logger.info(f"{self._page_plan.title} -> Building component {self._page_plan.components[i]}")
logger.info(f"{self._page_plan.title} -> Building component {self._page_plan.components_plan[i]}")
try:
components.append(
self._page_plan.components[i].create(df_metadata=self._df_metadata, model=self._model)
self._page_plan.components_plan[i].create(df_metadata=self._df_metadata, model=self._model)
)
except DebugFailure as e: # TODO: check - does this ever get raised?
except DebugFailure as e:
components.append(
vm.Card(id=self._page_plan.components[i].component_id, text=f"Failed to build component: {e}")
vm.Card(id=self._page_plan.components_plan[i].component_id, text=f"Failed to build component: {e}")
)
return components

Expand All @@ -56,10 +57,10 @@ def layout(self):
return self._layout

def _build_layout(self):
if self._page_plan.layout is None:
if self._page_plan.layout_plan is None:
return None
logger.info(f"{self._page_plan.title} -> Building layout {self._page_plan.layout}")
return self._page_plan.layout.create(model=self._model)
logger.info(f"{self._page_plan.title} -> Building layout {self._page_plan.layout_plan}")
return self._page_plan.layout_plan.create(model=self._model)

@property
def controls(self):
Expand All @@ -78,10 +79,10 @@ def _build_controls(self):
logger.info(f"Building controls of page: {self._page_plan.title}")
# Could potentially be parallelized or sent as a batch to the API
for i in trange(
len(self._page_plan.controls), desc=f"Building controls of page: {self._page_plan.title}", leave=False
len(self._page_plan.controls_plan), desc=f"Building controls of page: {self._page_plan.title}", leave=False
):
logger.info(f"{self._page_plan.title} -> Building control {self._page_plan.controls[i]}")
control = self._page_plan.controls[i].create(
logger.info(f"{self._page_plan.title} -> Building control {self._page_plan.controls_plan[i]}")
control = self._page_plan.controls_plan[i].create(
model=self._model, available_components=self.available_components, df_metadata=self._df_metadata
)
if control:
Expand Down
22 changes: 11 additions & 11 deletions vizro-ai/src/vizro_ai/dashboard/nodes/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
from vizro.models._layout import _get_grid_lines, _get_unique_grid_component_ids, _validate_grid_areas
from vizro.tables import dash_ag_grid
from vizro_ai.dashboard.nodes._model import _get_structured_output
from vizro_ai.dashboard.nodes._pydantic_output import _get_pydantic_output

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +32,7 @@
# For other models, like ["Accordion", "NavBar"], how to handle them?


class Component(BaseModel):
class ComponentPlan(BaseModel):
"""Component plan model."""

component_type: component_type
Expand Down Expand Up @@ -66,7 +66,7 @@ def create(self, model, df_metadata) -> Union[ComponentType, None]:
elif self.component_type == "AgGrid":
return vm.AgGrid(id=self.component_id + "_" + self.page_id, figure=dash_ag_grid(data_frame=self.df_name))
elif self.component_type == "Card":
return _get_structured_output(
return _get_pydantic_output(
query=self.component_description, llm_model=model, result_model=vm.Card, df_info=None
)

Expand Down Expand Up @@ -110,7 +110,7 @@ def validate_column(v):
)


class Control(BaseModel):
class ControlPlan(BaseModel):
"""Control plan model."""

control_type: control_type
Expand Down Expand Up @@ -144,7 +144,7 @@ def create(self, model, available_components, df_metadata):

try:
result_proxy = create_filter_proxy(df_cols=_df_cols, available_components=available_components)
proxy = _get_structured_output(
proxy = _get_pydantic_output(
query=filter_prompt, llm_model=model, result_model=result_proxy, df_info=_df_schema
)
logger.info(
Expand Down Expand Up @@ -187,7 +187,7 @@ def validate_grid(cls, grid):
return grid


class Layout(BaseModel):
class LayoutPlan(BaseModel):
"""Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card)."""

layout_description: str = Field(
Expand Down Expand Up @@ -215,7 +215,7 @@ def create(self, model) -> Union[vm.Layout, None]:
return None

try:
proxy = _get_structured_output(
proxy = _get_pydantic_output(
query=layout_prompt, llm_model=model, result_model=LayoutProxyModel, df_info=None
)
actual = vm.Layout.parse_obj(proxy.dict(exclude={}))
Expand All @@ -234,9 +234,9 @@ class PagePlanner(BaseModel):
description="Title of the page. If no description is provided, "
"make a short and concise title from the components.",
)
components: List[Component]
controls: List[Control] = Field([], description="Controls of the page.")
layout: Layout = Field(None, description="Layout of the page.")
components_plan: List[ComponentPlan]
controls_plan: List[ControlPlan] = Field([], description="Controls of the page.")
layout_plan: LayoutPlan = Field(None, description="Layout of the page.")


class DashboardPlanner(BaseModel):
Expand All @@ -255,6 +255,6 @@ def _get_dashboard_plan(
model: Union[ChatOpenAI],
df_metadata: DfMetadata,
) -> DashboardPlanner:
return _get_structured_output(
return _get_pydantic_output(
query=query, llm_model=model, result_model=DashboardPlanner, df_info=df_metadata.get_schemas_and_samples()
)
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@


@dataclass
class DataFrameMetadata:
"""Dataclass containing metadata for a dataframe."""
class MetadataContent:
"""Dataclass containing metadata content for a dataframe."""

df_schema: Dict[str, str]
df: pd.DataFrame
Expand All @@ -31,7 +31,7 @@ class DataFrameMetadata:
class DfMetadata:
"""Dataclass containing metadata for all dataframes."""

metadata: Dict[str, DataFrameMetadata] = field(default_factory=dict)
metadata: Dict[str, MetadataContent] = field(default_factory=dict)

def get_schemas_and_samples(self) -> Dict[str, Dict[str, str]]:
"""Retrieve only the df_schema and df_sample for all datasets."""
Expand Down
6 changes: 3 additions & 3 deletions vizro-ai/tests/unit/vizro-ai/dashboard/nodes/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import vizro.models as vm
from vizro_ai.dashboard.nodes._model import _get_structured_output
from vizro_ai.dashboard.nodes._pydantic_output import _get_pydantic_output


def test_get_structured_output(component_description, fake_llm):
structured_output = _get_structured_output(
def test_get_pydantic_output(component_description, fake_llm):
structured_output = _get_pydantic_output(
query=component_description, llm_model=fake_llm, result_model=vm.Card, df_info=None
)
assert structured_output.dict(exclude={"id": True}) == vm.Card(text="this is a card", href="").dict(
Expand Down

0 comments on commit 7e0cbef

Please sign in to comment.