diff --git a/vizro-ai/changelog.d/20240603_173503_anna_xiong_get_vizro_ai_customized_output.md b/vizro-ai/changelog.d/20240603_173503_anna_xiong_get_vizro_ai_customized_output.md new file mode 100644 index 000000000..0366db564 --- /dev/null +++ b/vizro-ai/changelog.d/20240603_173503_anna_xiong_get_vizro_ai_customized_output.md @@ -0,0 +1,47 @@ + + + + + +### Added + +- Enable feature to get PlotOutputs dataclass from `VizroAI.plot()`. Add argument `return_elements`to`VizroAI.plot()`, when it is set to`True`, the return type will be changed to a `dataclass` containing the code string, figure object, business insights, and code explanation. ([#488](https://github.com/mckinsey/vizro/pull/488)) + + + + + diff --git a/vizro-ai/src/vizro_ai/_vizro_ai.py b/vizro-ai/src/vizro_ai/_vizro_ai.py index 8417ca583..fabc07f12 100644 --- a/vizro-ai/src/vizro_ai/_vizro_ai.py +++ b/vizro-ai/src/vizro_ai/_vizro_ai.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Optional, Union +from dataclasses import asdict +from typing import Any, Optional, Union import pandas as pd import plotly.graph_objects as go @@ -10,9 +11,10 @@ from vizro_ai.task_pipeline._pipeline_manager import PipelineManager from vizro_ai.utils.helper import ( DebugFailure, + PlotOutputs, _debug_helper, + _display_markdown, _exec_code_and_retrieve_fig, - _exec_fig_code_display_markdown, _is_jupyter, ) @@ -23,7 +25,7 @@ class VizroAI: """Vizro-AI main class.""" pipeline_manager: PipelineManager = PipelineManager() - _return_all_text: bool = False + _return_all_text: bool = False # TODO deleted after adding new integration test def __init__(self, model: Optional[Union[ChatOpenAI, str]] = None): """Initialization of VizroAI. @@ -57,7 +59,7 @@ def _lazy_get_component(self, component_class: Any) -> Any: # TODO configure co def _run_plot_tasks( self, df: pd.DataFrame, user_input: str, max_debug_retry: int = 3, explain: bool = False - ) -> Dict[str, Any]: + ) -> PlotOutputs: """Task execution.""" chart_type_pipeline = self.pipeline_manager.chart_type_pipeline chart_types = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df}) @@ -76,18 +78,29 @@ def _run_plot_tasks( pass_validation = validated_code_dict.get("debug_status") code_string = validated_code_dict.get("code_string") - business_insights, code_explanation = None, None - if explain and pass_validation: + if not pass_validation: + raise DebugFailure( + "Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, " + "or try to select a different model. Fallout response is provided: \n\n" + code_string + ) + + fig_object = _exec_code_and_retrieve_fig( + code=code_string, local_args={"df": df}, show_fig=_is_jupyter(), is_notebook_env=_is_jupyter() + ) + if explain: business_insights, code_explanation = self._lazy_get_component(GetCodeExplanation).run( chain_input=user_input, code_snippet=code_string ) - return { - "business_insights": business_insights, - "code_explanation": code_explanation, - "code_string": code_string, - } + return PlotOutputs( + code=code_string, + figure=fig_object, + business_insights=business_insights, + code_explanation=code_explanation, + ) + + return PlotOutputs(code=code_string, figure=fig_object) def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str: """Get Chart code of vizro via english descriptions, English to chart translation. @@ -99,12 +112,17 @@ def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str: user_input: User questions or descriptions of the desired visual """ - # TODO refine and update error handling - return self._run_plot_tasks(df, user_input, explain=False).get("code_string") - - def plot( - self, df: pd.DataFrame, user_input: str, explain: bool = False, max_debug_retry: int = 3 - ) -> Union[go.Figure, Dict[str, Any]]: + # TODO retained for some chat application integration, need deprecation handling + return self._run_plot_tasks(df, user_input, explain=False).code + + def plot( # pylint: disable=too-many-arguments # noqa: PLR0913 + self, + df: pd.DataFrame, + user_input: str, + explain: bool = False, + max_debug_retry: int = 3, + return_elements: bool = False, + ) -> Union[go.Figure, PlotOutputs]: """Plot visuals using vizro via english descriptions, english to chart translation. Args: @@ -112,28 +130,33 @@ def plot( user_input: User questions or descriptions of the desired visual. explain: Flag to include explanation in response. max_debug_retry: Maximum number of retries to debug errors. Defaults to `3`. + return_elements: Flag to return PlotOutputs dataclass that includes all possible elements generated. Returns: - Plotly Figure object or a dictionary containing data + go.Figure or PlotOutputs dataclass """ - output_dict = self._run_plot_tasks(df, user_input, explain=explain, max_debug_retry=max_debug_retry) - code_string = output_dict.get("code_string") - business_insights = output_dict.get("business_insights") - code_explanation = output_dict.get("code_explanation") + vizro_plot = self._run_plot_tasks( + df=df, user_input=user_input, explain=explain, max_debug_retry=max_debug_retry + ) - if code_string.startswith("Failed to debug code"): - raise DebugFailure( - "Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, " - "or try to select a different model. Fallout response is provided: \n\n" + code_string + if not explain: + logger.info( + "Flag explain is set to False. business_insights and code_explanation will not be included in " + "the output dataclass." + ) + + else: + _display_markdown( + code_snippet=vizro_plot.code, + biz_insights=vizro_plot.business_insights, + code_explain=vizro_plot.code_explanation, ) - # TODO Tentative for integration test + # TODO Tentative for integration test, will be updated/removed for new tests if self._return_all_text: + output_dict = asdict(vizro_plot) + output_dict["code_string"] = vizro_plot.code return output_dict - if not explain: - return _exec_code_and_retrieve_fig(code=code_string, local_args={"df": df}, is_notebook_env=_is_jupyter()) - if explain: - return _exec_fig_code_display_markdown( - df=df, code_snippet=code_string, biz_insights=business_insights, code_explain=code_explanation - ) + + return vizro_plot if return_elements else vizro_plot.figure diff --git a/vizro-ai/src/vizro_ai/utils/helper.py b/vizro-ai/src/vizro_ai/utils/helper.py index c7b6744f7..2de34b845 100644 --- a/vizro-ai/src/vizro_ai/utils/helper.py +++ b/vizro-ai/src/vizro_ai/utils/helper.py @@ -1,6 +1,7 @@ """Helper Functions For Vizro AI.""" import traceback +from dataclasses import dataclass, field from typing import Callable, Dict, Optional import pandas as pd @@ -9,6 +10,16 @@ from .safeguard import _safeguard_check +@dataclass +class PlotOutputs: + """Dataclass containing all possible `VizroAI.plot()` output.""" + + code: str + figure: go.Figure + business_insights: Optional[str] = field(default=None) + code_explanation: Optional[str] = field(default=None) + + # Taken from rich.console. See https://github.com/Textualize/rich. def _is_jupyter() -> bool: # pragma: no cover """Checks if we're running in a Jupyter notebook.""" @@ -49,13 +60,14 @@ def _debug_helper( def _exec_code_and_retrieve_fig( - code: str, local_args: Optional[Dict] = None, is_notebook_env: bool = True + code: str, local_args: Optional[Dict] = None, show_fig: bool = False, is_notebook_env: bool = True ) -> go.Figure: """Execute code in notebook with correct namespace and return fig object. Args: code: code string to be executed local_args: additional local arguments + show_fig: boolean flag indicating if fig will be rendered automatically is_notebook_env: boolean flag indicating if code is run in Jupyter notebook Returns: @@ -64,7 +76,13 @@ def _exec_code_and_retrieve_fig( """ from IPython import get_ipython + if show_fig and "\nfig.show()" not in code: + code += "\nfig.show()" + elif not show_fig: + code = code.replace("fig.show()", "") + namespace = get_ipython().user_ns if is_notebook_env else globals() + if local_args: namespace.update(local_args) _safeguard_check(code) @@ -75,21 +93,14 @@ def _exec_code_and_retrieve_fig( return dashboard_ready_fig -def _exec_fig_code_display_markdown( - df: pd.DataFrame, code_snippet: str, biz_insights: str, code_explain: str -) -> go.Figure: - # TODO change default test str to other +def _display_markdown(code_snippet: str, biz_insights: str, code_explain: str) -> None: """Display chart and Markdown format description in jupyter and returns fig object. Args: - df: The dataframe to be analyzed. code_snippet: code string to be executed biz_insights: business insights to be displayed in markdown cell code_explain: code explanation to be displayed in markdown cell - Returns: - go.Figure - """ try: # pylint: disable=import-outside-toplevel @@ -100,7 +111,6 @@ def _exec_fig_code_display_markdown( markdown_code = f"```\n{code_snippet}\n```" output_text = f"

Insights:

\n\n{biz_insights}\n

Code:

\n\n{code_explain}\n{markdown_code}" display(Markdown(output_text)) - return _exec_code_and_retrieve_fig(code_snippet, local_args={"df": df}, is_notebook_env=_is_jupyter()) class DebugFailure(Exception):