Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add xAI grok-beta to code #858

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
10 changes: 9 additions & 1 deletion vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI}
SUPPORTED_VENDORS = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI,
}

SUPPORTED_MODELS = {
"OpenAI": [
Expand All @@ -43,6 +48,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}
DEFAULT_TEMPERATURE = 0.1
DEFAULT_RETRY = 3
Expand All @@ -62,6 +68,8 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
)
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
if vendor_input == "xAI":
llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)

vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
Expand Down
7 changes: 6 additions & 1 deletion vizro-ai/examples/dashboard_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}


Expand Down Expand Up @@ -180,7 +181,11 @@
MyDropdown(
options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
),
OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"),
OffCanvas(
id="settings",
options=["OpenAI", "Anthropic", "Mistral", "xAI"],
value="OpenAI",
),
lingyielia marked this conversation as resolved.
Show resolved Hide resolved
UserPromptTextArea(id="text-area-id"),
# Modal(id="modal"),
],
Expand Down
10 changes: 10 additions & 0 deletions vizro-ai/examples/dashboard_ui/assets/custom_css.css
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,13 @@
#open-settings-id:hover {
cursor: pointer;
}

.hover-effect {
transition: all 0.2s ease !important;
}

.hover-effect:hover {
background-color: rgba(255, 255, 255, 0.1) !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
transform: translateY(-2px);
}
59 changes: 58 additions & 1 deletion vizro-ai/examples/dashboard_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,33 @@ def build(self):
)


def create_provider_item(name, url, note=None):
"""Helper function to create a consistent ListGroupItem for each provider."""
return dbc.ListGroupItem(
[
html.Div(
[
html.Span(name, style={"color": "#ffffff"}),
(html.Small(note, style={"color": "rgba(255, 255, 255, 0.5)"}) if note else None),
html.Span("→", className="float-end", style={"color": "#ffffff"}),
],
className="d-flex justify-content-between align-items-center",
)
],
href=url,
target="_blank",
action=True,
style={
"background-color": "transparent",
"border": "1px solid rgba(255, 255, 255, 0.1)",
"margin-bottom": "8px",
"transition": "all 0.2s ease",
"cursor": "pointer",
},
class_name="list-group-item-action hover-effect",
)


class OffCanvas(vm.VizroBaseModel):
"""OffCanvas component for settings."""

Expand Down Expand Up @@ -202,14 +229,44 @@ def build(self):
className="mb-3",
)

providers = [
{"name": "OpenAI", "url": "https://openai.com/index/openai-api/"},
{"name": "Anthropic", "url": "https://docs.anthropic.com/en/api/getting-started"},
{"name": "Mistral", "url": "https://docs.mistral.ai/getting-started/quickstart/"},
{"name": "xAI", "url": "https://x.ai/blog/api", "note": "(Free API credits available)"},
]

api_instructions = html.Div(
[
html.Hr(
style={
"margin": "2rem 0",
"border-color": "rgba(255, 255, 255, 0.1)",
"border-style": "solid",
"border-width": "0 0 1px 0",
}
),
html.Div("Get API Keys", className="mb-3", style={"color": "#ffffff"}),
dbc.ListGroup(
[
create_provider_item(name=provider["name"], url=provider["url"], note=provider.get("note"))
for provider in providers
],
flush=True,
className="border-0",
),
],
)

offcanvas = dbc.Offcanvas(
id=self.id,
children=[
html.Div(
children=[
input_groups,
api_instructions,
]
)
),
],
title="Settings",
is_open=True,
Expand Down
8 changes: 8 additions & 0 deletions vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
"# llm = \"claude-3-5-sonnet-latest\"\n",
"# llm = \"mistral-large-latest\"\n",
"\n",
"# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n",
"# point `OPENAI_BASE_URL` to the xAI baseurl, use xAI API key for `OPENAI_API_KEY`\n",
"# when setting up the environment variables\n",
"# e.g.\n",
"# OPENAI_BASE_URL=\"https://api.x.ai/v1\"\n",
"# OPENAI_API_KEY=<xAI API key>\n",
"# reference: https://docs.x.ai/api/integrations#openai-sdk\n",
"\n",
"# from langchain_openai import ChatOpenAI\n",
"# llm = ChatOpenAI(\n",
"# model=\"gpt-4o\")\n",
Expand Down
2 changes: 2 additions & 0 deletions vizro-ai/src/vizro_ai/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}

DEFAULT_WRAPPER_MAP: dict[str, BaseChatModel] = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI, # xAI API is compatible with OpenAI
}

DEFAULT_MODEL = "gpt-4o-mini"
Expand Down
9 changes: 8 additions & 1 deletion vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,18 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT just suggested me this:

import re

def strip_markdown(v):
    # Strip leading and trailing markdown code block delimiters if they exist
    return re.sub(r"^```(?:python\n)?|```$", "", v).strip()

or

def strip_markdown(v):
    if v.startswith("```python\n"):
        v = v[len("```python\n"):]
    elif v.startswith("```\n"):
        v = v[len("```\n"):]
    if v.endswith("```"):
        v = v[:-3]
    return v.strip()

Maybe better?

if v.startswith("```python\n") and v.endswith("```"):
v = v[len("```python\n") : -3].strip()
elif v.startswith("```\n") and v.endswith("```"):
v = v[len("```\n") : -3].strip()

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`")

if "data_frame" not in v.split("\n")[0]:
first_line = v.split("\n")[0].strip()
if "data_frame" not in first_line:
raise ValueError(
"""The chart code must accept a single argument `data_frame`,
and it should be the first argument of the chart."""
Expand Down