Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyielia committed Jul 6, 2024
1 parent 5da09eb commit 1600b55
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 105 deletions.
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
82 changes: 31 additions & 51 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,34 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

# TODO is there a better way to handle this import?
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None

# TODO to be removed, just use BaseChatModel should be enough
LLM_MODELS = Union[ChatOpenAI]

# TODO constant of model inventory, can be converted to yaml and link to docs
PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, BaseChatModel]]] = {
"gpt-3.5-turbo-0613": {
"max_tokens": 4096,
"wrapper": ChatOpenAI,
},
"gpt-4-0613": {
"max_tokens": 8192,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-1106-preview": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-0125": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-turbo": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
"gpt-4o": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
}

# TODO add new wrappers in if new model support is added
if ChatAnthropic is not None:
PREDEFINED_MODELS = {
**PREDEFINED_MODELS,
**{"claude-3-haiku-20240307": {"max_tokens": 200000, "wrapper": ChatAnthropic}},
**{"claude-3-sonnet-20240229": {"max_tokens": 200000, "wrapper": ChatAnthropic}},
}

SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4-0613",
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
],
"Anthropic": [
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
],
}

DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic}
DEFAULT_MODEL = "gpt-3.5-turbo"
DEFAULT_TEMPERATURE = 0

model_to_vendor = {model: key for key, models in SUPPORTED_MODELS.items() for model in models}


def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel:
"""Fetches and initializes an instance of the LLM.
Expand All @@ -77,10 +48,17 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo
"""
if not model:
return ChatOpenAI(model_name=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE)
if isinstance(model, ChatOpenAI):

if isinstance(model, BaseChatModel):
return model
if isinstance(model, str) and model in PREDEFINED_MODELS:
return PREDEFINED_MODELS.get(model)["wrapper"](model_name=model, temperature=DEFAULT_TEMPERATURE)

if isinstance(model, str):
if any(model in model_list for model_list in SUPPORTED_MODELS.values()):
vendor = model_to_vendor[model]
if DEFAULT_WRAPPER_MAP.get(vendor) is None:
raise ValueError(f"Addtitional library to support {vendor} models is not installed.")
return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE)

raise ValueError(
f"Model {model} not found! List of available model can be found at https://vizro.readthedocs.io/projects/vizro-ai/en/latest/pages/explanation/faq/#which-llms-are-supported-by-vizro-ai"
)
Expand All @@ -100,4 +78,6 @@ def _get_model_name(model):


if __name__ == "__main__":
llm_chat_openai = _get_llm_model()
llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo")
print(repr(llm_chat_openai)) # noqa: T201
print(llm_chat_openai.model_name) # noqa: T201
8 changes: 4 additions & 4 deletions vizro-ai/src/vizro_ai/dashboard/nodes/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create(self, model, df_metadata) -> Union[ComponentType, None]:
return vm.AgGrid(id=self.component_id, figure=dash_ag_grid(data_frame=self.data_frame))
elif self.component_type == "Card":
return _get_proxy_model(
query=self.component_description, model=model, result_model=vm.Card, df_metadata=df_metadata
query=self.component_description, llm_model=model, result_model=vm.Card, df_metadata=df_metadata
)


