Skip to content

Commit

Permalink
[Dev] Update LLM model wrapper mapping (#517)
Browse files Browse the repository at this point in the history
Co-authored-by: Lingyi Zhang <[email protected]>
  • Loading branch information
Anna-Xiong and lingyielia authored Jul 5, 2024
1 parent b5e8245 commit 542b4c1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 34 deletions.
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
58 changes: 24 additions & 34 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,24 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

# TODO constant of model inventory, can be converted to yaml and link to docs
PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, BaseChatModel]]] = {
"gpt-3.5-turbo-0613": {
"max_tokens": 4096,
"wrapper": ChatOpenAI,
},
"gpt-4-0613": {
"max_tokens": 8192,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-1106-preview": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-0125": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-turbo": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4-0613",
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
]
}

DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI}
DEFAULT_MODEL = "gpt-3.5-turbo"
DEFAULT_TEMPERATURE = 0

model_to_vendor = {model: key for key, models in SUPPORTED_MODELS.items() for model in models}


def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel:
"""Fetches and initializes an instance of the LLM.
Expand All @@ -54,14 +37,21 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo
"""
if not model:
return ChatOpenAI(model_name=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE)
if isinstance(model, ChatOpenAI):

if isinstance(model, BaseChatModel):
return model
if isinstance(model, str) and model in PREDEFINED_MODELS:
return PREDEFINED_MODELS.get(model)["wrapper"](model_name=model, temperature=DEFAULT_TEMPERATURE)

if isinstance(model, str):
if any(model in model_list for model_list in SUPPORTED_MODELS.values()):
vendor = model_to_vendor[model]
return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE)

raise ValueError(
f"Model {model} not found! List of available model can be found at https://vizro.readthedocs.io/projects/vizro-ai/en/latest/pages/explanation/faq/#which-llms-are-supported-by-vizro-ai"
)


if __name__ == "__main__":
llm_chat_openai = _get_llm_model()
llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo")
print(repr(llm_chat_openai)) # noqa: T201
print(llm_chat_openai.model_name) # noqa: T201

0 comments on commit 542b4c1

Please sign in to comment.