From 1c5e36f00da05ca0f40c7768a5ec31facd571dac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:14:03 +0200 Subject: [PATCH 1/6] [Bot] Bump actions/labeler from 4 to 5 (#608) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index dc101a2d0..ecaeb9036 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -21,6 +21,6 @@ jobs: pull-requests: write steps: - - uses: actions/labeler@v4 + - uses: actions/labeler@v5 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From 4187e951147c89e29a079a4bca3e5506b9e0d613 Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Fri, 2 Aug 2024 01:09:39 -0400 Subject: [PATCH 2/6] [POC] Dashboard generator (#522) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Schulz --- ...020634_lingyi_zhang_dashboard_generator.md | 48 +++ vizro-ai/examples/example_dashboard.ipynb | 283 ++++++++++++++++++ vizro-ai/examples/example_dashboard.py | 33 ++ vizro-ai/hatch.toml | 1 + vizro-ai/pyproject.toml | 3 +- vizro-ai/snyk/requirements.txt | 1 + vizro-ai/src/vizro_ai/_llm_models.py | 18 +- vizro-ai/src/vizro_ai/_vizro_ai.py | 52 +++- .../dashboard/_graph/dashboard_creation.py | 190 ++++++++++++ .../vizro_ai/dashboard/_pydantic_output.py | 98 ++++++ .../dashboard/_response_models/components.py | 98 ++++++ .../dashboard/_response_models/controls.py | 170 +++++++++++ .../dashboard/_response_models/dashboard.py | 25 ++ .../dashboard/_response_models/df_info.py | 46 +++ .../dashboard/_response_models/layout.py | 92 ++++++ .../dashboard/_response_models/page.py | 223 ++++++++++++++ .../dashboard/_response_models/types.py | 13 + vizro-ai/src/vizro_ai/dashboard/utils.py | 63 ++++ vizro-ai/src/vizro_ai/py.typed | 1 + 19 files changed, 1453 insertions(+), 5 deletions(-) create mode 100644 vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md create mode 100644 vizro-ai/examples/example_dashboard.ipynb create mode 100644 vizro-ai/examples/example_dashboard.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/components.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/page.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/_response_models/types.py create mode 100644 vizro-ai/src/vizro_ai/dashboard/utils.py create mode 100644 vizro-ai/src/vizro_ai/py.typed diff --git a/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md b/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/examples/example_dashboard.ipynb b/vizro-ai/examples/example_dashboard.ipynb new file mode 100644 index 000000000..c918bfca5 --- /dev/null +++ b/vizro-ai/examples/example_dashboard.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "53e857ce-22bc-49de-9adc-9a2e7c9829cf", + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a25acdd-20c3-4762-b97f-254de1586aeb", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import vizro.plotly.express as px\n", + "\n", + "from vizro import Vizro\n", + "from vizro_ai import VizroAI\n", + "\n", + "# vizro_ai = VizroAI(model=\"gpt-4-turbo\")\n", + "vizro_ai = VizroAI(model=\"gpt-4o\")\n", + "# vizro_ai = VizroAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5e24f1b-e698-40e5-be00-c3a59c53ec65", + "metadata": {}, + "outputs": [], + "source": [ + "df1 = px.data.gapminder()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "449da2ee-c754-420a-ba2e-c9b0ef62d934", + "metadata": {}, + "outputs": [], + "source": [ + "df2 = px.data.stocks()" + ] + }, + { + "cell_type": "markdown", + "id": "ec46d4d1-d20b-4351-831d-d3d8ddc5cb70", + "metadata": {}, + "source": [ + "# Example: Simple dashboard request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "820a5d0f-a31e-4bbd-a924-9629631cc291", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_2_data = \"\"\"\n", + "I need a page with 1 table.\n", + "The table shows the tech companies stock data.\n", + "\n", + "I need a second page showing 2 cards and one chart.\n", + "The first card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing life expectancy vs. GDP per capita by country. Life expectancy on the y axis, GDP per capita on the x axis, and colored by continent.\n", + "The second card says 'Data spans from 1952 to 2007 across various countries'\n", + "The layout uses a grid of 3 columns and 2 rows.\n", + "\n", + "Row 1: The first row has three columns:\n", + "The first column is occupied by the first card.\n", + "The second and third columns are spanned by the chart.\n", + "\n", + "Row 2: The second row mirrors the layout of the first row with respect to chart, but the first column is occupied by the second card.\n", + "\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d71e089-8c94-4d12-87bd-d803552acb32", + "metadata": {}, + "outputs": [], + "source": [ + "dashboard = vizro_ai.dashboard([df1, df2], user_question_2_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14477c56-54e9-43a5-9136-25bc950fdf3a", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + }, + { + "cell_type": "markdown", + "id": "747964b9-fd05-4c5a-a73a-79dae82320b3", + "metadata": {}, + "source": [ + "# Example: 5-page dashboard request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "967ff6a4-f138-4643-b993-a72e5cc26de2", + "metadata": {}, + "outputs": [], + "source": [ + "df3 = px.data.tips()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb9347f8", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_3_data = \"\"\"\n", + "\n", + "I need a page with 1 table and 1 line chart. \n", + "The chart shows the stock price trends of GOOG and AAPL.\n", + "The table shows the stock prices data details.\n", + "\n", + "\n", + "I need a second page showing 1 card and 1 chart.\n", + "The card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Layout the card on the left and the chart on the right. The card takes 1/3 of the whole space on the left.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\n", + "\n", + "This page displays the tips dataset. use two different charts to show data\n", + "distributions. one chart should be a bar chart and the other should be a scatter plot.\n", + "first chart is on the left and the second chart is on the right.\n", + "Add a filter to filter data in the scatter plot by smoker.\n", + "\n", + "\n", + "Create 3 cards on this page:\n", + "1. The first card on top says \"This page combines data from various sources including tips, stock prices, and global indicators.\"\n", + "2. The second card says \"Insights from Gapminder dataset.\"\n", + "3. The third card says \"Stock price trends over time.\"\n", + "\n", + "Layout these 3 cards in this way:\n", + "create a grid with 3 columns and 2 rows.\n", + "Row 1: The first row has three columns:\n", + "- The first column is empty.\n", + "- The second and third columns span the area for card 1.\n", + "\n", + "Row 2: The second row also has three columns:\n", + "- The first column is empty.\n", + "- The second column is occupied by the area for card 2.\n", + "- The third column is occupied by the area for card 3.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0a0cdfa", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro._reset()\n", + "dashboard = vizro_ai.dashboard([df1, df2, df3], user_question_3_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3167e996", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + }, + { + "cell_type": "markdown", + "id": "bbf5c920-0432-4415-996f-1acb9d7b6b8a", + "metadata": {}, + "source": [ + "# Example: Request with unsupported features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12d5976e", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_2_data = \"\"\"\n", + "\n", + "I need a page showing 2 cards, one chart, and 1 button.\n", + "The first card says 'The Tips dataset provides insights into customer tipping behavior.'\n", + "The chart is a bar chart showing the total bill amount by day. Day on the x axis, total bill amount on the y axis, and colored by time of day.\n", + "The second card says 'Data collected from various days and times.'\n", + "Layout the two cards on the left and the chart on the right. Two cards take 1/3 of the whole space on the left in total.\n", + "The first card is on top of the second card vertically.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "The button would trigger a download action to download the Tips dataset.\n", + "Add a filter to filter the bar chart by `size`.\n", + "Make another tab on this page,\n", + "In this tab, create a card saying \"Tipping patterns and trends.\"\n", + "Group all the above content into the first NavLink.\n", + "\n", + "\n", + "Create two pages:\n", + "1. The first page has a card saying \"Analyzing global development trends.\"\n", + "2. The second page has a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Add a parameter to control the title of the scatter plot, with title options \"Economic Growth vs. Health\" and \"Development Indicators.\"\n", + "Also create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "\n", + "Create one page:\n", + "1. The first page has a card saying \"Stock price trends over time.\"\n", + "Create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "For hosting the dashboard on AWS, which service should I use?\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b4838d1", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro._reset()\n", + "dashboard = vizro_ai.dashboard([df3, df2, df1], user_question_2_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f055bec1", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/vizro-ai/examples/example_dashboard.py b/vizro-ai/examples/example_dashboard.py new file mode 100644 index 000000000..b6dc55839 --- /dev/null +++ b/vizro-ai/examples/example_dashboard.py @@ -0,0 +1,33 @@ +"""Example of creating a dashboard using VizroAI.""" + +import vizro.plotly.express as px +from dotenv import load_dotenv +from vizro import Vizro +from vizro_ai import VizroAI + +load_dotenv() + +vizro_ai = VizroAI(model="gpt-4o") +# vizro_ai = VizroAI() + +gapminder_data = px.data.gapminder() +tips_data = px.data.tips() + +dfs = [gapminder_data, tips_data] +input_text = ( + "Create a dashboard that displays the Gapminder dataset and the tips dataset. " + "page1 displays the Gapminder dataset. create a bar chart for average GDP per capita of each continent. " + "add a filter to filter by continent. " + "Use a card to explain what Gapminder dataset is about. " + "The card should only take 1/6 of the whole page. " + "The rest of the page should be the graph or table. Don't create empty space." + "page2 displays the tips dataset. use two different charts to help me understand the data " + "distributions. one chart should be a bar chart and the other should be a scatter plot. " + "first chart is on the left and the second chart is on the right. " + "add a filter to filter data in the scatter plot by smoker." +) + +dashboard = vizro_ai.dashboard(dfs=dfs, user_input=input_text) + +if __name__ == "__main__": + Vizro().build(dashboard).run() diff --git a/vizro-ai/hatch.toml b/vizro-ai/hatch.toml index 66b69e519..258e6c479 100644 --- a/vizro-ai/hatch.toml +++ b/vizro-ai/hatch.toml @@ -25,6 +25,7 @@ VIZRO_AI_LOG_LEVEL = "DEBUG" [envs.default.scripts] example = "cd examples; python example.py" +example-create-dashboard = "cd examples; python example_dashboard.py" lint = "hatch run lint:lint {args:--all-files}" prep-release = [ "hatch version release", diff --git a/vizro-ai/pyproject.toml b/vizro-ai/pyproject.toml index 2136e2b88..d6c027b84 100644 --- a/vizro-ai/pyproject.toml +++ b/vizro-ai/pyproject.toml @@ -17,8 +17,9 @@ dependencies = [ "pandas", "tabulate", "openai>=1.0.0", - "langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class and remove upper bound + "langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class, update to pydantic v2 and remove upper bound "langchain-openai", + "langgraph>=0.1.2", "python-dotenv>=1.0.0", # TODO decide env var management to see if we need this "vizro>=0.1.4", # TODO set upper bound later "ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382 diff --git a/vizro-ai/snyk/requirements.txt b/vizro-ai/snyk/requirements.txt index b2588a330..19da50b49 100644 --- a/vizro-ai/snyk/requirements.txt +++ b/vizro-ai/snyk/requirements.txt @@ -3,6 +3,7 @@ tabulate openai>=1.0.0 langchain>=0.1.0, <0.3.0 langchain-openai +langgraph>=0.1.2 python-dotenv>=1.0.0 vizro>=0.1.4 ipython>=8.10.0 diff --git a/vizro-ai/src/vizro_ai/_llm_models.py b/vizro-ai/src/vizro_ai/_llm_models.py index 9014ad17c..b9a955a8b 100644 --- a/vizro-ai/src/vizro_ai/_llm_models.py +++ b/vizro-ai/src/vizro_ai/_llm_models.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Dict, Optional, Union from langchain_core.language_models.chat_models import BaseChatModel @@ -17,7 +18,7 @@ "gpt-3.5-turbo", "gpt-4o-2024-05-13", "gpt-4o", - ] + ], } DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI} @@ -49,6 +50,8 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo if isinstance(model, str): if any(model in model_list for model_list in SUPPORTED_MODELS.values()): vendor = model_to_vendor[model] + if DEFAULT_WRAPPER_MAP.get(vendor) is None: + raise ValueError(f"Additional library to support {vendor} models is not installed.") return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE) raise ValueError( @@ -56,6 +59,19 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo ) +def _get_model_name(model: BaseChatModel) -> str: + methods = [ + lambda: model.model_name, # OpenAI models + lambda: model.model, # Anthropic models + ] + + for method in methods: + with suppress(AttributeError): + return method() + + raise ValueError("Model name could not be retrieved") + + if __name__ == "__main__": llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo") print(repr(llm_chat_openai)) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/_vizro_ai.py b/vizro-ai/src/vizro_ai/_vizro_ai.py index c0866aa41..61e2dc48b 100644 --- a/vizro-ai/src/vizro_ai/_vizro_ai.py +++ b/vizro-ai/src/vizro_ai/_vizro_ai.py @@ -1,11 +1,15 @@ import logging -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import pandas as pd import plotly.graph_objects as go +import vizro.models as vm +from langchain_core.messages import HumanMessage from langchain_openai import ChatOpenAI -from vizro_ai._llm_models import _get_llm_model +from vizro_ai._llm_models import _get_llm_model, _get_model_name +from vizro_ai.dashboard._graph.dashboard_creation import _create_and_compile_graph +from vizro_ai.dashboard.utils import DashboardOutputs, _register_data from vizro_ai.plot.components import GetCodeExplanation, GetDebugger from vizro_ai.plot.task_pipeline._pipeline_manager import PipelineManager from vizro_ai.utils.helper import ( @@ -36,8 +40,9 @@ def __init__(self, model: Optional[Union[ChatOpenAI, str]] = None): self.components_instances = {} # TODO add pending URL link to docs + model_name = _get_model_name(self.model) logger.info( - f"You have selected {self.model.model_name}," + f"You have selected {model_name}," f"Engaging with LLMs (Large Language Models) carries certain risks. " f"Users are advised to become familiar with these risks to make informed decisions, " f"and visit this page for detailed information: " @@ -154,3 +159,44 @@ def plot( # pylint: disable=too-many-arguments # noqa: PLR0913 ) return vizro_plot if return_elements else vizro_plot.figure + + def dashboard( + self, + dfs: List[pd.DataFrame], + user_input: str, + return_elements: bool = False, + ) -> Union[DashboardOutputs, vm.Dashboard]: + """Creates a Vizro dashboard using english descriptions. + + Args: + dfs: The dataframes to be analyzed. + user_input: User questions or descriptions of the desired visual. + return_elements: Flag to return DashboardOutputs dataclass that includes all possible elements generated. + + Returns: + vm.Dashboard or DashboardOutputs dataclass. + + """ + runnable = _create_and_compile_graph() + + config = {"configurable": {"model": self.model}} + message_res = runnable.invoke( + { + "dfs": dfs, + "all_df_metadata": {}, + "dashboard_plan": None, + "pages": [], + "dashboard": None, + "messages": [HumanMessage(content=user_input)], + }, + config=config, + ) + dashboard = message_res["dashboard"] + _register_data(all_df_metadata=message_res["all_df_metadata"]) + + if return_elements: + # code = _dashboard_code(dashboard) # TODO: `_dashboard_code` to be implemented + dashboard_output = DashboardOutputs(dashboard=dashboard) + return dashboard_output + else: + return dashboard diff --git a/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py b/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py new file mode 100644 index 000000000..e11caf5da --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py @@ -0,0 +1,190 @@ +"""Code generation graph for dashboard generation.""" + +import logging +import operator +from typing import Annotated, Dict, List, Optional + +import pandas as pd +import vizro.models as vm +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig +from langgraph.constants import END, Send +from langgraph.graph import StateGraph +from tqdm.auto import tqdm +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.dashboard import DashboardPlan +from vizro_ai.dashboard._response_models.df_info import DfInfo, _create_df_info_content, _get_df_info +from vizro_ai.dashboard._response_models.page import PagePlan +from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata, _execute_step +from vizro_ai.utils.helper import DebugFailure + +try: + from pydantic.v1 import BaseModel, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, ValidationError + + +logger = logging.getLogger(__name__) + + +Messages = List[BaseMessage] +"""List of messages.""" + + +class GraphState(BaseModel): + """Represents the state of the dashboard graph. + + Attributes + messages: With user question, error messages, reasoning + dfs: Dataframes + all_df_metadata: Cleaned dataframe names and their metadata + dashboard_plan: Plan for the dashboard + pages: Vizro pages + dashboard: Vizro dashboard + + """ + + messages: List[BaseMessage] + dfs: List[pd.DataFrame] + all_df_metadata: AllDfMetadata + dashboard_plan: Optional[DashboardPlan] = None + pages: Annotated[List, operator.add] + dashboard: Optional[vm.Dashboard] = None + + class Config: + """Pydantic configuration.""" + + arbitrary_types_allowed = True + + +def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, AllDfMetadata]: + """Store information about the dataframes.""" + dfs = state.dfs + all_df_metadata = state.all_df_metadata + query = state.messages[0].content + current_df_names = [] + with tqdm(total=len(dfs), desc="Store df info") as pbar: + for df in dfs: + df_schema, df_sample = _get_df_info(df) + df_info = _create_df_info_content( + df_schema=df_schema, df_sample=df_sample, current_df_names=current_df_names + ) + + llm = config["configurable"].get("model", None) + try: + df_name = _get_pydantic_model( + query=query, + llm_model=llm, + response_model=DfInfo, + df_info=df_info, + ).dataset + except DebugFailure as e: + logger.warning(f"Failed in name generation {e}") + df_name = f"df_{len(current_df_names)}" + + current_df_names.append(df_name) + + pbar.write(f"df_name: {df_name}") + pbar.update(1) + all_df_metadata.all_df_metadata[df_name] = DfMetadata(df_schema=df_schema, df=df, df_sample=df_sample) + + return {"all_df_metadata": all_df_metadata} + + +def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlan]: + """Generate a dashboard plan.""" + node_desc = "Generate dashboard plan" + pbar = tqdm(total=2, desc=node_desc) + query = state.messages[0].content + all_df_metadata = state.all_df_metadata + + llm = config["configurable"].get("model", None) + + _execute_step( + pbar, + node_desc + " --> in progress \n(this step could take longer when more complex requirements are given)", + None, + ) + try: + dashboard_plan = _get_pydantic_model( + query=query, + llm_model=llm, + response_model=DashboardPlan, + df_info=all_df_metadata.get_schemas_and_samples(), + ) + except (DebugFailure, ValidationError) as e: + raise ValueError( + f""" + Failed to create a valid dashboard plan. Try rephrase the prompt or select a different + model. Error details: + {e} + """ + ) + + _execute_step(pbar, node_desc + " --> done", None) + pbar.close() + + return {"dashboard_plan": dashboard_plan} + + +class BuildPageState(BaseModel): + """Represents the state of building the page. + + Attributes + all_df_metadata: Cleaned dataframe names and their metadata + page_plan: Plan for the dashboard page + + """ + + all_df_metadata: AllDfMetadata + page_plan: Optional[PagePlan] = None + + +def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List[vm.Page]]: + """Build a page.""" + all_df_metadata = state["all_df_metadata"] + page_plan = state["page_plan"] + + llm = config["configurable"].get("model", None) + page = page_plan.create(model=llm, all_df_metadata=all_df_metadata) + + return {"pages": [page]} + + +def _continue_to_pages(state: GraphState) -> List[Send]: + """Map-reduce logic to build pages in parallel.""" + all_df_metadata = state.all_df_metadata + return [ + Send(node="_build_page", arg={"page_plan": v, "all_df_metadata": all_df_metadata}) + for v in state.dashboard_plan.pages + ] + + +def _build_dashboard(state: GraphState) -> Dict[str, vm.Dashboard]: + """Build a dashboard.""" + dashboard_plan = state.dashboard_plan + pages = state.pages + + dashboard = vm.Dashboard(title=dashboard_plan.title, pages=pages) + + return {"dashboard": dashboard} + + +def _create_and_compile_graph(): + graph = StateGraph(GraphState) + + graph.add_node("_store_df_info", _store_df_info) + graph.add_node("_dashboard_plan", _dashboard_plan) + graph.add_node("_build_page", _build_page) + graph.add_node("_build_dashboard", _build_dashboard) + + graph.add_edge("_store_df_info", "_dashboard_plan") + graph.add_conditional_edges("_dashboard_plan", _continue_to_pages) + graph.add_edge("_build_page", "_build_dashboard") + graph.add_edge("_build_dashboard", END) + + graph.set_entry_point("_store_df_info") + + runnable = graph.compile() + + return runnable diff --git a/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py b/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py new file mode 100644 index 000000000..c1511b8e3 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py @@ -0,0 +1,98 @@ +"""Contains the _get_pydantic_model for the Vizro AI dashboard.""" + +# ruff: noqa: F821 + +import logging + +try: + from pydantic.v1 import BaseModel, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, ValidationError + +from typing import Any, Optional + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate + +logger = logging.getLogger(__name__) + +BASE_PROMPT = """ +You are a front-end developer with expertise in Plotly, Dash, and the visualization library named Vizro. +Your goal is to summarize the given specifications into the given Pydantic schema. +IMPORTANT: Please always output your response by using a tool. + +This is the task context: +{df_info} + +Additional information: +{additional_info} + +Here is the user request: +""" + + +def _create_prompt_template(additional_info: str) -> ChatPromptTemplate: + """Create the ChatPromptTemplate from the base prompt and additional info.""" + return ChatPromptTemplate.from_messages( + [ + ("system", BASE_PROMPT.format(df_info="{df_info}", additional_info=additional_info)), + ("placeholder", "{message}"), + ] + ) + + +SINGLE_MODEL_PROMPT = _create_prompt_template("") +MODEL_REPROMPT = _create_prompt_template("Pay special attention to the following error: {validation_error}") + + +def _create_prompt(retry: bool = False) -> ChatPromptTemplate: + """Create the prompt message for the LLM model.""" + return MODEL_REPROMPT if retry else SINGLE_MODEL_PROMPT + + +def _create_message_content( + query: str, df_info: Any, validation_error: Optional[str] = None, retry: bool = False +) -> dict: + """Create the message content for the LLM model.""" + message_content = {"message": [HumanMessage(content=query)], "df_info": df_info} + + if retry: + message_content["validation_error"] = validation_error + + return message_content + + +def _get_pydantic_model( + query: str, + llm_model: BaseChatModel, + response_model: BaseModel, + df_info: Optional[Any] = None, + max_retry: int = 2, +) -> BaseModel: + """Get the pydantic output from the LLM model with retry logic.""" + for attempt in range(max_retry): + attempt_is_retry = attempt > 0 + prompt = _create_prompt(retry=attempt_is_retry) + message_content = _create_message_content( + query, df_info, str(last_validation_error) if attempt_is_retry else None, retry=attempt_is_retry + ) + pydantic_llm = prompt | llm_model.with_structured_output(response_model) + try: + res = pydantic_llm.invoke(message_content) + except ValidationError as validation_error: + last_validation_error = validation_error + else: + return res + + raise last_validation_error + + +if __name__ == "__main__": + import vizro.models as vm + from vizro_ai._llm_models import _get_llm_model + + model = _get_llm_model() + component_description = "Create a card with the following content: 'Hello, world!'" + res = _get_pydantic_model(query=component_description, llm_model=model, response_model=vm.Card) + print(res) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py new file mode 100644 index 000000000..feb0dfde8 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py @@ -0,0 +1,98 @@ +"""Component plan model.""" + +import logging +from typing import Union + +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field +from vizro.tables import dash_ag_grid +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.types import ComponentType +from vizro_ai.utils.helper import DebugFailure + +logger = logging.getLogger(__name__) + + +class ComponentPlan(BaseModel): + """Component plan model.""" + + component_type: ComponentType + component_description: str = Field( + ..., + description=""" + Description of the component. Include everything that relates to this component. + Be as specific and detailed as possible. + Keep the original relevant description AS IS. Keep any links exactly as provided. + Remember: Accuracy and completeness are key. Do not omit any relevant information provided about the component. + """, + ) + component_id: str = Field( + pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case description of this component." + ) + df_name: str = Field( + ..., + description=""" + The name of the dataframe that this component will use. If no dataframe is + used, please specify that as N/A. + """, + ) + + def create(self, model, all_df_metadata) -> Union[vm.Card, vm.AgGrid, vm.Figure]: + """Create the component.""" + from vizro_ai import VizroAI + + vizro_ai = VizroAI(model=model) + + try: + if self.component_type == "Graph": + return vm.Graph( + id=self.component_id, + figure=vizro_ai.plot( + df=all_df_metadata.get_df(self.df_name), user_input=self.component_description + ), + ) + elif self.component_type == "AgGrid": + return vm.AgGrid(id=self.component_id, figure=dash_ag_grid(data_frame=self.df_name)) + elif self.component_type == "Card": + card_prompt = f""" + The Card uses the dcc.Markdown component from Dash as its underlying text component. + Create a card based on the card description: {self.component_description}. + """ + result_proxy = _get_pydantic_model(query=card_prompt, llm_model=model, response_model=vm.Card) + proxy_dict = result_proxy.dict() + proxy_dict["id"] = self.component_id + return vm.Card.parse_obj(proxy_dict) + + except DebugFailure as e: + logger.warning( + f""" +[FALLBACK] Failed to build `Component`: {self.component_id}. +Reason: {e} +Relevant prompt: {self.component_description} +""" + ) + return vm.Card(id=self.component_id, text=f"Failed to build component: {self.component_id}") + + +if __name__ == "__main__": + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata({}) + component_plan = ComponentPlan( + component_type="Card", + component_description="Create a card says 'this is worldwide GDP'.", + component_id="gdp_card", + df_name="N/A", + ) + component = component_plan.create(model, all_df_metadata) + print(component.__repr__()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py new file mode 100644 index 000000000..43bdcf62f --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py @@ -0,0 +1,170 @@ +"""Controls plan model.""" + +import logging +from typing import List, Optional + +import pandas as pd +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field, ValidationError, create_model, root_validator, validator +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, ValidationError, create_model, root_validator, validator +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.types import ControlType + +logger = logging.getLogger(__name__) + + +def _create_filter_proxy(df_cols, df_schema, controllable_components) -> BaseModel: + """Create a filter proxy model.""" + + def validate_targets(v): + """Validate the targets.""" + if v not in controllable_components: + raise ValueError(f"targets must be one of {controllable_components}") + return v + + def validate_targets_not_empty(v): + """Validate the targets not empty.""" + if not controllable_components: + raise ValueError( + """ + This might be due to the filter target is not found in the controllable components. + returning default values. + """ + ) + return v + + def validate_column(v): + """Validate the column.""" + if v not in df_cols: + raise ValueError(f"column must be one of {df_cols}") + return v + + @root_validator(allow_reuse=True) + def validate_date_picker_column(cls, values): + """Validate the column for date picker.""" + column = values.get("column") + selector = values.get("selector") + if selector and selector.type == "date_picker": + if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): + raise ValueError( + f""" + The column '{column}' is not of datetime type. Selector type 'date_picker' is + not allowed. Use 'dropdown' instead. + """ + ) + return values + + return create_model( + "FilterProxy", + targets=( + List[str], + Field( + ..., + description=f""" + Target component to be affected by filter. + Must be one of {controllable_components}. ALWAYS REQUIRED. + """, + ), + ), + column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), + __validators__={ + "validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), + "validator2": validator("column", allow_reuse=True)(validate_column), + "validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), + "validator4": validate_date_picker_column, + }, + __base__=vm.Filter, + ) + + +def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: + result_proxy = _create_filter_proxy( + df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components + ) + proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema) + return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) + + +class ControlPlan(BaseModel): + """Control plan model.""" + + control_type: ControlType + control_description: str = Field( + ..., + description=""" + Description of the control. Include everything that seems to relate to this control. + Be as detailed as possible. Keep the original relevant description AS IS. If this control is used + to control a specific component, include the relevant component details. + """, + ) + df_name: str = Field( + ..., + description=""" + The name of the dataframe that the target component will use. + If the dataframe is not used, please specify that. + """, + ) + + def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]: + """Create the control.""" + filter_prompt = f""" + Create a filter from the following instructions: <{self.control_description}>. Do not make up + things that are optional and DO NOT configure actions, action triggers or action chains. + If no options are specified, leave them out. + """ + try: + _df_schema = all_df_metadata.get_df_schema(self.df_name) + _df_cols = list(_df_schema.keys()) + except KeyError: + logger.warning(f"Dataframe {self.df_name} not found in metadata, returning default values.") + return None + + try: + if self.control_type == "Filter": + res = _create_filter( + filter_prompt=filter_prompt, + model=model, + df_cols=_df_cols, + df_schema=_df_schema, + controllable_components=controllable_components, + ) + return res + + except ValidationError as e: + logger.warning( + f""" +[FALLBACK] Build failed for `Control`, returning default values. Try rephrase the prompt or select a different model. +Error details: {e} +Relevant prompt: {self.control_description} +""" + ) + return None + + +if __name__ == "__main__": + import pandas as pd + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata({}) + all_df_metadata.all_df_metadata["gdp_chart"] = DfMetadata( + df_schema={"a": "int64", "b": "int64"}, + df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + ) + control_plan = ControlPlan( + control_type="Filter", + control_description="Create a filter that filters the data by column 'a'.", + df_name="gdp_chart", + ) + control = control_plan.create( + model, ["gdp_chart"], all_df_metadata + ) # error: Target gdp_chart not found in model_manager. diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py new file mode 100644 index 000000000..a96550b61 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py @@ -0,0 +1,25 @@ +"""Dashboard plan model.""" + +import logging +from typing import List + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field +from vizro_ai.dashboard._response_models.page import PagePlan + +logger = logging.getLogger(__name__) + + +class DashboardPlan(BaseModel): + """Dashboard plan model.""" + + title: str = Field( + ..., + description=""" + Title of the dashboard. If no description is provided, + make a short and concise title from the content of the pages. + """, + ) + pages: List[PagePlan] diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py new file mode 100644 index 000000000..0ea59395b --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py @@ -0,0 +1,46 @@ +"""Data Summary Node.""" + +from typing import Dict, List, Tuple + +import pandas as pd + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field + + +DF_SUMMARY_PROMPT = """ +Inspect the provided data and give a short unique name to the dataset. \n +dataframe sample: \n ------- \n {df_sample} \n ------- \n +Here is the data schema: \n ------- \n {df_schema} \n ------- \n +AVOID the following names: \n ------- \n {current_df_names} \n ------- \n +Provide descriptive name mainly based on the data context above. +User request content is just for context. +""" + + +class DfInfo(BaseModel): + """Data Info output.""" + + dataset: str = Field(pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case name of the dataset.") + + +def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], pd.DataFrame]: + """Get the dataframe schema and sample.""" + formatted_pairs = dict(df.dtypes.astype(str)) + df_sample = df.sample(5, replace=True, random_state=19) + return formatted_pairs, df_sample + + +def _create_df_info_content(df_schema: Dict[str, str], df_sample: pd.DataFrame, current_df_names: List[str]) -> dict: + """Create the message content for the dataframe summarization.""" + return DF_SUMMARY_PROMPT.format(df_sample=df_sample, df_schema=df_schema, current_df_names=current_df_names) + + +if __name__ == "__main__": + df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}) + df_schema, df_sample = _get_df_info(df) + current_df_names = ["df1", "df2"] + print(_create_df_info_content(df_schema, df_sample, current_df_names)) # noqa: T201 + print(DfInfo(dataset="test").dict()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py new file mode 100644 index 000000000..dbec8b3d6 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py @@ -0,0 +1,92 @@ +"""Layout plan model.""" + +import logging +from typing import List, Optional + +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, ValidationError + +logger = logging.getLogger(__name__) + + +def _convert_to_grid(layout_grid_template_areas: List[str], component_ids: List[str]) -> List[List[int]]: + component_map = {component: index for index, component in enumerate(component_ids)} + grid = [] + + for row in layout_grid_template_areas: + grid_row = [] + for cell in row.split(): + if cell == ".": + grid_row.append(-1) + else: + try: + grid_row.append(component_map[cell]) + except KeyError: + logger.warning( + f""" +[FALLBACK] Component {cell} not found in component_ids: {component_ids}. +Returning default values. +""" + ) + return [] + grid.append(grid_row) + + return grid + + +class LayoutPlan(BaseModel): + """Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card).""" + + layout_grid_template_areas: List[str] = Field( + [], + description=""" + Generate grid template areas for the layout adhering to the grid-template-areas CSS property syntax. + If no layout requested, return an empty list. + If requested, represent each component by 'component_id'. + IMPORTANT: Ensure that the `component_id` matches the `component_id` in the ComponentPlan. + If a grid area is empty, use a dot ('.') to represent it. + Ensure that each row of the grid layout is represented by a string, with each grid area separated by a space. + Return the grid template areas as a list of strings, where each string corresponds to a row in the grid. + No more than 600 characters in total. + """, + ) + + def create(self, component_ids: List[str]) -> Optional[vm.Layout]: + """Create the layout.""" + if not self.layout_grid_template_areas: + return None + + try: + grid = _convert_to_grid( + layout_grid_template_areas=self.layout_grid_template_areas, component_ids=component_ids + ) + actual = vm.Layout(grid=grid) + except ValidationError as e: + logger.warning( + f""" +[FALLBACK] Build failed for `Layout`, returning default values. Try rephrase the prompt or select a different model. +Error details: {e} +Relevant layout_grid_template_areas: +{self.layout_grid_template_areas} +""" + ) + if grid: + logger.warning(f"Calculated grid which caused the error: {grid}") + actual = None + + return actual + + +if __name__ == "__main__": + from vizro_ai._llm_models import _get_llm_model + + model = _get_llm_model() + layout_plan = LayoutPlan( + layout_grid_template_areas=["graph1 card2 card2", "graph1 . card1"], + ) + layout = layout_plan.create(component_ids=["graph1", "card1", "card2"]) + print(layout) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py new file mode 100644 index 000000000..de37b7db1 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py @@ -0,0 +1,223 @@ +"""Page plan model.""" + +import logging +from collections import Counter +from typing import List, Union + +try: + from pydantic.v1 import BaseModel, Field, PrivateAttr, ValidationError, root_validator, validator +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, PrivateAttr, ValidationError, root_validator, validator +import vizro.models as vm +from tqdm.auto import tqdm +from vizro_ai.dashboard._response_models.components import ComponentPlan +from vizro_ai.dashboard._response_models.controls import ControlPlan +from vizro_ai.dashboard._response_models.layout import LayoutPlan +from vizro_ai.dashboard.utils import _execute_step + +logger = logging.getLogger(__name__) + + +class PagePlan(BaseModel): + """Page plan model.""" + + title: str = Field( + ..., + description=""" + Title of the page. If no description is provided, + make a concise and descriptive title from the components. + """, + ) + components_plan: List[ComponentPlan] = Field( + ..., description="List of components. Must contain at least one component." + ) + controls_plan: List[ControlPlan] = Field([], description="Controls of the page.") + layout_plan: LayoutPlan = Field(None, description="Layout of components on the page.") + unsupported_specs: List[str] = Field( + [], + description=""" + List of unsupported specs. If there are any unsupported specs, + list them here. If not, leave this as an empty list. + """, + ) + + _components: List[Union[vm.Card, vm.AgGrid, vm.Figure]] = PrivateAttr() + _controls: List[vm.Filter] = PrivateAttr() + _layout: vm.Layout = PrivateAttr() + + @validator("components_plan") + def _check_components_plan(cls, v): + if not v: + raise ValueError("A page must contain at least one component.") + return v + + @validator("unsupported_specs") + def _check_unsupported_specs(cls, v, values): + title = values.get("title", "Unknown Title") + if v: + logger.warning(f"\n ------- \n Unsupported specs on page <{title}>: \n {v}") + return [] + + @root_validator(allow_reuse=True) + def validate_component_id_unique(cls, values): + """Validate the component id is unique.""" + components = values.get("components_plan", []) + component_ids = [comp.component_id for comp in components] + duplicates = [id for id, count in Counter(component_ids).items() if count > 1] + if duplicates: + raise ValidationError(f"Component ids must be unique. Duplicated component ids: {duplicates}") + return values + + def __init__(self, **data): + """Initialize the page plan.""" + super().__init__(**data) + self._components = None + self._controls = None + self._layout = None + + def _get_components(self, model, all_df_metadata): + if self._components is None: + self._components = self._build_components(model=model, all_df_metadata=all_df_metadata) + return self._components + + def _build_components(self, model, all_df_metadata): + components = [] + component_log = tqdm(total=0, bar_format="{desc}", leave=False) + with tqdm( + total=len(self.components_plan), + desc=f"Currently Building ... [Page] <{self.title}> components", + leave=False, + ) as pbar: + for component_plan in self.components_plan: + component_log.set_description_str(f"[Page] <{self.title}>: [Component] {component_plan.component_id}") + pbar.update(1) + components.append(component_plan.create(model=model, all_df_metadata=all_df_metadata)) + component_log.close() + return components + + def _get_layout(self, model, all_df_metadata): + if self._layout is None: + self._layout = self._build_layout(model, all_df_metadata) + return self._layout + + def _build_layout(self, model, all_df_metadata): + if self.layout_plan is None: + return None + return self.layout_plan.create( + component_ids=self._get_component_ids(model=model, all_df_metadata=all_df_metadata), + ) + + def _get_controls(self, model, all_df_metadata): + if self._controls is None: + self._controls = self._build_controls(model=model, all_df_metadata=all_df_metadata) + return self._controls + + def _controllable_components(self, model, all_df_metadata): + return [ + comp.id + for comp in self._get_components(model=model, all_df_metadata=all_df_metadata) + if isinstance(comp, (vm.Graph, vm.AgGrid)) + ] + + def _get_component_ids(self, model, all_df_metadata): + return [comp.id for comp in self._get_components(model=model, all_df_metadata=all_df_metadata)] + + def _build_controls(self, model, all_df_metadata): + controls = [] + with tqdm( + total=len(self.controls_plan), + desc=f"Currently Building ... [Page] <{self.title}> controls", + leave=False, + ) as pbar: + for control_plan in self.controls_plan: + pbar.update(1) + control = control_plan.create( + model=model, + controllable_components=self._controllable_components(model=model, all_df_metadata=all_df_metadata), + all_df_metadata=all_df_metadata, + ) + if control: + controls.append(control) + + return controls + + def create(self, model, all_df_metadata) -> Union[vm.Page, None]: + """Create the page.""" + page_desc = f"Building page: {self.title}" + logger.info(page_desc) + pbar = tqdm(total=5, desc=page_desc) + + title = _execute_step(pbar, page_desc + " --> add title", self.title) + components = _execute_step( + pbar, page_desc + " --> add components", self._get_components(model=model, all_df_metadata=all_df_metadata) + ) + controls = _execute_step( + pbar, page_desc + " --> add controls", self._get_controls(model=model, all_df_metadata=all_df_metadata) + ) + layout = _execute_step( + pbar, page_desc + " --> add layout", self._get_layout(model=model, all_df_metadata=all_df_metadata) + ) + + try: + page = vm.Page(title=title, components=components, controls=controls, layout=layout) + except Exception as e: + # TODO: This Exception might be redundant. Check if it can be removed. + if any("Number of page and grid components need to be the same" in error["msg"] for error in e.errors()): + logger.warning( + """ +[FALLBACK] Number of page and grid components provided are not the same. +Build page without layout. +""" + ) + page = vm.Page(title=title, components=components, controls=controls, layout=None) + else: + logger.warning(f"[FALLBACK] Failed to build page: {self.title}. Reason: {e}") + page = None + _execute_step(pbar, page_desc + " --> done", None) + pbar.close() + return page + + +if __name__ == "__main__": + import pandas as pd + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata( + all_df_metadata={ + "gdp_chart": DfMetadata( + df_schema={"a": "int64", "b": "int64"}, + df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + ) + } + ) + page_plan = PagePlan( + title="Worldwide GDP", + components_plan=[ + ComponentPlan( + component_type="Card", + component_description="Create a card says 'this is worldwide GDP'.", + component_id="gdp_card", + df_name="N/A", + ) + ], + controls_plan=[ + ControlPlan( + control_type="Filter", + control_description="Create a filter that filters the data by column 'a'.", + df_name="gdp_chart", + ) + ], + layout_plan=LayoutPlan( + layout_grid_template_areas=[], + ), + unsupported_specs=[], + ) + page = page_plan.create(model, all_df_metadata) + print(page) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py new file mode 100644 index 000000000..56cc2023f --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py @@ -0,0 +1,13 @@ +"""Types for response models.""" + +from typing import Literal + +# TODO make available in documentation + +# Complete list: ["AgGrid", "Button", "Card", "Container", "Graph", "Table", "Tabs"] +ComponentType = Literal["AgGrid", "Card", "Graph"] +"""Component types currently supported by Vizro-AI.""" + +# Complete list: ["Filter", "Parameter"] +ControlType = Literal["Filter"] +"""Control types currently supported by Vizro-AI.""" diff --git a/vizro-ai/src/vizro_ai/dashboard/utils.py b/vizro-ai/src/vizro_ai/dashboard/utils.py new file mode 100644 index 000000000..7276a57bb --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/utils.py @@ -0,0 +1,63 @@ +"""Helper Functions For Vizro AI dashboard.""" + +from dataclasses import dataclass, field +from typing import Any, Dict + +import pandas as pd +import tqdm.std as tsd +import vizro.models as vm + + +@dataclass +class DfMetadata: + """Dataclass containing metadata content for a dataframe.""" + + df_schema: Dict[str, str] + df: pd.DataFrame + df_sample: pd.DataFrame + + +@dataclass +class AllDfMetadata: + """Dataclass containing metadata for all dataframes.""" + + all_df_metadata: Dict[str, DfMetadata] = field(default_factory=dict) + + def get_schemas_and_samples(self) -> Dict[str, Dict[str, str]]: + """Retrieve only the df_schema and df_sample for all datasets.""" + return { + name: {"df_schema": metadata.df_schema, "df_sample": metadata.df_sample} + for name, metadata in self.all_df_metadata.items() + } + + def get_df(self, name: str) -> pd.DataFrame: + """Retrieve the dataframe by name.""" + try: + return self.all_df_metadata[name].df + except KeyError: + raise KeyError("Dataframe not found in metadata. Please ensure that the correct dataframe is provided.") + + def get_df_schema(self, name: str) -> Dict[str, str]: + """Retrieve the schema of the dataframe by name.""" + return self.all_df_metadata[name].df_schema + + +@dataclass +class DashboardOutputs: + """Dataclass containing all possible `VizroAI.dashboard()` output.""" + + dashboard: vm.Dashboard + + +def _execute_step(pbar: tsd.tqdm, description: str, value: Any) -> Any: + pbar.update(1) + pbar.set_description_str(description) + return value + + +def _register_data(all_df_metadata: AllDfMetadata) -> vm.Dashboard: + """Register the dashboard data in data manager.""" + from vizro.managers import data_manager + + for name, metadata in all_df_metadata.all_df_metadata.items(): + data_manager[name] = metadata.df diff --git a/vizro-ai/src/vizro_ai/py.typed b/vizro-ai/src/vizro_ai/py.typed new file mode 100644 index 000000000..512ec7cb8 --- /dev/null +++ b/vizro-ai/src/vizro_ai/py.typed @@ -0,0 +1 @@ + # Marker file for PEP 561 From 3d230675cfcebd92c3454cd2f1fc2ce34baab0e5 Mon Sep 17 00:00:00 2001 From: Maximilian Schulz <83698606+maxschulz-COL@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:25:49 +0200 Subject: [PATCH 3/6] [Feat] Improve background color and loading spinner on loading layout (#598) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ...946_maximilian_schulz_improved_load_CSS.md | 47 +++++++++++++++++++ vizro-core/src/vizro/static/css/loading.css | 43 +++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 vizro-core/changelog.d/20240725_105946_maximilian_schulz_improved_load_CSS.md create mode 100644 vizro-core/src/vizro/static/css/loading.css diff --git a/vizro-core/changelog.d/20240725_105946_maximilian_schulz_improved_load_CSS.md b/vizro-core/changelog.d/20240725_105946_maximilian_schulz_improved_load_CSS.md new file mode 100644 index 000000000..56bee372f --- /dev/null +++ b/vizro-core/changelog.d/20240725_105946_maximilian_schulz_improved_load_CSS.md @@ -0,0 +1,47 @@ + + + + + +### Added + +- Add dark mode and loading spinner to the layout loading screen (before Vizro app is shown) ([#598](https://github.com/mckinsey/vizro/pull/598)) + + + + + diff --git a/vizro-core/src/vizro/static/css/loading.css b/vizro-core/src/vizro/static/css/loading.css new file mode 100644 index 000000000..fd5eac3c8 --- /dev/null +++ b/vizro-core/src/vizro/static/css/loading.css @@ -0,0 +1,43 @@ +/* Inspired by https://github.com/facultyai/dash-bootstrap-components/blob/5c8f4b40f1100fc00bf2d5c1d671a7815a6b2910/docs/static/loading.css */ + +/* This creates a dark background in situations where neither dash-loading nor the Vizro app are displayed */ +html { + background: rgba(20, 23, 33, 1); + min-height: 100vh; +} + +/* The dash-loading Div is present when Dash is initially loading, before the layout is built */ + +/* The dash-loading-callback Div is present when Dash has loaded, but the layout is still building */ + +/* Note that the dash-loading-callback Div is present until all elements are loaded, but as soon as the initial page +elements (before on-page-load) are rendered, it gets pushed outside the viewable area, hence the spinner is not visible +which is good, as we have individual loading spinners for elements. At the moment, we are not using this class. +TODO: If we want to use this class, we need to evaluate if this is the best approach. + */ + +/* ._dash-loading-callback, */ +._dash-loading { + align-items: center; + background: rgba(20, 23, 33, 1); + color: transparent; + display: flex; + height: 100%; + justify-content: center; + position: fixed; + width: 100%; +} + +/* Loading spinner */ + +/* ._dash-loading-callback::after, */ +._dash-loading::after { + animation: spinner-border 0.75s linear infinite; + border: 0.5rem solid lightgrey; + border-radius: 50%; + border-right-color: transparent; + content: ""; + display: inline-block; + height: 8rem; + width: 8rem; +} From dd125def3749220d607d7f69f94cc9082fde48e3 Mon Sep 17 00:00:00 2001 From: Li Nguyen <90609403+huong-li-nguyen@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:01:20 +0200 Subject: [PATCH 4/6] [Bug] Fix display of marks in `vm.Slider` and `vm.RangeSlider` (#613) Co-authored-by: Antony Milne <49395058+antonymilne@users.noreply.github.com> --- ..._152834_huong_li_nguyen_fix_slider_type.md | 47 +++++++++ vizro-core/examples/scratch_dev/app.py | 95 +++++++++---------- vizro-core/schemas/0.1.20.dev0.json | 18 +--- .../models/_components/form/_form_utils.py | 6 ++ .../models/_components/form/range_slider.py | 4 +- .../vizro/models/_components/form/slider.py | 4 +- .../_components/form/test_range_slider.py | 12 ++- .../models/_components/form/test_slider.py | 12 ++- 8 files changed, 122 insertions(+), 76 deletions(-) create mode 100644 vizro-core/changelog.d/20240801_152834_huong_li_nguyen_fix_slider_type.md diff --git a/vizro-core/changelog.d/20240801_152834_huong_li_nguyen_fix_slider_type.md b/vizro-core/changelog.d/20240801_152834_huong_li_nguyen_fix_slider_type.md new file mode 100644 index 000000000..07c5ee233 --- /dev/null +++ b/vizro-core/changelog.d/20240801_152834_huong_li_nguyen_fix_slider_type.md @@ -0,0 +1,47 @@ + + + + + + + + +### Fixed + +- Fix display of marks in `vm.Slider` and `vm.RangeSlider` by converting floats to integers when possible. ([#613](https://github.com/mckinsey/vizro/pull/613)) + + diff --git a/vizro-core/examples/scratch_dev/app.py b/vizro-core/examples/scratch_dev/app.py index b4bb910c1..443fccf46 100644 --- a/vizro-core/examples/scratch_dev/app.py +++ b/vizro-core/examples/scratch_dev/app.py @@ -1,64 +1,59 @@ """Dev app to try things out.""" -from typing import List +from typing import Any, Literal -import pandas as pd -import plotly.graph_objects as go import vizro.models as vm +from dash import dcc from vizro import Vizro -from vizro.models.types import capture - -sankey_data = pd.DataFrame( - { - "Origin": [0, 1, 2, 1, 2, 4, 0], # indices inside labels - "Destination": [1, 2, 3, 4, 5, 5, 6], # indices inside labels - "Value": [10, 4, 8, 6, 4, 8, 8], - } -) -@capture("graph") -def sankey( - data_frame: pd.DataFrame, - source: str, - target: str, - value: str, - labels: List[str], -) -> go.Figure: - """Creates a sankey diagram based on a go.Figure.""" - fig = go.Figure( - data=[ - go.Sankey( - node={ - "pad": 16, - "thickness": 16, - "label": labels, - }, - link={ - "source": data_frame[source], - "target": data_frame[target], - "value": data_frame[value], - "label": labels, - "color": "rgba(205, 209, 228, 0.4)", - }, - ) - ] - ) - fig.update_layout(barmode="relative") - return fig +class PureDashSlider(vm.VizroBaseModel): + """Simple Dash Slider.""" + + type: Literal["simple_slider"] = "simple_slider" + kwargs: Any + + def build(self): + """Pure Slider component.""" + return dcc.Slider(**self.kwargs) + + +vm.Container.add_type("components", vm.Slider) +vm.Container.add_type("components", PureDashSlider) +# All floats: This works on Vizro, while it does not work in Dash +a = dict(min=0, max=2, step=0.1, marks={0.0: "a", 1.0: "x", 2.0: "y"}) # noqa: C408 + +# All int: This works in both Vizro and Dash +b = dict(min=0, max=2, step=0.1, marks={0: "a", 1: "x", 2: "y"}) # noqa: C408 + +# Mixed float and int: This works in Vizro, while it only partially works in Dash +c = dict(min=0, max=1, step=0.1, marks={0: "a", 0.5: "x", 1.0: "y"}) # noqa: C408 + +# User example +e = dict(min=0, max=1, step=0.01, marks={0: "0%", 0.21: "MARKET", 1.0: "100%"}) # noqa: C408 + +# Other test example: https://github.com/mckinsey/vizro/pull/266 +d = dict(min=2, max=5, step=1, value=3) # noqa: C408 page = vm.Page( - title="Sankey", + title="Vizro on PyCafe", components=[ - vm.Graph( - figure=sankey( - data_frame=sankey_data, - labels=["A1", "A2", "B1", "B2", "C1", "C2", "D1"], - source="Origin", - target="Destination", - value="Value", - ), + vm.Container( + title="vm.Sliders", + layout=vm.Layout(grid=[[0, 1, 2, 3, 4]]), + components=[vm.Slider(**a), vm.Slider(**b), vm.Slider(**c), vm.Slider(**d), vm.Slider(**e)], + ), + vm.Container( + title="dcc.Sliders", + layout=vm.Layout(grid=[[0, 1, 2, 3, 4]]), + components=[ + PureDashSlider(kwargs=a), + PureDashSlider(kwargs=b), + PureDashSlider(kwargs=c), + PureDashSlider(kwargs=d), + PureDashSlider(kwargs=e), + ], ), ], ) diff --git a/vizro-core/schemas/0.1.20.dev0.json b/vizro-core/schemas/0.1.20.dev0.json index 99492cc9f..89b8fe8fc 100644 --- a/vizro-core/schemas/0.1.20.dev0.json +++ b/vizro-core/schemas/0.1.20.dev0.json @@ -857,14 +857,7 @@ "default": {}, "type": "object", "additionalProperties": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "object" - } - ] + "type": "string" } }, "value": { @@ -932,14 +925,7 @@ "default": {}, "type": "object", "additionalProperties": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "object" - } - ] + "type": "string" } }, "value": { diff --git a/vizro-core/src/vizro/models/_components/form/_form_utils.py b/vizro-core/src/vizro/models/_components/form/_form_utils.py index 2b86cb341..39de02881 100644 --- a/vizro-core/src/vizro/models/_components/form/_form_utils.py +++ b/vizro-core/src/vizro/models/_components/form/_form_utils.py @@ -105,6 +105,12 @@ def validate_step(cls, step, values): def set_default_marks(cls, marks, values): if not marks and values.get("step") is None: marks = None + + # Dash has a bug where marks provided as floats that can be converted to integers are not displayed. + # So we need to convert the floats to integers if possible. + # https://github.com/plotly/dash-core-components/issues/159#issuecomment-380581043 + if marks: + marks = {int(k) if k.is_integer() else k: v for k, v in marks.items()} return marks diff --git a/vizro-core/src/vizro/models/_components/form/range_slider.py b/vizro-core/src/vizro/models/_components/form/range_slider.py index 34c1f98f4..503ba87c8 100644 --- a/vizro-core/src/vizro/models/_components/form/range_slider.py +++ b/vizro-core/src/vizro/models/_components/form/range_slider.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional from dash import ClientsideFunction, Input, Output, State, clientside_callback, dcc, html @@ -43,7 +43,7 @@ class RangeSlider(VizroBaseModel): min: Optional[float] = Field(None, description="Start value for slider.") max: Optional[float] = Field(None, description="End value for slider.") step: Optional[float] = Field(None, description="Step-size for marks on slider.") - marks: Optional[Dict[int, Union[str, Dict[str, Any]]]] = Field({}, description="Marks to be displayed on slider.") + marks: Optional[Dict[float, str]] = Field({}, description="Marks to be displayed on slider.") value: Optional[List[float]] = Field( None, description="Default start and end value for slider", min_items=2, max_items=2 ) diff --git a/vizro-core/src/vizro/models/_components/form/slider.py b/vizro-core/src/vizro/models/_components/form/slider.py index 33dcc4057..f3854e2b1 100644 --- a/vizro-core/src/vizro/models/_components/form/slider.py +++ b/vizro-core/src/vizro/models/_components/form/slider.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional from dash import ClientsideFunction, Input, Output, State, clientside_callback, dcc, html @@ -43,7 +43,7 @@ class Slider(VizroBaseModel): min: Optional[float] = Field(None, description="Start value for slider.") max: Optional[float] = Field(None, description="End value for slider.") step: Optional[float] = Field(None, description="Step-size for marks on slider.") - marks: Optional[Dict[int, Union[str, Dict[str, Any]]]] = Field({}, description="Marks to be displayed on slider.") + marks: Optional[Dict[float, str]] = Field({}, description="Marks to be displayed on slider.") value: Optional[float] = Field(None, description="Default value for slider.") title: str = Field("", description="Title to be displayed.") actions: List[Action] = [] diff --git a/vizro-core/tests/unit/vizro/models/_components/form/test_range_slider.py b/vizro-core/tests/unit/vizro/models/_components/form/test_range_slider.py index 80a23c9cc..e0c9a9f13 100644 --- a/vizro-core/tests/unit/vizro/models/_components/form/test_range_slider.py +++ b/vizro-core/tests/unit/vizro/models/_components/form/test_range_slider.py @@ -229,16 +229,22 @@ def test_validate_step_invalid(self): "marks, expected", [ ({i: str(i) for i in range(0, 10, 5)}, {i: str(i) for i in range(0, 10, 5)}), - ({15: 15, 25: 25}, {15.0: "15", 25.0: "25"}), - ({"15": 15, "25": 25}, {15.0: "15", 25.0: "25"}), + ({15: 15, 25: 25}, {15: "15", 25: "25"}), # all int + ({15.5: 15.5, 25.5: 25.5}, {15.5: "15.5", 25.5: "25.5"}), # all floats + ({15.0: 15, 25.5: 25.5}, {15: "15", 25.5: "25.5"}), # mixed floats + ({"15": 15, "25": 25}, {15: "15", 25: "25"}), # all string (None, None), ], ) def test_valid_marks(self, marks, expected): range_slider = vm.RangeSlider(min=0, max=10, marks=marks) - assert range_slider.marks == expected + if marks: + assert [type(result_key) for result_key in range_slider.marks] == [ + type(expected_key) for expected_key in expected + ] + def test_invalid_marks(self): with pytest.raises(ValidationError, match="2 validation errors for RangeSlider"): vm.RangeSlider(min=1, max=10, marks={"start": 0, "end": 10}) diff --git a/vizro-core/tests/unit/vizro/models/_components/form/test_slider.py b/vizro-core/tests/unit/vizro/models/_components/form/test_slider.py index 7ce15011d..56d2c5f24 100755 --- a/vizro-core/tests/unit/vizro/models/_components/form/test_slider.py +++ b/vizro-core/tests/unit/vizro/models/_components/form/test_slider.py @@ -120,16 +120,22 @@ def test_valid_marks_with_step(self): "marks, expected", [ ({i: str(i) for i in range(0, 10, 5)}, {i: str(i) for i in range(0, 10, 5)}), - ({15: 15, 25: 25}, {15.0: "15", 25.0: "25"}), - ({"15": 15, "25": 25}, {15.0: "15", 25.0: "25"}), + ({15: 15, 25: 25}, {15: "15", 25: "25"}), # all int + ({15.5: 15.5, 25.5: 25.5}, {15.5: "15.5", 25.5: "25.5"}), # all floats + ({15.0: 15, 25.5: 25.5}, {15: "15", 25.5: "25.5"}), # mixed floats + ({"15": 15, "25": 25}, {15: "15", 25: "25"}), # all string (None, None), ], ) def test_valid_marks(self, marks, expected): slider = vm.Slider(min=0, max=10, marks=marks) - assert slider.marks == expected + if marks: + assert [type(result_key) for result_key in slider.marks] == [ + type(expected_key) for expected_key in expected + ] + def test_invalid_marks(self): with pytest.raises(ValidationError, match="2 validation errors for Slider"): vm.Slider(min=1, max=10, marks={"start": 0, "end": 10}) From 8eb353e41d79a94a6903ac17a17c599dfdab6a54 Mon Sep 17 00:00:00 2001 From: Antony Milne <49395058+antonymilne@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:29:45 +0100 Subject: [PATCH 5/6] [Bug] Make figure models subclassable (#606) Co-authored-by: petar-qb --- ...petar_pejovic_make_figures_subclassable.md | 47 +++++++ vizro-core/examples/scratch_dev/app.py | 124 ++++++++++++------ .../src/vizro/models/_action/_action.py | 3 +- .../src/vizro/models/_components/ag_grid.py | 3 +- .../src/vizro/models/_components/figure.py | 3 +- .../src/vizro/models/_components/graph.py | 5 +- .../src/vizro/models/_components/table.py | 3 +- vizro-core/src/vizro/models/types.py | 11 +- vizro-core/tests/integration/test_examples.py | 12 +- .../unit/vizro/models/_action/test_action.py | 12 ++ .../vizro/models/_components/test_ag_grid.py | 11 ++ .../vizro/models/_components/test_figure.py | 10 ++ .../vizro/models/_components/test_graph.py | 11 ++ .../vizro/models/_components/test_table.py | 11 ++ .../tests/unit/vizro/models/test_types.py | 17 ++- 15 files changed, 221 insertions(+), 62 deletions(-) create mode 100644 vizro-core/changelog.d/20240801_113636_petar_pejovic_make_figures_subclassable.md diff --git a/vizro-core/changelog.d/20240801_113636_petar_pejovic_make_figures_subclassable.md b/vizro-core/changelog.d/20240801_113636_petar_pejovic_make_figures_subclassable.md new file mode 100644 index 000000000..74c27f31f --- /dev/null +++ b/vizro-core/changelog.d/20240801_113636_petar_pejovic_make_figures_subclassable.md @@ -0,0 +1,47 @@ + + + + + + + + +### Fixed + +- Fix subclassing of `vm.Graph`, `vm.Table`, `vm.AgGrid`, `vm.Figure` and `vm.Action` models. ([#606](https://github.com/mckinsey/vizro/pull/606)) + + diff --git a/vizro-core/examples/scratch_dev/app.py b/vizro-core/examples/scratch_dev/app.py index 443fccf46..8cebd9d33 100644 --- a/vizro-core/examples/scratch_dev/app.py +++ b/vizro-core/examples/scratch_dev/app.py @@ -1,61 +1,111 @@ """Dev app to try things out.""" -from typing import Any, Literal - import vizro.models as vm -from dash import dcc +import vizro.plotly.express as px from vizro import Vizro +from vizro.figures import kpi_card +from vizro.models.types import capture +from vizro.tables import dash_ag_grid, dash_data_table + +df = px.data.iris() + +# Graph +@capture("graph") +def my_graph_figure(data_frame, **kwargs): + """My custom figure.""" + return px.scatter(data_frame, **kwargs) -class PureDashSlider(vm.VizroBaseModel): - """Simple Dash Slider.""" - type: Literal["simple_slider"] = "simple_slider" - kwargs: Any +class MyGraph(vm.Graph): + """My custom class.""" def build(self): - """Pure Slider component.""" - return dcc.Slider(**self.kwargs) + """Custom build.""" + graph_build_obj = super().build() + # DO SOMETHING + return graph_build_obj + + +# Table +@capture("table") +def my_table_figure(data_frame, **kwargs): + """My custom figure.""" + return dash_data_table(data_frame, **kwargs)() + + +class MyTable(vm.Table): + """My custom class.""" + + pass + +# AgGrid +@capture("ag_grid") +def my_ag_grid_figure(data_frame, **kwargs): + """My custom figure.""" + return dash_ag_grid(data_frame, **kwargs)() -vm.Container.add_type("components", vm.Slider) -vm.Container.add_type("components", PureDashSlider) -# All floats: This works on Vizro, while it does not work in Dash -a = dict(min=0, max=2, step=0.1, marks={0.0: "a", 1.0: "x", 2.0: "y"}) # noqa: C408 +class MyAgGrid(vm.AgGrid): + """My custom class.""" -# All int: This works in both Vizro and Dash -b = dict(min=0, max=2, step=0.1, marks={0: "a", 1: "x", 2: "y"}) # noqa: C408 + pass -# Mixed float and int: This works in Vizro, while it only partially works in Dash -c = dict(min=0, max=1, step=0.1, marks={0: "a", 0.5: "x", 1.0: "y"}) # noqa: C408 -# User example -e = dict(min=0, max=1, step=0.01, marks={0: "0%", 0.21: "MARKET", 1.0: "100%"}) # noqa: C408 +# Figure +@capture("figure") +def my_kpi_card_figure(data_frame, **kwargs): + """My custom figure.""" + return kpi_card(data_frame, **kwargs)() + + +class MyFigure(vm.Figure): + """My custom class.""" + + pass + + +# Action +@capture("action") +def my_action_function(): + """My custom action.""" + pass + + +class MyAction(vm.Action): + """My custom class.""" + + pass -# Other test example: https://github.com/mckinsey/vizro/pull/266 -d = dict(min=2, max=5, step=1, value=3) # noqa: C408 page = vm.Page( - title="Vizro on PyCafe", + title="Test", + layout=vm.Layout( + grid=[[0, 1], [2, 3], [4, 5], [6, 7], [8, -1]], + col_gap="50px", + row_gap="50px", + ), components=[ - vm.Container( - title="vm.Sliders", - layout=vm.Layout(grid=[[0, 1, 2, 3, 4]]), - components=[vm.Slider(**a), vm.Slider(**b), vm.Slider(**c), vm.Slider(**d), vm.Slider(**e)], - ), - vm.Container( - title="dcc.Sliders", - layout=vm.Layout(grid=[[0, 1, 2, 3, 4]]), - components=[ - PureDashSlider(kwargs=a), - PureDashSlider(kwargs=b), - PureDashSlider(kwargs=c), - PureDashSlider(kwargs=d), - PureDashSlider(kwargs=e), - ], + # Graph + MyGraph(figure=px.scatter(df, x="sepal_width", y="sepal_length", title="My Graph")), + MyGraph(figure=my_graph_figure(df, x="sepal_width", y="sepal_length", title="My Graph Custom Figure")), + # Table + MyTable(figure=dash_data_table(df), title="My Table"), + MyTable(figure=my_table_figure(df), title="My Table Custom Figure"), + # AgGrid + MyAgGrid(figure=dash_ag_grid(df), title="My AgGrid"), + MyAgGrid(figure=my_ag_grid_figure(df), title="My AgGrid Custom Figure"), + # Figure + MyFigure(figure=kpi_card(df, value_column="sepal_width", title="KPI Card")), + MyFigure(figure=my_kpi_card_figure(df, value_column="sepal_width", title="KPI Card Custom Figure")), + # Action + MyGraph( + figure=my_graph_figure(df, x="sepal_width", y="sepal_length", title="My Graph Custom Figure"), + actions=[MyAction(function=my_action_function())], ), ], + controls=[vm.Filter(column="species")], ) dashboard = vm.Dashboard(pages=[page]) diff --git a/vizro-core/src/vizro/models/_action/_action.py b/vizro-core/src/vizro/models/_action/_action.py index 3d9c1c665..0d0bffef4 100644 --- a/vizro-core/src/vizro/models/_action/_action.py +++ b/vizro-core/src/vizro/models/_action/_action.py @@ -11,7 +11,6 @@ except ImportError: # pragma: no cov from pydantic import Field, validator -import vizro.actions from vizro.managers._model_manager import ModelID from vizro.models import VizroBaseModel from vizro.models._models_utils import _log_call @@ -32,7 +31,7 @@ class Action(VizroBaseModel): """ - function: CapturedCallable = Field(..., import_path=vizro.actions, mode="action", description="Action function.") + function: CapturedCallable = Field(..., import_path="vizro.actions", mode="action", description="Action function.") inputs: List[str] = Field( [], description="Inputs in the form `.` passed to the action function.", diff --git a/vizro-core/src/vizro/models/_components/ag_grid.py b/vizro-core/src/vizro/models/_components/ag_grid.py index 8a7b5e721..80cbc02a4 100644 --- a/vizro-core/src/vizro/models/_components/ag_grid.py +++ b/vizro-core/src/vizro/models/_components/ag_grid.py @@ -10,7 +10,6 @@ from pydantic import Field, PrivateAttr, validator from dash import ClientsideFunction, Input, Output, clientside_callback -import vizro.tables as vt from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_vizro_model from vizro.managers import data_manager from vizro.models import Action, VizroBaseModel @@ -35,7 +34,7 @@ class AgGrid(VizroBaseModel): type: Literal["ag_grid"] = "ag_grid" figure: CapturedCallable = Field( - ..., import_path=vt, mode="ag_grid", description="Function that returns a Dash AgGrid." + ..., import_path="vizro.tables", mode="ag_grid", description="Function that returns a `Dash AG Grid`." ) title: str = Field("", description="Title of the AgGrid") actions: List[Action] = [] diff --git a/vizro-core/src/vizro/models/_components/figure.py b/vizro-core/src/vizro/models/_components/figure.py index cb5128918..f4cf0c52e 100644 --- a/vizro-core/src/vizro/models/_components/figure.py +++ b/vizro-core/src/vizro/models/_components/figure.py @@ -7,7 +7,6 @@ except ImportError: # pragma: no cov from pydantic import Field, PrivateAttr, validator -import vizro.figures as vf from vizro.managers import data_manager from vizro.models import VizroBaseModel from vizro.models._components._components_utils import _process_callable_data_frame @@ -26,7 +25,7 @@ class Figure(VizroBaseModel): type: Literal["figure"] = "figure" figure: CapturedCallable = Field( - import_path=vf, + import_path="vizro.figures", mode="figure", description="Function that returns a figure-like object.", ) diff --git a/vizro-core/src/vizro/models/_components/graph.py b/vizro-core/src/vizro/models/_components/graph.py index 75670cd52..0f11a5d2c 100644 --- a/vizro-core/src/vizro/models/_components/graph.py +++ b/vizro-core/src/vizro/models/_components/graph.py @@ -12,7 +12,6 @@ import pandas as pd -import vizro.plotly.express as px from vizro import _themes as themes from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions from vizro.managers import data_manager, model_manager @@ -38,7 +37,9 @@ class Graph(VizroBaseModel): """ type: Literal["graph"] = "graph" - figure: CapturedCallable = Field(..., import_path=px, mode="graph", description="Function that returns a graph.") + figure: CapturedCallable = Field( + ..., import_path="vizro.plotly.express", mode="graph", description="Function that returns a plotly `go.Figure`" + ) actions: List[Action] = [] # Component properties for actions and interactions diff --git a/vizro-core/src/vizro/models/_components/table.py b/vizro-core/src/vizro/models/_components/table.py index 6aefb7b9a..2018aedb8 100644 --- a/vizro-core/src/vizro/models/_components/table.py +++ b/vizro-core/src/vizro/models/_components/table.py @@ -9,7 +9,6 @@ except ImportError: # pragma: no cov from pydantic import Field, PrivateAttr, validator -import vizro.tables as vt from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_vizro_model from vizro.managers import data_manager from vizro.models import Action, VizroBaseModel @@ -34,7 +33,7 @@ class Table(VizroBaseModel): type: Literal["table"] = "table" figure: CapturedCallable = Field( - ..., import_path=vt, mode="table", description="Function that returns a Dash DataTable." + ..., import_path="vizro.tables", mode="table", description="Function that returns a `Dash DataTable`." ) title: str = Field("", description="Title of the table") actions: List[Action] = [] diff --git a/vizro-core/src/vizro/models/types.py b/vizro-core/src/vizro/models/types.py index 8dce74d74..bbe0209f7 100644 --- a/vizro-core/src/vizro/models/types.py +++ b/vizro-core/src/vizro/models/types.py @@ -4,6 +4,7 @@ from __future__ import annotations import functools +import importlib import inspect from datetime import date from typing import Any, Dict, List, Literal, Protocol, Union, runtime_checkable @@ -203,9 +204,9 @@ def _parse_json( import_path = field.field_info.extra["import_path"] try: - function = getattr(import_path, function_name) - except AttributeError as exc: - raise ValueError(f"_target_={function_name} cannot be imported from {import_path.__name__}.") from exc + function = getattr(importlib.import_module(import_path), function_name) + except (AttributeError, ModuleNotFoundError) as exc: + raise ValueError(f"_target_={function_name} cannot be imported from {import_path}.") from exc # All the other items in figure are the keyword arguments to pass into function. function_kwargs = captured_callable_config @@ -230,11 +231,11 @@ def _extract_from_attribute( def _check_type(cls, captured_callable: CapturedCallable, field: ModelField) -> CapturedCallable: """Checks captured_callable is right type and mode.""" expected_mode = field.field_info.extra["mode"] - import_path_name = field.field_info.extra["import_path"].__name__ + import_path = field.field_info.extra["import_path"] if not isinstance(captured_callable, CapturedCallable): raise ValueError( - f"Invalid CapturedCallable. Supply a function imported from {import_path_name} or defined with " + f"Invalid CapturedCallable. Supply a function imported from {import_path} or defined with " f"decorator @capture('{expected_mode}')." ) diff --git a/vizro-core/tests/integration/test_examples.py b/vizro-core/tests/integration/test_examples.py index fe39c5a8b..8201f0c71 100644 --- a/vizro-core/tests/integration/test_examples.py +++ b/vizro-core/tests/integration/test_examples.py @@ -1,7 +1,6 @@ # ruff: noqa: F403, F405 import os import runpy -import sys from pathlib import Path import chromedriver_autoinstaller @@ -30,12 +29,11 @@ def dashboard(request, monkeypatch): example_directory = request.getfixturevalue("example_path") / request.getfixturevalue("version") monkeypatch.chdir(example_directory) monkeypatch.syspath_prepend(example_directory) - old_sys_modules = set(sys.modules) - yield runpy.run_path("app.py")["dashboard"] + return runpy.run_path("app.py")["dashboard"] # Both run_path and run_module contaminate sys.modules, so we need to undo this in order to avoid interference - # between tests. - for key in set(sys.modules) - old_sys_modules: - del sys.modules[key] + # between tests. However, if you do this then importlib.import_module seems to cause the problem due to mysterious + # reasons. The current system should work well so long as there's no sub-packages with clashing names in the + # examples. examples_path = Path(__file__).parents[2] / "examples" @@ -52,13 +50,13 @@ def dashboard(request, monkeypatch): @pytest.mark.parametrize( "example_path, version", [ + # KPI example is not included as it will be moved to HuggingFace over time. # Chart gallery is not included since it means installing black in the testing environment. # It will move to HuggingFace in due course anyway. (examples_path / "scratch_dev", ""), (examples_path / "scratch_dev", "yaml_version"), (examples_path / "dev", ""), (examples_path / "dev", "yaml_version"), - (examples_path / "kpi", ""), ], ids=str, ) diff --git a/vizro-core/tests/unit/vizro/models/_action/test_action.py b/vizro-core/tests/unit/vizro/models/_action/test_action.py index c145424f3..f42a63542 100644 --- a/vizro-core/tests/unit/vizro/models/_action/test_action.py +++ b/vizro-core/tests/unit/vizro/models/_action/test_action.py @@ -64,6 +64,18 @@ def test_create_action_mandatory_and_optional(self, identity_action_function): assert action.inputs == inputs assert action.outputs == outputs + def test_is_model_inheritable(self, identity_action_function): + class MyAction(vm.Action): + pass + + function = identity_action_function() + my_action = MyAction(function=function) + + assert hasattr(my_action, "id") + assert my_action.function is function + assert my_action.inputs == [] + assert my_action.outputs == [] + @pytest.mark.parametrize( "inputs, outputs", [ diff --git a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py index 906ccdb1b..dd2c1116f 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py @@ -69,6 +69,17 @@ def test_captured_callable_wrong_mode(self, standard_dash_table): ): vm.AgGrid(figure=standard_dash_table) + def test_is_model_inheritable(self, standard_ag_grid): + class MyAgGrid(vm.AgGrid): + pass + + my_ag_grid = MyAgGrid(figure=standard_ag_grid) + + assert hasattr(my_ag_grid, "id") + assert my_ag_grid.type == "ag_grid" + assert my_ag_grid.figure == standard_ag_grid + assert my_ag_grid.actions == [] + def test_set_action_via_validator(self, standard_ag_grid, identity_action_function): ag_grid = vm.AgGrid(figure=standard_ag_grid, actions=[Action(function=identity_action_function())]) actions_chain = ag_grid.actions[0] diff --git a/vizro-core/tests/unit/vizro/models/_components/test_figure.py b/vizro-core/tests/unit/vizro/models/_components/test_figure.py index 1c1a5e9d5..20bbc3b87 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_figure.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_figure.py @@ -54,6 +54,16 @@ def test_captured_callable_wrong_mode(self, standard_dash_table): ): vm.Figure(figure=standard_dash_table) + def test_is_model_inheritable(self, standard_kpi_card): + class MyFigure(vm.Figure): + pass + + my_figure = MyFigure(figure=standard_kpi_card) + + assert hasattr(my_figure, "id") + assert my_figure.type == "figure" + assert my_figure.figure == standard_kpi_card + class TestDunderMethodsFigure: def test_getitem_known_args(self, standard_kpi_card): diff --git a/vizro-core/tests/unit/vizro/models/_components/test_graph.py b/vizro-core/tests/unit/vizro/models/_components/test_graph.py index ddc62fdc3..71e74480e 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_graph.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_graph.py @@ -75,6 +75,17 @@ def test_captured_callable_wrong_mode(self, standard_ag_grid): ): vm.Graph(figure=standard_ag_grid) + def test_is_model_inheritable(self, standard_px_chart): + class MyGraph(vm.Graph): + pass + + my_graph = MyGraph(figure=standard_px_chart) + + assert hasattr(my_graph, "id") + assert my_graph.type == "graph" + assert my_graph.figure == standard_px_chart._captured_callable + assert my_graph.actions == [] + class TestDunderMethodsGraph: def test_getitem_known_args(self, standard_px_chart): diff --git a/vizro-core/tests/unit/vizro/models/_components/test_table.py b/vizro-core/tests/unit/vizro/models/_components/test_table.py index a44fab6b6..82e716fab 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_table.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_table.py @@ -69,6 +69,17 @@ def test_captured_callable_wrong_mode(self, standard_ag_grid): ): vm.Table(figure=standard_ag_grid) + def test_is_model_inheritable(self, standard_dash_table): + class MyTable(vm.Table): + pass + + my_table = MyTable(figure=standard_dash_table) + + assert hasattr(my_table, "id") + assert my_table.type == "table" + assert my_table.figure == standard_dash_table + assert my_table.actions == [] + def test_set_action_via_validator(self, standard_dash_table, identity_action_function): table = vm.Table(figure=standard_dash_table, actions=[Action(function=identity_action_function())]) actions_chain = table.actions[0] diff --git a/vizro-core/tests/unit/vizro/models/test_types.py b/vizro-core/tests/unit/vizro/models/test_types.py index 0e51cf9bf..baa977c7f 100644 --- a/vizro-core/tests/unit/vizro/models/test_types.py +++ b/vizro-core/tests/unit/vizro/models/test_types.py @@ -1,4 +1,3 @@ -import importlib import re import plotly.graph_objects as go @@ -162,12 +161,12 @@ def invalid_decorated_graph_function(): class ModelWithAction(VizroBaseModel): # The import_path here makes it possible to import the above function using getattr(import_path, _target_). - function: CapturedCallable = Field(..., import_path=importlib.import_module(__name__), mode="action") + function: CapturedCallable = Field(..., import_path=__name__, mode="action") class ModelWithGraph(VizroBaseModel): # The import_path here makes it possible to import the above function using getattr(import_path, _target_). - function: CapturedCallable = Field(..., import_path=importlib.import_module(__name__), mode="graph") + function: CapturedCallable = Field(..., import_path=__name__, mode="graph") class TestModelFieldPython: @@ -270,3 +269,15 @@ def test_wrong_mode(self): ), ): ModelWithGraph(function=config) + + def test_invalid_import_path(self): + class ModelWithInvalidModule(VizroBaseModel): + # The import_path doesn't exist. + function: CapturedCallable = Field(..., import_path="invalid.module", mode="graph") + + config = {"_target_": "decorated_graph_function", "data_frame": None} + + with pytest.raises( + ValueError, match="_target_=decorated_graph_function cannot be imported from invalid.module." + ): + ModelWithInvalidModule(function=config) From b445dee9288aa88d247f573a1c80c5d26f5c6d6d Mon Sep 17 00:00:00 2001 From: Alexey Snigir <35569332+l0uden@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:40:09 +0200 Subject: [PATCH 6/6] [CI] Move tests from CircleCI to GitHub (#558) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/checks-workflows.yml | 29 +++ .github/workflows/circleci-trigger.yml | 74 ------ .../workflows/test-integration-vizro-ai.yml | 13 + .github/workflows/vizro-qa-tests-trigger.yml | 50 ++++ tools/scan_yaml_for_risky_text.py | 19 ++ tools/trigger-workflow-and-wait.sh | 244 ++++++++++++++++++ 6 files changed, 355 insertions(+), 74 deletions(-) create mode 100644 .github/workflows/checks-workflows.yml delete mode 100644 .github/workflows/circleci-trigger.yml create mode 100644 .github/workflows/vizro-qa-tests-trigger.yml create mode 100644 tools/scan_yaml_for_risky_text.py create mode 100755 tools/trigger-workflow-and-wait.sh diff --git a/.github/workflows/checks-workflows.yml b/.github/workflows/checks-workflows.yml new file mode 100644 index 000000000..9b034dba4 --- /dev/null +++ b/.github/workflows/checks-workflows.yml @@ -0,0 +1,29 @@ +name: Checks for GitHub workflows + +on: + push: + branches: [main] + pull_request: + branches: + - main + +env: + PYTHONUNBUFFERED: 1 + FORCE_COLOR: 1 + PYTHON_VERSION: "3.11" + +jobs: + checks-workflows: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Checks for GitHub workflows + run: | + python tools/scan_yaml_for_risky_text.py .github/workflows diff --git a/.github/workflows/circleci-trigger.yml b/.github/workflows/circleci-trigger.yml deleted file mode 100644 index 9b8713c48..000000000 --- a/.github/workflows/circleci-trigger.yml +++ /dev/null @@ -1,74 +0,0 @@ -name: CircleCI tests trigger - -on: - push: - branches: [main] - pull_request: - branches: - - main - -env: - PYTHONUNBUFFERED: 1 - FORCE_COLOR: 1 - -jobs: - circleci-trigger-fork: - if: ${{ github.event.pull_request.head.repo.fork }} - name: CircleCI tests trigger - runs-on: ubuntu-latest - steps: - - name: Passed fork step - run: echo "Success!" - - circleci-trigger: - if: ${{ ! github.event.pull_request.head.repo.fork }} - name: CircleCI tests trigger - runs-on: ubuntu-latest - steps: - - name: Start CircleCI pipeline - run: | - create_circleci_pipeline() { - local branch=$1 - - local json_data=$(jq -n --arg branch "$branch" --arg vizro_branch "${{ github.head_ref }}" '{branch: $branch, parameters: {branch: $branch, vizro_branch: $vizro_branch}}') - - curl --silent --request POST \ - --url "${{ secrets.QA_PIPELINE_URL }}" \ - --header "Circle-Token: ${{ secrets.CIRCLECI_API_KEY }}" \ - --header "content-type: application/json" \ - --data "$json_data" \ - | jq -r '.id' - } - - PIPELINE=$(create_circleci_pipeline "${{ github.head_ref }}") - - # If the above returns null then the QA repo doesn't contain current dev branch, so we use main branch. - if [[ "$PIPELINE" == "null" ]]; then - PIPELINE=$(create_circleci_pipeline "main") - fi - echo "Started pipeline with id $PIPELINE" - - echo "PIPELINE=$PIPELINE" >> $GITHUB_ENV - - - name: Wait for pipeline to run - run: sleep 60 - - - name: Check pipeline status - run: | - get_pipeline_status() { - curl --silent --request GET \ - --url "https://circleci.com/api/v2/pipeline/$PIPELINE/workflow" \ - --header "Circle-Token: ${{ secrets.CIRCLECI_API_KEY }}" \ - --header "content-type: application/json" \ - | jq -r '.items[0].status' - } - - while pipeline_status=$(get_pipeline_status); [[ "$pipeline_status" == "running" ]]; do - echo $pipeline_status - sleep 15 - done - - if [[ "$pipeline_status" != "success" ]]; then - echo "Pipeline not completed successfully - status was ${pipeline_status}" - exit 1 - fi diff --git a/.github/workflows/test-integration-vizro-ai.yml b/.github/workflows/test-integration-vizro-ai.yml index bb5917e16..758cd04eb 100644 --- a/.github/workflows/test-integration-vizro-ai.yml +++ b/.github/workflows/test-integration-vizro-ai.yml @@ -93,3 +93,16 @@ jobs: cd ../vizro-ai hatch run ${{ matrix.hatch-env }}:pip install ../vizro-core/dist/vizro*.tar.gz hatch run ${{ matrix.hatch-env }}:test-integration + + - name: Send custom JSON data to Slack + id: slack + uses: slackapi/slack-github-action@v1.26.0 + if: failure() + with: + payload: | + { + "text": "Vizro-ai ${{ matrix.hatch-env }} integration tests build result: ${{ job.status }}\nBranch: ${{ github.head_ref }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK diff --git a/.github/workflows/vizro-qa-tests-trigger.yml b/.github/workflows/vizro-qa-tests-trigger.yml new file mode 100644 index 000000000..1d223a6de --- /dev/null +++ b/.github/workflows/vizro-qa-tests-trigger.yml @@ -0,0 +1,50 @@ +name: Vizro QA tests trigger + +on: + push: + branches: [main] + pull_request: + branches: + - main + +env: + PYTHONUNBUFFERED: 1 + FORCE_COLOR: 1 + +jobs: + vizro-qa-test-trigger-fork: + if: ${{ github.event.pull_request.head.repo.fork }} + name: Vizro QA ${{ matrix.label }} trigger + runs-on: ubuntu-latest + strategy: + matrix: + include: + - label: integration tests + - label: notebooks tests + steps: + - name: Passed fork step + run: echo "Success!" + + vizro-qa-tests-trigger: + if: ${{ ! github.event.pull_request.head.repo.fork }} + name: Vizro QA ${{ matrix.label }} trigger + runs-on: ubuntu-latest + strategy: + matrix: + include: + - label: integration tests + - label: notebooks test + steps: + - uses: actions/checkout@v4 + - name: Tests trigger + run: | + export INPUT_OWNER=${{ secrets.VIZRO_QA_ORG }} + export INPUT_REPO=${{ secrets.VIZRO_QA_REPO }} + if [ "${{ matrix.label }}" == "integration tests" ]; then + export INPUT_WORKFLOW_FILE_NAME=${{ secrets.VIZRO_QA_INTEGRATION_TESTS_WORKFLOW }} + elif [ "${{ matrix.label }}" == "notebooks test" ]; then + export INPUT_WORKFLOW_FILE_NAME=${{ secrets.VIZRO_QA_NOTEBOOKS_TESTS_WORKFLOW }} + fi + export INPUT_GITHUB_TOKEN=${{ secrets.VIZRO_SVC_PAT }} + export INPUT_REF=${{ github.head_ref }} + tools/trigger-workflow-and-wait.sh diff --git a/tools/scan_yaml_for_risky_text.py b/tools/scan_yaml_for_risky_text.py new file mode 100644 index 000000000..70240f872 --- /dev/null +++ b/tools/scan_yaml_for_risky_text.py @@ -0,0 +1,19 @@ +"""Check for security issues in workflows files.""" + +import sys +from pathlib import Path + +# according to this article: https://nathandavison.com/blog/github-actions-and-the-threat-of-malicious-pull-requests +# we should avoid using `pull_request_target` for security reasons +risky_text = "pull_request_target" + + +def find_risky_files(path: str): + """Searching for risky text in yml files for given path.""" + return {file for file in Path(path).rglob("*.yml") if risky_text in file.read_text()} + + +if __name__ == "__main__": + risky_files = find_risky_files(sys.argv[1]) + if risky_files: + sys.exit(f"{risky_text} found in files {risky_files}.") diff --git a/tools/trigger-workflow-and-wait.sh b/tools/trigger-workflow-and-wait.sh new file mode 100755 index 000000000..4447efb65 --- /dev/null +++ b/tools/trigger-workflow-and-wait.sh @@ -0,0 +1,244 @@ +#!/usr/bin/env bash + +#MIT License +# +#Copyright (c) 2020 Convictional, Inc. +# +#Permission is hereby granted, free of charge, to any person obtaining a copy +#of this software and associated documentation files (the "Software"), to deal +#in the Software without restriction, including without limitation the rights +#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +#copies of the Software, and to permit persons to whom the Software is +#furnished to do so, subject to the following conditions: +# +#The above copyright notice and this permission notice shall be included in all +#copies or substantial portions of the Software. +# +#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +#SOFTWARE. + +set -e + +GITHUB_API_URL="${API_URL:-https://api.github.com}" +GITHUB_SERVER_URL="${SERVER_URL:-https://github.com}" + +validate_args() { + wait_interval=10 # Waits for 10 seconds + if [ "${INPUT_WAIT_INTERVAL}" ] + then + wait_interval=${INPUT_WAIT_INTERVAL} + fi + + propagate_failure=true + if [ -n "${INPUT_PROPAGATE_FAILURE}" ] + then + propagate_failure=${INPUT_PROPAGATE_FAILURE} + fi + + trigger_workflow=true + if [ -n "${INPUT_TRIGGER_WORKFLOW}" ] + then + trigger_workflow=${INPUT_TRIGGER_WORKFLOW} + fi + + wait_workflow=true + if [ -n "${INPUT_WAIT_WORKFLOW}" ] + then + wait_workflow=${INPUT_WAIT_WORKFLOW} + fi + + if [ -z "${INPUT_OWNER}" ] + then + echo "Error: Owner is a required argument." + exit 1 + fi + + if [ -z "${INPUT_REPO}" ] + then + echo "Error: Repo is a required argument." + exit 1 + fi + + if [ -z "${INPUT_GITHUB_TOKEN}" ] + then + echo "Error: Github token is required. You can head over settings and" + echo "under developer, you can create a personal access tokens. The" + echo "token requires repo access." + exit 1 + fi + + if [ -z "${INPUT_WORKFLOW_FILE_NAME}" ] + then + echo "Error: Workflow File Name is required" + exit 1 + fi + + client_payload=$(echo '{}' | jq -c) + if [ "${INPUT_CLIENT_PAYLOAD}" ] + then + client_payload=$(echo "${INPUT_CLIENT_PAYLOAD}" | jq -c) + fi + + ref="main" + if [ "$INPUT_REF" ] + then + ref="${INPUT_REF}" + fi +} + +lets_wait() { + echo "Sleeping for ${wait_interval} seconds" + sleep "$wait_interval" +} + +api() { + path=$1; shift + if response=$(curl --fail-with-body -sSL \ + "${GITHUB_API_URL}/repos/${INPUT_OWNER}/${INPUT_REPO}/actions/$path" \ + -H "Authorization: Bearer ${INPUT_GITHUB_TOKEN}" \ + -H 'Accept: application/vnd.github.v3+json' \ + -H 'Content-Type: application/json' \ + "$@") + then + echo "$response" + else + echo >&2 "api failed:" + echo >&2 "path: $path" + echo >&2 "response: $response" + if [[ "$response" == *'"Server Error"'* ]]; then + echo "Server error - trying again" + else + exit 1 + fi + fi +} + +lets_wait() { + local interval=${1:-$wait_interval} + echo >&2 "Sleeping for $interval seconds" + sleep "$interval" +} + +# Return the ids of the most recent workflow runs, optionally filtered by user +get_workflow_runs() { + since=${1:?} + + query="event=workflow_dispatch&created=>=$since${INPUT_GITHUB_USER+&actor=}${INPUT_GITHUB_USER}&per_page=100" + + echo "Getting workflow runs using query: ${query}" >&2 + + api "workflows/${INPUT_WORKFLOW_FILE_NAME}/runs?${query}" | + jq -r '.workflow_runs[].id' | + sort # Sort to ensure repeatable order, and lexicographically for compatibility with join +} + +trigger_workflow() { + START_TIME=$(date +%s) + SINCE=$(date -u -Iseconds -d "@$((START_TIME - 120))") # Two minutes ago, to overcome clock skew + + OLD_RUNS=$(get_workflow_runs "$SINCE") + + echo >&2 "Triggering workflow:" + echo >&2 " workflows/${INPUT_WORKFLOW_FILE_NAME}/dispatches" + echo >&2 " {\"ref\":\"${ref}\",\"inputs\":${client_payload}}" + + api "workflows/${INPUT_WORKFLOW_FILE_NAME}/dispatches" \ + --data "{\"ref\":\"${ref}\",\"inputs\":${client_payload}}" + + NEW_RUNS=$OLD_RUNS + while [ "$NEW_RUNS" = "$OLD_RUNS" ] + do + lets_wait + NEW_RUNS=$(get_workflow_runs "$SINCE") + done + + # Return new run ids + join -v2 <(echo "$OLD_RUNS") <(echo "$NEW_RUNS") +} + +comment_downstream_link() { + if response=$(curl --fail-with-body -sSL -X POST \ + "${INPUT_COMMENT_DOWNSTREAM_URL}" \ + -H "Authorization: Bearer ${INPUT_COMMENT_GITHUB_TOKEN}" \ + -H 'Accept: application/vnd.github.v3+json' \ + -d "{\"body\": \"Running downstream job at $1\"}") + then + echo "$response" + else + echo >&2 "failed to comment to ${INPUT_COMMENT_DOWNSTREAM_URL}:" + fi +} + +wait_for_workflow_to_finish() { + last_workflow_id=${1:?} + last_workflow_url="${GITHUB_SERVER_URL}/${INPUT_OWNER}/${INPUT_REPO}/actions/runs/${last_workflow_id}" + + echo "Waiting for workflow to finish:" + echo "The workflow id is [${last_workflow_id}]." + echo "The workflow logs can be found at ${last_workflow_url}" + echo "workflow_id=${last_workflow_id}" >> $GITHUB_OUTPUT + echo "workflow_url=${last_workflow_url}" >> $GITHUB_OUTPUT + echo "" + + if [ -n "${INPUT_COMMENT_DOWNSTREAM_URL}" ]; then + comment_downstream_link ${last_workflow_url} + fi + + conclusion=null + status= + + while [[ "${conclusion}" == "null" && "${status}" != "completed" ]] + do + lets_wait + + workflow=$(api "runs/$last_workflow_id") + conclusion=$(echo "${workflow}" | jq -r '.conclusion') + status=$(echo "${workflow}" | jq -r '.status') + + echo "Checking conclusion [${conclusion}]" + echo "Checking status [${status}]" + echo "conclusion=${conclusion}" >> $GITHUB_OUTPUT + done + + if [[ "${conclusion}" == "success" && "${status}" == "completed" ]] + then + echo "Yes, success" + else + # Alternative "failure" + echo "Conclusion is not success, it's [${conclusion}]." + + if [ "${propagate_failure}" = true ] + then + echo "Propagating failure to upstream job" + exit 1 + fi + fi +} + +main() { + validate_args + + if [ "${trigger_workflow}" = true ] + then + run_ids=$(trigger_workflow) + else + echo "Skipping triggering the workflow." + fi + + if [ "${wait_workflow}" = true ] + then + for run_id in $run_ids + do + wait_for_workflow_to_finish "$run_id" + done + else + echo "Skipping waiting for workflow." + fi +} + +main