Skip to content

Commit

Permalink
refactor page build
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 18, 2024
1 parent 30282ce commit 6b678a2
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 147 deletions.
2 changes: 1 addition & 1 deletion vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@ def _get_pydantic_output(
return res
except ValidationError as validation_error:
last_validation_error = validation_error
return last_validation_error
raise last_validation_error
34 changes: 19 additions & 15 deletions vizro-ai/src/vizro_ai/dashboard/graph/dashboard_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from vizro_ai.dashboard.response_models.dashboard import DashboardPlanner
from vizro_ai.dashboard.response_models.df_info import DfInfo, _create_df_info_content, _get_df_info
from vizro_ai.dashboard.response_models.page import PagePlanner
from vizro_ai.dashboard.response_models.page_build import PageBuilder
from vizro_ai.dashboard.utils import DfMetadata, MetadataContent, _execute_step
from vizro_ai.utils.helper import DebugFailure

try:
from pydantic.v1 import BaseModel, validator
Expand Down Expand Up @@ -81,12 +81,16 @@ def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, DfMet
)

llm = config["configurable"].get("model", None)
df_name = _get_pydantic_output(
query=query,
llm_model=llm,
result_model=DfInfo,
df_info=df_info,
).dataset_name
try:
df_name = _get_pydantic_output(
query=query,
llm_model=llm,
result_model=DfInfo,
df_info=df_info,
).dataset_name
except DebugFailure as e:
logger.warning(f"Failed in name generation {e}")
df_name = f"df_{len(current_df_names)}"

current_df_names.append(df_name)

Expand All @@ -111,9 +115,13 @@ def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, Dash
node_desc + " --> in progress \n(this step could take longer " "when more complex requirements are given)",
None,
)
dashboard_plan = _get_pydantic_output(
query=query, llm_model=llm, result_model=DashboardPlanner, df_info=df_metadata.get_schemas_and_samples()
)
try:
dashboard_plan = _get_pydantic_output(
query=query, llm_model=llm, result_model=DashboardPlanner, df_info=df_metadata.get_schemas_and_samples()
)
except DebugFailure as e:
logger.error(f"Error in dashboard plan generation: {e}", exc_info=True)
raise
_execute_step(pbar, node_desc + " --> done", None)
pbar.close()

Expand All @@ -139,11 +147,7 @@ def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List
page_plan = state["page_plan"]

llm = config["configurable"].get("model", None)
page = PageBuilder(
model=llm,
df_metadata=df_metadata,
page_plan=page_plan,
).page
page = page_plan.create(model=llm, df_metadata=df_metadata)

return {"pages": [page]}

Expand Down
27 changes: 17 additions & 10 deletions vizro-ai/src/vizro_ai/dashboard/response_models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vizro.tables import dash_ag_grid
from vizro_ai.dashboard._constants import component_type
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.utils.helper import DebugFailure

logger = logging.getLogger(__name__)

Expand All @@ -37,18 +38,24 @@ class ComponentPlan(BaseModel):
"used, please specify that as N/A.",
)

def create(self, model, df_metadata) -> Union[ComponentType, None]:
def create(self, model, df_metadata) -> ComponentType:
"""Create the component."""
from vizro_ai import VizroAI

vizro_ai = VizroAI(model=model)

if self.component_type == "Graph":
return vm.Graph(
id=self.component_id + "_" + self.page_id,
figure=vizro_ai.plot(df=df_metadata.get_df(self.df_name), user_input=self.component_description),
)
elif self.component_type == "AgGrid":
return vm.AgGrid(id=self.component_id + "_" + self.page_id, figure=dash_ag_grid(data_frame=self.df_name))
elif self.component_type == "Card":
return _get_pydantic_output(query=self.component_description, llm_model=model, result_model=vm.Card)
try:
if self.component_type == "Graph":
return vm.Graph(
id=self.component_id + "_" + self.page_id,
figure=vizro_ai.plot(df=df_metadata.get_df(self.df_name), user_input=self.component_description),
)
elif self.component_type == "AgGrid":
return vm.AgGrid(id=self.component_id + "_" + self.page_id, figure=dash_ag_grid(data_frame=self.df_name))
elif self.component_type == "Card":
return _get_pydantic_output(query=self.component_description, llm_model=model, result_model=vm.Card)
except DebugFailure as e:
logger.warning(
f"Failed to build component: {self.component_id}.\n ------- \n "
f"Reason: {e} \n ------- \n Relevant prompt: `{self.component_description}`")
return vm.Card(id=self.component_id, text=f"Failed to build component: {self.component_id}")
4 changes: 3 additions & 1 deletion vizro-ai/src/vizro_ai/dashboard/response_models/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def create(self, model, available_components, df_metadata):
)

except ValidationError as e:
logger.info(f"Build failed for `Control`, returning default values. Error details: {e}")
logger.warning(f"Build failed for `Control`, returning default values. Try rephrase the prompt or "
f"select a different model. \n ------- \n Error details: {e} \n ------- \n "
f"Relevant prompt: `{self.control_description}`")
return None

return actual
1 change: 0 additions & 1 deletion vizro-ai/src/vizro_ai/dashboard/response_models/df_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field

