Skip to content

Commit

Permalink
Merge branch 'test/xai' of github.com:mckinsey/vizro into test/xai
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Nov 7, 2024
2 parents d54508c + 0f51c7d commit 252df7f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
11 changes: 7 additions & 4 deletions vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, "xAI (free API credits available)": ChatOpenAI}
SUPPORTED_VENDORS = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI (free API credits available)": ChatOpenAI,
}

SUPPORTED_MODELS = {
"OpenAI": [
Expand Down Expand Up @@ -64,9 +69,7 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
if vendor_input == "xAI (free API credits available)":
llm = vendor(
model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE
)
llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)

vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
Expand Down
6 changes: 5 additions & 1 deletion vizro-ai/examples/dashboard_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@
MyDropdown(
options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
),
OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"], value="OpenAI"),
OffCanvas(
id="settings",
options=["OpenAI", "Anthropic", "Mistral", "xAI (free API credits available)"],
value="OpenAI",
),
UserPromptTextArea(id="text-area-id"),
# Modal(id="modal"),
],
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/src/vizro_ai/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI, # xAI API is compatible with OpenAI
"xAI": ChatOpenAI, # xAI API is compatible with OpenAI
}

DEFAULT_MODEL = "gpt-4o-mini"
Expand Down
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def _check_chart_code(cls, v):
# Remove markdown code block if present
code = v
if code.startswith("```python\n") and code.endswith("```"):
code = code[len("```python\n"):-3].strip()
code = code[len("```python\n") : -3].strip()
elif code.startswith("```\n") and code.endswith("```"):
code = code[len("```\n"):-3].strip()
code = code[len("```\n") : -3].strip()

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in code:
Expand All @@ -110,7 +110,7 @@ def _check_chart_code(cls, v):
"""The chart code must accept a single argument `data_frame`,
and it should be the first argument of the chart."""
)
return code
return code

def _get_imports(self, vizro: bool = False):
imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates
Expand Down

0 comments on commit 252df7f

Please sign in to comment.