Expand Down Expand Up @@ -142,7 +142,7 @@ def create(self, model, available_components, df_metadata):
df_cols=_df_cols, df_sample=_df_sample, available_components=available_components
)
proxy = _get_proxy_model(
query=filter_prompt, model=model, result_model=result_proxy, df_metadata=df_metadata
query=filter_prompt, llm_model=model, result_model=result_proxy, df_metadata=df_metadata
)
logger.info(
f"`Control` proxy: {proxy.dict()}"
Expand Down Expand Up @@ -199,7 +199,7 @@ def create(self, model, df_metadata) -> Union[vm.Layout, None]:

try:
proxy = _get_proxy_model(
query=self.layout_description, model=model, result_model=LayoutProxyModel, df_metadata=df_metadata
query=self.layout_description, llm_model=model, result_model=LayoutProxyModel, df_metadata=df_metadata
)
actual = vm.Layout.parse_obj(proxy.dict(exclude={}))
except (ValidationError, AttributeError) as e:
Expand Down Expand Up @@ -238,7 +238,7 @@ def _get_dashboard_plan(
model: Union[ChatOpenAI],
df_metadata: Dict[str, Dict[str, str]],
) -> DashboardPlanner:
return _get_proxy_model(query=query, model=model, result_model=DashboardPlanner, df_metadata=df_metadata)
return _get_proxy_model(query=query, llm_model=model, result_model=DashboardPlanner, df_metadata=df_metadata)


def _print_dashboard_plan(dashboard_plan) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

### Fixed

- Ensure that categorical selectors always return a list of values. ([#562](https://github.com/mckinsey/vizro/pull/562))

<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
75 changes: 30 additions & 45 deletions vizro-core/examples/_dev/app.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,49 @@
"""Dev app to try things out."""

from typing import Optional

import dash_bootstrap_components as dbc
import pandas as pd
import vizro.models as vm
import vizro.plotly.express as px
from dash import html
from vizro import Vizro
from vizro.figures import kpi_card
from vizro.models.types import capture

tips = px.data.tips
df_stocks = px.data.stocks(datetimes=True)
df_stocks_long = pd.melt(
df_stocks,
id_vars="date",
value_vars=["GOOG", "AAPL", "AMZN", "FB", "NFLX", "MSFT"],
var_name="stocks",
value_name="value",
)


@capture("figure") # (1)!
def custom_kpi_card( # noqa: PLR0913
data_frame: pd.DataFrame,
value_column: str,
*,
value_format: str = "{value}",
agg_func: str = "sum",
title: Optional[str] = None,
icon: Optional[str] = None,
) -> dbc.Card: # (2)!
"""Creates a custom KPI card."""
title = title or f"{agg_func} {value_column}".title()
value = data_frame[value_column].agg(agg_func)
@capture("graph")
def vizro_plot(data_frame, stocks_selected, **kwargs):
"""Custom chart function."""
return px.line(data_frame[data_frame["stocks"].isin(stocks_selected)], **kwargs)

header = dbc.CardHeader(
[
html.H2(title),
html.P(icon, className="material-symbols-outlined") if icon else None, # (3)!
]
)
body = dbc.CardBody([value_format.format(value=value)])
return dbc.Card([header, body], className="card-kpi")

df_stocks_long["value"] = df_stocks_long["value"].round(3)

page = vm.Page(
title="Create your own KPI card",
layout=vm.Layout(grid=[[0, 1, -1, -1]] + [[-1, -1, -1, -1]] * 3), # (4)!
title="My first page",
components=[
vm.Figure(
figure=kpi_card( # (5)!
data_frame=tips,
value_column="tip",
value_format="${value:.2f}",
icon="shopping_cart",
title="Default KPI card",
)
vm.Graph(
id="my_graph",
figure=vizro_plot(
data_frame=df_stocks_long,
stocks_selected=list(df_stocks_long["stocks"].unique()),
x="date",
y="value",
color="stocks",
),
),
vm.Figure(
figure=custom_kpi_card( # (6)!
data_frame=tips,
value_column="tip",
value_format="${value:.2f}",
icon="payment",
title="Custom KPI card",
)
],
controls=[
vm.Parameter(
targets=["my_graph.stocks_selected"],
selector=vm.Dropdown(
options=[{"label": s, "value": s} for s in df_stocks_long["stocks"].unique()],
),
),
],
)
Expand Down
15 changes: 11 additions & 4 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,19 @@ def _get_parametrized_config(target: ModelID, ctd_parameters: List[CallbackTrigg
config["data_frame"] = {}

for ctd in ctd_parameters:
selector_value = ctd[
"value"
] # TODO: needs to be refactored so that it is independent of implementation details
# TODO: needs to be refactored so that it is independent of implementation details
selector_value = ctd["value"]

if hasattr(selector_value, "__iter__") and ALL_OPTION in selector_value: # type: ignore[operator]
selector: SelectorType = model_manager[ctd["id"]]
selector_value = selector.options

# Even if options are provided as List[Dict], the Dash component only returns a List of values.
# So we need to ensure that we always return a List only as well to provide consistent types.
if all(isinstance(option, dict) for option in selector.options):
selector_value = [option["value"] for option in selector.options]
else:
selector_value = selector.options

selector_value = _validate_selector_value_none(selector_value)
selector_actions = _get_component_actions(model_manager[ctd["id"]])

Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_action/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _action_callback_function(
) -> Any:
logger.debug("===== Running action with id %s, function %s =====", self.id, self.function._function.__name__)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Action inputs:\n%s", pformat(inputs, depth=2, width=200))
logger.debug("Action inputs:\n%s", pformat(inputs, depth=3, width=200))
logger.debug("Action outputs:\n%s", pformat(outputs, width=200))

if isinstance(inputs, Mapping):
Expand Down
Loading

0 comments on commit 1600b55

Please sign in to comment.