Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add xAI grok-beta to code #858

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
10 changes: 9 additions & 1 deletion vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI}
SUPPORTED_VENDORS = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI (free API credits available)": ChatOpenAI,
}

SUPPORTED_MODELS = {
"OpenAI": [
Expand All @@ -43,6 +48,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI (free API credits available)": ["grok-beta"],
lingyielia marked this conversation as resolved.
Show resolved Hide resolved
}
DEFAULT_TEMPERATURE = 0.1
DEFAULT_RETRY = 3
Expand All @@ -62,6 +68,8 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
)
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
if vendor_input == "xAI (free API credits available)":
llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)

vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
Expand Down
7 changes: 6 additions & 1 deletion vizro-ai/examples/dashboard_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI (free API credits available)": ["grok-beta"],
}


Expand Down Expand Up @@ -180,7 +181,11 @@
MyDropdown(
options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
),
OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"),
OffCanvas(
id="settings",
options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"],
value="OpenAI",
),
lingyielia marked this conversation as resolved.
Show resolved Hide resolved
UserPromptTextArea(id="text-area-id"),
# Modal(id="modal"),
],
Expand Down
8 changes: 8 additions & 0 deletions vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
"# llm = \"claude-3-5-sonnet-latest\"\n",
"# llm = \"mistral-large-latest\"\n",
"\n",
"# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n",
"# point `OPENAI_BASE_URL` to the xAI baseurl, use xAI API key for `OPENAI_API_KEY`\n",
"# when setting up the environment variables\n",
"# e.g.\n",
"# OPENAI_BASE_URL=\"https://api.x.ai/v1\"\n",
"# OPENAI_API_KEY=<xAI API key>\n",
"# reference: https://docs.x.ai/api/integrations#openai-sdk\n",
"\n",
"# from langchain_openai import ChatOpenAI\n",
"# llm = ChatOpenAI(\n",
"# model=\"gpt-4o\")\n",
Expand Down
2 changes: 2 additions & 0 deletions vizro-ai/src/vizro_ai/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}

DEFAULT_WRAPPER_MAP: dict[str, BaseChatModel] = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI, # xAI API is compatible with OpenAI
}

DEFAULT_MODEL = "gpt-4o-mini"
Expand Down
14 changes: 11 additions & 3 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,24 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT just suggested me this:

import re

def strip_markdown(v):
    # Strip leading and trailing markdown code block delimiters if they exist
    return re.sub(r"^```(?:python\n)?|```$", "", v).strip()

or

def strip_markdown(v):
    if v.startswith("```python\n"):
        v = v[len("```python\n"):]
    elif v.startswith("```\n"):
        v = v[len("```\n"):]
    if v.endswith("```"):
        v = v[:-3]
    return v.strip()

Maybe better?

code = v
lingyielia marked this conversation as resolved.
Show resolved Hide resolved
if code.startswith("```python\n") and code.endswith("```"):
code = code[len("```python\n") : -3].strip()
elif code.startswith("```\n") and code.endswith("```"):
code = code[len("```\n") : -3].strip()

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
if f"def {CUSTOM_CHART_NAME}(" not in code:
raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`")

if "data_frame" not in v.split("\n")[0]:
first_line = code.split("\n")[0].strip()
if "data_frame" not in first_line:
raise ValueError(
"""The chart code must accept a single argument `data_frame`,
and it should be the first argument of the chart."""
)
return v
return code

def _get_imports(self, vizro: bool = False):
imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates
Expand Down