Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Demo] Add VizroAI dashboard UI #690

Closed
wants to merge 15 commits into from
Closed
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
20 changes: 20 additions & 0 deletions vizro-ai/examples/chart_ui/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
# you will also find guides on how best to write your Dockerfile

FROM python:3.12

RUN useradd -m -u 1000 user

WORKDIR /app

COPY --chown=user ./requirements.txt requirements.txt

RUN pip install --no-cache-dir --upgrade -r requirements.txt

COPY --chown=user . /app

EXPOSE 7860

USER user

CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:7860", "app:server"]
8 changes: 8 additions & 0 deletions vizro-ai/examples/chart_ui/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Utils file."""


def check_file_extension(filename):
filename = filename.lower()

# Check if the filename ends with .csv or .xls
return filename.endswith(".csv") or filename.endswith(".xls") or filename.endswith(".xlsx")
129 changes: 129 additions & 0 deletions vizro-ai/examples/chart_ui/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Custom actions used within a dashboard."""

import base64
import io
import logging

import black
import pandas as pd
from _utils import check_file_extension
from dash.exceptions import PreventUpdate
from langchain_openai import ChatOpenAI
from plotly import graph_objects as go
from vizro.models.types import capture
from vizro_ai import VizroAI

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI}


def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""VizroAi plot configuration."""
vendor = SUPPORTED_VENDORS[vendor_input]
llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, explain=False, return_elements=True)

return ai_outputs


@capture("action")
def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""Gets the AI response and adds it to the text window."""

def create_response(ai_response, figure, user_prompt, filename):
plotly_fig = figure.to_json()
return (
ai_response,
figure,
{"ai_response": ai_response, "figure": plotly_fig, "prompt": user_prompt, "filename": filename},
)

if not n_clicks:
raise PreventUpdate

if not data:
ai_response = "Please upload data to proceed!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, None)

if not api_key:
ai_response = "API key not found. Make sure you enter your API key!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])

if api_key.startswith('"'):
ai_response = "Make sure you enter your API key without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])

if api_base is not None and api_base.startswith('"'):
ai_response = "Make sure you enter your API base without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])

try:
logger.info("Attempting chart code.")
df = pd.DataFrame(data["data"])
ai_outputs = get_vizro_ai_plot(
user_prompt=user_prompt,
df=df,
model=model,
api_key=api_key,
api_base=api_base,
vendor_input=vendor_input,
)
ai_code = ai_outputs.code
figure = ai_outputs.figure
formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))

ai_response = "\n".join(["```python", formatted_code, "```"])
logger.info("Successful query produced.")
return create_response(ai_response, figure, user_prompt, data["filename"])

except Exception as exc:
logger.debug(exc)
logger.info("Chart creation failed.")
ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])


@capture("action")
def data_upload_action(contents, filename):
"""Custom data upload action."""
if not contents:
raise PreventUpdate

if not check_file_extension(filename=filename):
return {"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."}

content_type, content_string = contents.split(",")

try:
decoded = base64.b64decode(content_string)
if filename.endswith(".csv"):
# Handle CSV file
df = pd.read_csv(io.StringIO(decoded.decode("utf-8")))
else:
# Handle Excel file
df = pd.read_excel(io.BytesIO(decoded))

data = df.to_dict("records")
return {"data": data, "filename": filename}

except Exception as e:
logger.debug(e)
return {"error_message": "There was an error processing this file."}


@capture("action")
def display_filename(data):
"""Custom action to display uploaded filename."""
if data is None:
raise PreventUpdate

display_message = data.get("filename") or data.get("error_message")
return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message
Loading
Loading