from vizro_ai.dashboard._pydantic_output import _get_pydantic_output

DF_SUM_PROMPT = """
Inspect the provided data and give a short unique name to the dataset. \n
Expand Down
7 changes: 5 additions & 2 deletions vizro-ai/src/vizro_ai/dashboard/response_models/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from vizro.models._layout import _get_grid_lines, _get_unique_grid_component_ids, _validate_grid_areas
from vizro_ai.dashboard._pydantic_output import _get_pydantic_output
from vizro_ai.utils.helper import DebugFailure

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,8 +70,10 @@ def create(self, model) -> Union[vm.Layout, None]:
try:
proxy = _get_pydantic_output(query=layout_prompt, llm_model=model, result_model=LayoutProxyModel)
actual = vm.Layout.parse_obj(proxy.dict(exclude={}))
except (ValidationError, AttributeError) as e:
logger.info(f"Build failed for `Layout`, returning default values. Error details: {e}")
except DebugFailure as e:
logger.warning(f"Build failed for `Layout`, returning default values. Try rephrase the prompt or "
f"select a different model. \n ------- \n Error details: {e} \n ------- \n "
f"Relevant prompt: `{self.layout_description}`")
actual = None

return actual
95 changes: 90 additions & 5 deletions vizro-ai/src/vizro_ai/dashboard/response_models/page.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Page plan model."""

import logging
from typing import List
from typing import List, Union

try:
from pydantic.v1 import BaseModel, Field, validator
from pydantic.v1 import BaseModel, Field, validator, PrivateAttr
except ImportError: # pragma: no cov
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, validator, PrivateAttr
from vizro_ai.dashboard.response_models.components import ComponentPlan
from vizro_ai.dashboard.response_models.controls import ControlPlan
from vizro_ai.dashboard.response_models.layout import LayoutPlan
import vizro.models as vm
from tqdm.auto import tqdm
from vizro_ai.dashboard.utils import _execute_step
from vizro_ai.utils.helper import DebugFailure

logger = logging.getLogger(__name__)

Expand All @@ -28,11 +32,92 @@ class PagePlanner(BaseModel):
controls_plan: List[ControlPlan] = Field([], description="Controls of the page.")
layout_plan: LayoutPlan = Field(None, description="Layout of the page.")

_components: List[Union[vm.Card, vm.AgGrid, vm.Figure]] = PrivateAttr()
_controls: List[vm.Filter] = PrivateAttr()
_layout: vm.Layout = PrivateAttr()

@validator("components_plan")
def _check_components_plan(cls, v):
if len(v) == 0:
raise ValueError("A page must contain at least one component.")
return v

def __init__(self, **data):
super().__init__(**data)
self._components = None
self._controls = None
self._layout = None

def _get_components(self, df_metadata, model):
if self._components is None:
self._components = self._build_components(df_metadata, model)
return self._components

def _build_components(self, df_metadata, model):
components = []
component_log = tqdm(total=0, bar_format="{desc}", leave=False)
with tqdm(
total=len(self.components_plan),
desc=f"Currently Building ... [Page] <{self.title}> components",
leave=False,
) as pbar:
for component_plan in self.components_plan:
component_log.set_description_str(
f"[Page] <{self.title}>: [Component] {component_plan.component_id}"
)
pbar.update(1)
components.append(component_plan.create(df_metadata=df_metadata, model=model))
component_log.close()
return components

def _get_layout(self, model):
if self._layout is None:
self._layout = self._build_layout(model)
return self._layout

def _build_layout(self, model):
if self.layout_plan is None:
return None
return self.layout_plan.create(model=model)

def _get_controls(self, df_metadata, model):
if self._controls is None:
self._controls = self._build_controls(df_metadata, model)
return self._controls

def _available_components(self, df_metadata, model):
return [comp.id for comp in self._get_components(df_metadata=df_metadata, model=model) if isinstance(comp, (vm.Graph, vm.AgGrid))]

def _build_controls(self, df_metadata, model):
controls = []
with tqdm(
total=len(self.controls_plan),
desc=f"Currently Building ... [Page] <{self.title}> controls",
leave=False,
) as pbar:
for control_plan in self.controls_plan:
pbar.update(1)
control = control_plan.create(
model=model, available_components=self._available_components(df_metadata, model), df_metadata=df_metadata
)
if control:
controls.append(control)

return controls



def create(self, model, df_metadata):
page_desc = f"Building page: {self.title}"
logger.info(page_desc)
pbar = tqdm(total=5, desc=page_desc)

def create():
pass
title = _execute_step(pbar, page_desc + " --> add title", self.title)
components = _execute_step(pbar, page_desc + " --> add components", self._get_components(df_metadata=df_metadata, model=model))
controls = _execute_step(pbar, page_desc + " --> add controls", self._get_controls(df_metadata, model))
layout = _execute_step(pbar, page_desc + " --> add layout", self._get_layout(model))

page = vm.Page(title=title, components=components, controls=controls, layout=layout)
_execute_step(pbar, page_desc + " --> done", None)
pbar.close()
return page
112 changes: 0 additions & 112 deletions vizro-ai/src/vizro_ai/dashboard/response_models/page_build.py

This file was deleted.

0 comments on commit 6b678a2

Please sign in to comment.