From ee72ff9b6364ceafd96714414c7c86349862df7c Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 7 Nov 2024 11:15:32 -0500 Subject: [PATCH 1/6] add xAI grok-beta to code --- vizro-ai/examples/dashboard_ui/actions.py | 7 ++++++- vizro-ai/examples/dashboard_ui/app.py | 3 ++- vizro-ai/examples/example.ipynb | 8 ++++++++ vizro-ai/src/vizro_ai/_llm_models.py | 2 ++ vizro-ai/src/vizro_ai/plot/_response_models.py | 14 +++++++++++--- 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/vizro-ai/examples/dashboard_ui/actions.py b/vizro-ai/examples/dashboard_ui/actions.py index a69893d59..ff5421d32 100644 --- a/vizro-ai/examples/dashboard_ui/actions.py +++ b/vizro-ai/examples/dashboard_ui/actions.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled -SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI} +SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, "xAI (free API credits available)": ChatOpenAI} SUPPORTED_MODELS = { "OpenAI": [ @@ -43,6 +43,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI (free API credits available)": ["grok-beta"], } DEFAULT_TEMPERATURE = 0.1 DEFAULT_RETRY = 3 @@ -62,6 +63,10 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): ) if vendor_input == "Mistral": llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE) + if vendor_input == "xAI (free API credits available)": + llm = vendor( + model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE + ) vizro_ai = VizroAI(model=llm) ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True) diff --git a/vizro-ai/examples/dashboard_ui/app.py b/vizro-ai/examples/dashboard_ui/app.py index 2a4d00752..be7aad14a 100644 --- a/vizro-ai/examples/dashboard_ui/app.py +++ b/vizro-ai/examples/dashboard_ui/app.py @@ -70,6 +70,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI (free API credits available)": ["grok-beta"], } @@ -180,7 +181,7 @@ MyDropdown( options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id" ), - OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"), + OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"], value="OpenAI"), UserPromptTextArea(id="text-area-id"), # Modal(id="modal"), ], diff --git a/vizro-ai/examples/example.ipynb b/vizro-ai/examples/example.ipynb index 9fc56071c..f8029a472 100644 --- a/vizro-ai/examples/example.ipynb +++ b/vizro-ai/examples/example.ipynb @@ -14,6 +14,14 @@ "# llm = \"claude-3-5-sonnet-latest\"\n", "# llm = \"mistral-large-latest\"\n", "\n", + "# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n", + "# point `OPENAI_BASE_URL` to the xAI baseurl, use xAI API key for `OPENAI_API_KEY`\n", + "# when setting up the environment variables\n", + "# e.g.\n", + "# OPENAI_BASE_URL=\"https://api.x.ai/v1\"\n", + "# OPENAI_API_KEY=\n", + "# reference: https://docs.x.ai/api/integrations#openai-sdk\n", + "\n", "# from langchain_openai import ChatOpenAI\n", "# llm = ChatOpenAI(\n", "# model=\"gpt-4o\")\n", diff --git a/vizro-ai/src/vizro_ai/_llm_models.py b/vizro-ai/src/vizro_ai/_llm_models.py index c3c9858b0..62f32f2cf 100644 --- a/vizro-ai/src/vizro_ai/_llm_models.py +++ b/vizro-ai/src/vizro_ai/_llm_models.py @@ -27,12 +27,14 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI": ["grok-beta"], } DEFAULT_WRAPPER_MAP: dict[str, BaseChatModel] = { "OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, + "xAI": ChatOpenAI, # xAI API is compatible with OpenAI } DEFAULT_MODEL = "gpt-4o-mini" diff --git a/vizro-ai/src/vizro_ai/plot/_response_models.py b/vizro-ai/src/vizro_ai/plot/_response_models.py index 5bdedb73e..da07fbea9 100644 --- a/vizro-ai/src/vizro_ai/plot/_response_models.py +++ b/vizro-ai/src/vizro_ai/plot/_response_models.py @@ -93,16 +93,24 @@ class ChartPlan(BaseModel): @validator("chart_code") def _check_chart_code(cls, v): + # Remove markdown code block if present + code = v + if code.startswith("```python\n") and code.endswith("```"): + code = code[len("```python\n"):-3].strip() + elif code.startswith("```\n") and code.endswith("```"): + code = code[len("```\n"):-3].strip() + # TODO: add more checks: ends with return, has return, no second function def, only one indented line - if f"def {CUSTOM_CHART_NAME}(" not in v: + if f"def {CUSTOM_CHART_NAME}(" not in code: raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`") - if "data_frame" not in v.split("\n")[0]: + first_line = code.split("\n")[0].strip() + if "data_frame" not in first_line: raise ValueError( """The chart code must accept a single argument `data_frame`, and it should be the first argument of the chart.""" ) - return v + return code def _get_imports(self, vizro: bool = False): imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates From 0f51c7d935e78d914064091d10cd74a4d6a0fada Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:22:17 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- vizro-ai/examples/dashboard_ui/actions.py | 11 +++++++---- vizro-ai/examples/dashboard_ui/app.py | 6 +++++- vizro-ai/src/vizro_ai/_llm_models.py | 2 +- vizro-ai/src/vizro_ai/plot/_response_models.py | 6 +++--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/vizro-ai/examples/dashboard_ui/actions.py b/vizro-ai/examples/dashboard_ui/actions.py index ff5421d32..15f275eb8 100644 --- a/vizro-ai/examples/dashboard_ui/actions.py +++ b/vizro-ai/examples/dashboard_ui/actions.py @@ -28,7 +28,12 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled -SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, "xAI (free API credits available)": ChatOpenAI} +SUPPORTED_VENDORS = { + "OpenAI": ChatOpenAI, + "Anthropic": ChatAnthropic, + "Mistral": ChatMistralAI, + "xAI (free API credits available)": ChatOpenAI, +} SUPPORTED_MODELS = { "OpenAI": [ @@ -64,9 +69,7 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): if vendor_input == "Mistral": llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE) if vendor_input == "xAI (free API credits available)": - llm = vendor( - model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE - ) + llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE) vizro_ai = VizroAI(model=llm) ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True) diff --git a/vizro-ai/examples/dashboard_ui/app.py b/vizro-ai/examples/dashboard_ui/app.py index be7aad14a..eeadeffeb 100644 --- a/vizro-ai/examples/dashboard_ui/app.py +++ b/vizro-ai/examples/dashboard_ui/app.py @@ -181,7 +181,11 @@ MyDropdown( options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id" ), - OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"], value="OpenAI"), + OffCanvas( + id="settings", + options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"], + value="OpenAI", + ), UserPromptTextArea(id="text-area-id"), # Modal(id="modal"), ], diff --git a/vizro-ai/src/vizro_ai/_llm_models.py b/vizro-ai/src/vizro_ai/_llm_models.py index 62f32f2cf..6b572d512 100644 --- a/vizro-ai/src/vizro_ai/_llm_models.py +++ b/vizro-ai/src/vizro_ai/_llm_models.py @@ -34,7 +34,7 @@ "OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, - "xAI": ChatOpenAI, # xAI API is compatible with OpenAI + "xAI": ChatOpenAI, # xAI API is compatible with OpenAI } DEFAULT_MODEL = "gpt-4o-mini" diff --git a/vizro-ai/src/vizro_ai/plot/_response_models.py b/vizro-ai/src/vizro_ai/plot/_response_models.py index da07fbea9..49b3e98e1 100644 --- a/vizro-ai/src/vizro_ai/plot/_response_models.py +++ b/vizro-ai/src/vizro_ai/plot/_response_models.py @@ -96,9 +96,9 @@ def _check_chart_code(cls, v): # Remove markdown code block if present code = v if code.startswith("```python\n") and code.endswith("```"): - code = code[len("```python\n"):-3].strip() + code = code[len("```python\n") : -3].strip() elif code.startswith("```\n") and code.endswith("```"): - code = code[len("```\n"):-3].strip() + code = code[len("```\n") : -3].strip() # TODO: add more checks: ends with return, has return, no second function def, only one indented line if f"def {CUSTOM_CHART_NAME}(" not in code: @@ -110,7 +110,7 @@ def _check_chart_code(cls, v): """The chart code must accept a single argument `data_frame`, and it should be the first argument of the chart.""" ) - return code + return code def _get_imports(self, vizro: bool = False): imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates From d54508c24c64a53a78df0e0acb92f39492353476 Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 7 Nov 2024 11:23:55 -0500 Subject: [PATCH 3/6] add changelog --- .../20241107_112343_lingyi_zhang_xai.md | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md diff --git a/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md b/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md @@ -0,0 +1,48 @@ + + + + + + + + + From adf9e8bcc85e06c904ad0873efd623218396ff49 Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 7 Nov 2024 14:16:09 -0500 Subject: [PATCH 4/6] address comments --- vizro-ai/examples/dashboard_ui/actions.py | 6 +- vizro-ai/examples/dashboard_ui/app.py | 4 +- .../dashboard_ui/assets/custom_css.css | 10 ++++ vizro-ai/examples/dashboard_ui/components.py | 59 ++++++++++++++++++- .../src/vizro_ai/plot/_response_models.py | 15 +++-- 5 files changed, 80 insertions(+), 14 deletions(-) diff --git a/vizro-ai/examples/dashboard_ui/actions.py b/vizro-ai/examples/dashboard_ui/actions.py index 15f275eb8..872ce0274 100644 --- a/vizro-ai/examples/dashboard_ui/actions.py +++ b/vizro-ai/examples/dashboard_ui/actions.py @@ -32,7 +32,7 @@ "OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, - "xAI (free API credits available)": ChatOpenAI, + "xAI": ChatOpenAI, } SUPPORTED_MODELS = { @@ -48,7 +48,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], - "xAI (free API credits available)": ["grok-beta"], + "xAI": ["grok-beta"], } DEFAULT_TEMPERATURE = 0.1 DEFAULT_RETRY = 3 @@ -68,7 +68,7 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): ) if vendor_input == "Mistral": llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE) - if vendor_input == "xAI (free API credits available)": + if vendor_input == "xAI": llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE) vizro_ai = VizroAI(model=llm) diff --git a/vizro-ai/examples/dashboard_ui/app.py b/vizro-ai/examples/dashboard_ui/app.py index eeadeffeb..e8b29f7dc 100644 --- a/vizro-ai/examples/dashboard_ui/app.py +++ b/vizro-ai/examples/dashboard_ui/app.py @@ -70,7 +70,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], - "xAI (free API credits available)": ["grok-beta"], + "xAI": ["grok-beta"], } @@ -183,7 +183,7 @@ ), OffCanvas( id="settings", - options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"], + options=["OpenAI", "Anthropic", "Mistral", "xAI"], value="OpenAI", ), UserPromptTextArea(id="text-area-id"), diff --git a/vizro-ai/examples/dashboard_ui/assets/custom_css.css b/vizro-ai/examples/dashboard_ui/assets/custom_css.css index cd9c92d14..38419022f 100644 --- a/vizro-ai/examples/dashboard_ui/assets/custom_css.css +++ b/vizro-ai/examples/dashboard_ui/assets/custom_css.css @@ -306,3 +306,13 @@ #open-settings-id:hover { cursor: pointer; } + +.hover-effect:hover { + background-color: rgba(255, 255, 255, 0.1) !important; + box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); + transform: translateY(-2px); +} + +.hover-effect { + transition: all 0.2s ease !important; +} diff --git a/vizro-ai/examples/dashboard_ui/components.py b/vizro-ai/examples/dashboard_ui/components.py index 41a3d6d88..20a2d7eb4 100644 --- a/vizro-ai/examples/dashboard_ui/components.py +++ b/vizro-ai/examples/dashboard_ui/components.py @@ -148,6 +148,33 @@ def build(self): ) +def create_provider_item(name, url, note=None): + """Helper function to create a consistent ListGroupItem for each provider""" + return dbc.ListGroupItem( + [ + html.Div( + [ + html.Span(name, style={"color": "#ffffff"}), + (html.Small(note, style={"color": "rgba(255, 255, 255, 0.5)"}) if note else None), + html.Span("→", className="float-end", style={"color": "#ffffff"}), + ], + className="d-flex justify-content-between align-items-center", + ) + ], + href=url, + target="_blank", + action=True, + style={ + "background-color": "transparent", + "border": "1px solid rgba(255, 255, 255, 0.1)", + "margin-bottom": "8px", + "transition": "all 0.2s ease", + "cursor": "pointer", + }, + class_name="list-group-item-action hover-effect", + ) + + class OffCanvas(vm.VizroBaseModel): """OffCanvas component for settings.""" @@ -202,14 +229,44 @@ def build(self): className="mb-3", ) + providers = [ + {"name": "OpenAI", "url": "https://openai.com/index/openai-api/"}, + {"name": "Anthropic", "url": "https://docs.anthropic.com/en/api/getting-started"}, + {"name": "Mistral", "url": "https://docs.mistral.ai/getting-started/quickstart/"}, + {"name": "xAI", "url": "https://x.ai/blog/api", "note": "(Free API credits available)"}, + ] + + api_instructions = html.Div( + [ + html.Hr( + style={ + "margin": "2rem 0", + "border-color": "rgba(255, 255, 255, 0.1)", + "border-style": "solid", + "border-width": "0 0 1px 0", + } + ), + html.H6("Get API Keys:", className="mb-3", style={"color": "#ffffff"}), + dbc.ListGroup( + [ + create_provider_item(name=provider["name"], url=provider["url"], note=provider.get("note")) + for provider in providers + ], + flush=True, + className="border-0", + ), + ], + ) + offcanvas = dbc.Offcanvas( id=self.id, children=[ html.Div( children=[ input_groups, + api_instructions, ] - ) + ), ], title="Settings", is_open=True, diff --git a/vizro-ai/src/vizro_ai/plot/_response_models.py b/vizro-ai/src/vizro_ai/plot/_response_models.py index 49b3e98e1..efa10259d 100644 --- a/vizro-ai/src/vizro_ai/plot/_response_models.py +++ b/vizro-ai/src/vizro_ai/plot/_response_models.py @@ -94,23 +94,22 @@ class ChartPlan(BaseModel): @validator("chart_code") def _check_chart_code(cls, v): # Remove markdown code block if present - code = v - if code.startswith("```python\n") and code.endswith("```"): - code = code[len("```python\n") : -3].strip() - elif code.startswith("```\n") and code.endswith("```"): - code = code[len("```\n") : -3].strip() + if v.startswith("```python\n") and v.endswith("```"): + v = v[len("```python\n") : -3].strip() + elif v.startswith("```\n") and v.endswith("```"): + v = v[len("```\n") : -3].strip() # TODO: add more checks: ends with return, has return, no second function def, only one indented line - if f"def {CUSTOM_CHART_NAME}(" not in code: + if f"def {CUSTOM_CHART_NAME}(" not in v: raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`") - first_line = code.split("\n")[0].strip() + first_line = v.split("\n")[0].strip() if "data_frame" not in first_line: raise ValueError( """The chart code must accept a single argument `data_frame`, and it should be the first argument of the chart.""" ) - return code + return v def _get_imports(self, vizro: bool = False): imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates From 28de3761e6779091ac77403ca833c64666557ec4 Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 7 Nov 2024 14:26:47 -0500 Subject: [PATCH 5/6] lint --- vizro-ai/examples/dashboard_ui/assets/custom_css.css | 8 ++++---- vizro-ai/examples/dashboard_ui/components.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vizro-ai/examples/dashboard_ui/assets/custom_css.css b/vizro-ai/examples/dashboard_ui/assets/custom_css.css index 38419022f..dea230d73 100644 --- a/vizro-ai/examples/dashboard_ui/assets/custom_css.css +++ b/vizro-ai/examples/dashboard_ui/assets/custom_css.css @@ -307,12 +307,12 @@ cursor: pointer; } +.hover-effect { + transition: all 0.2s ease !important; +} + .hover-effect:hover { background-color: rgba(255, 255, 255, 0.1) !important; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); transform: translateY(-2px); } - -.hover-effect { - transition: all 0.2s ease !important; -} diff --git a/vizro-ai/examples/dashboard_ui/components.py b/vizro-ai/examples/dashboard_ui/components.py index 20a2d7eb4..6d8a42706 100644 --- a/vizro-ai/examples/dashboard_ui/components.py +++ b/vizro-ai/examples/dashboard_ui/components.py @@ -149,7 +149,7 @@ def build(self): def create_provider_item(name, url, note=None): - """Helper function to create a consistent ListGroupItem for each provider""" + """Helper function to create a consistent ListGroupItem for each provider.""" return dbc.ListGroupItem( [ html.Div( From 168f457701b7f5bab7d3faba2880e6fc6abb523d Mon Sep 17 00:00:00 2001 From: Lingyi Zhang Date: Thu, 7 Nov 2024 14:50:20 -0500 Subject: [PATCH 6/6] tidy --- vizro-ai/examples/dashboard_ui/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vizro-ai/examples/dashboard_ui/components.py b/vizro-ai/examples/dashboard_ui/components.py index 6d8a42706..1e8397b3f 100644 --- a/vizro-ai/examples/dashboard_ui/components.py +++ b/vizro-ai/examples/dashboard_ui/components.py @@ -246,7 +246,7 @@ def build(self): "border-width": "0 0 1px 0", } ), - html.H6("Get API Keys:", className="mb-3", style={"color": "#ffffff"}), + html.Div("Get API Keys", className="mb-3", style={"color": "#ffffff"}), dbc.ListGroup( [ create_provider_item(name=provider["name"], url=provider["url"], note=provider.get("note"))