Skip to content

Commit

Permalink
[Feat] get vizro ai customized text output (#488)
Browse files Browse the repository at this point in the history
Co-authored-by: Antony Milne <[email protected]>
Co-authored-by: Maximilian Schulz
 <[email protected]>
  • Loading branch information
3 people authored Jun 6, 2024
1 parent 85ca2ec commit 3ad4e1d
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

### Added

- Enable feature to get PlotOutputs dataclass from `VizroAI.plot()`. Add argument `return_elements`to`VizroAI.plot()`, when it is set to`True`, the return type will be changed to a `dataclass` containing the code string, figure object, business insights, and code explanation. ([#488](https://github.com/mckinsey/vizro/pull/488))

<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
89 changes: 56 additions & 33 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, Optional, Union
from dataclasses import asdict
from typing import Any, Optional, Union

import pandas as pd
import plotly.graph_objects as go
Expand All @@ -10,9 +11,10 @@
from vizro_ai.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
DebugFailure,
PlotOutputs,
_debug_helper,
_display_markdown,
_exec_code_and_retrieve_fig,
_exec_fig_code_display_markdown,
_is_jupyter,
)

Expand All @@ -23,7 +25,7 @@ class VizroAI:
"""Vizro-AI main class."""

pipeline_manager: PipelineManager = PipelineManager()
_return_all_text: bool = False
_return_all_text: bool = False # TODO deleted after adding new integration test

def __init__(self, model: Optional[Union[ChatOpenAI, str]] = None):
"""Initialization of VizroAI.
Expand Down Expand Up @@ -57,7 +59,7 @@ def _lazy_get_component(self, component_class: Any) -> Any: # TODO configure co

def _run_plot_tasks(
self, df: pd.DataFrame, user_input: str, max_debug_retry: int = 3, explain: bool = False
) -> Dict[str, Any]:
) -> PlotOutputs:
"""Task execution."""
chart_type_pipeline = self.pipeline_manager.chart_type_pipeline
chart_types = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df})
Expand All @@ -76,18 +78,29 @@ def _run_plot_tasks(

pass_validation = validated_code_dict.get("debug_status")
code_string = validated_code_dict.get("code_string")
business_insights, code_explanation = None, None

if explain and pass_validation:
if not pass_validation:
raise DebugFailure(
"Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, "
"or try to select a different model. Fallout response is provided: \n\n" + code_string
)

fig_object = _exec_code_and_retrieve_fig(
code=code_string, local_args={"df": df}, show_fig=_is_jupyter(), is_notebook_env=_is_jupyter()
)
if explain:
business_insights, code_explanation = self._lazy_get_component(GetCodeExplanation).run(
chain_input=user_input, code_snippet=code_string
)

return {
"business_insights": business_insights,
"code_explanation": code_explanation,
"code_string": code_string,
}
return PlotOutputs(
code=code_string,
figure=fig_object,
business_insights=business_insights,
code_explanation=code_explanation,
)

return PlotOutputs(code=code_string, figure=fig_object)

def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str:
"""Get Chart code of vizro via english descriptions, English to chart translation.
Expand All @@ -99,41 +112,51 @@ def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str:
user_input: User questions or descriptions of the desired visual
"""
# TODO refine and update error handling
return self._run_plot_tasks(df, user_input, explain=False).get("code_string")

def plot(
self, df: pd.DataFrame, user_input: str, explain: bool = False, max_debug_retry: int = 3
) -> Union[go.Figure, Dict[str, Any]]:
# TODO retained for some chat application integration, need deprecation handling
return self._run_plot_tasks(df, user_input, explain=False).code

def plot( # pylint: disable=too-many-arguments # noqa: PLR0913
self,
df: pd.DataFrame,
user_input: str,
explain: bool = False,
max_debug_retry: int = 3,
return_elements: bool = False,
) -> Union[go.Figure, PlotOutputs]:
"""Plot visuals using vizro via english descriptions, english to chart translation.
Args:
df: The dataframe to be analyzed.
user_input: User questions or descriptions of the desired visual.
explain: Flag to include explanation in response.
max_debug_retry: Maximum number of retries to debug errors. Defaults to `3`.
return_elements: Flag to return PlotOutputs dataclass that includes all possible elements generated.
Returns:
Plotly Figure object or a dictionary containing data
go.Figure or PlotOutputs dataclass
"""
output_dict = self._run_plot_tasks(df, user_input, explain=explain, max_debug_retry=max_debug_retry)
code_string = output_dict.get("code_string")
business_insights = output_dict.get("business_insights")
code_explanation = output_dict.get("code_explanation")
vizro_plot = self._run_plot_tasks(
df=df, user_input=user_input, explain=explain, max_debug_retry=max_debug_retry
)

if code_string.startswith("Failed to debug code"):
raise DebugFailure(
"Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, "
"or try to select a different model. Fallout response is provided: \n\n" + code_string
if not explain:
logger.info(
"Flag explain is set to False. business_insights and code_explanation will not be included in "
"the output dataclass."
)

else:
_display_markdown(
code_snippet=vizro_plot.code,
biz_insights=vizro_plot.business_insights,
code_explain=vizro_plot.code_explanation,
)

# TODO Tentative for integration test
# TODO Tentative for integration test, will be updated/removed for new tests
if self._return_all_text:
output_dict = asdict(vizro_plot)
output_dict["code_string"] = vizro_plot.code
return output_dict
if not explain:
return _exec_code_and_retrieve_fig(code=code_string, local_args={"df": df}, is_notebook_env=_is_jupyter())
if explain:
return _exec_fig_code_display_markdown(
df=df, code_snippet=code_string, biz_insights=business_insights, code_explain=code_explanation
)

return vizro_plot if return_elements else vizro_plot.figure
30 changes: 20 additions & 10 deletions vizro-ai/src/vizro_ai/utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Helper Functions For Vizro AI."""

import traceback
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional

import pandas as pd
Expand All @@ -9,6 +10,16 @@
from .safeguard import _safeguard_check


@dataclass
class PlotOutputs:
"""Dataclass containing all possible `VizroAI.plot()` output."""

code: str
figure: go.Figure
business_insights: Optional[str] = field(default=None)
code_explanation: Optional[str] = field(default=None)


# Taken from rich.console. See https://github.com/Textualize/rich.
def _is_jupyter() -> bool: # pragma: no cover
"""Checks if we're running in a Jupyter notebook."""
Expand Down Expand Up @@ -49,13 +60,14 @@ def _debug_helper(


def _exec_code_and_retrieve_fig(
code: str, local_args: Optional[Dict] = None, is_notebook_env: bool = True
code: str, local_args: Optional[Dict] = None, show_fig: bool = False, is_notebook_env: bool = True
) -> go.Figure:
"""Execute code in notebook with correct namespace and return fig object.
Args:
code: code string to be executed
local_args: additional local arguments
show_fig: boolean flag indicating if fig will be rendered automatically
is_notebook_env: boolean flag indicating if code is run in Jupyter notebook
Returns:
Expand All @@ -64,7 +76,13 @@ def _exec_code_and_retrieve_fig(
"""
from IPython import get_ipython

if show_fig and "\nfig.show()" not in code:
code += "\nfig.show()"
elif not show_fig:
code = code.replace("fig.show()", "")

namespace = get_ipython().user_ns if is_notebook_env else globals()

if local_args:
namespace.update(local_args)
_safeguard_check(code)
Expand All @@ -75,21 +93,14 @@ def _exec_code_and_retrieve_fig(
return dashboard_ready_fig


def _exec_fig_code_display_markdown(
df: pd.DataFrame, code_snippet: str, biz_insights: str, code_explain: str
) -> go.Figure:
# TODO change default test str to other
def _display_markdown(code_snippet: str, biz_insights: str, code_explain: str) -> None:
"""Display chart and Markdown format description in jupyter and returns fig object.
Args:
df: The dataframe to be analyzed.
code_snippet: code string to be executed
biz_insights: business insights to be displayed in markdown cell
code_explain: code explanation to be displayed in markdown cell
Returns:
go.Figure
"""
try:
# pylint: disable=import-outside-toplevel
Expand All @@ -100,7 +111,6 @@ def _exec_fig_code_display_markdown(
markdown_code = f"```\n{code_snippet}\n```"
output_text = f"<h4>Insights:</h4>\n\n{biz_insights}\n<br><br><h4>Code:</h4>\n\n{code_explain}\n{markdown_code}"
display(Markdown(output_text))
return _exec_code_and_retrieve_fig(code_snippet, local_args={"df": df}, is_notebook_env=_is_jupyter())


class DebugFailure(Exception):
Expand Down

0 comments on commit 3ad4e1d

Please sign in to comment.