diff --git a/vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md b/vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/src/vizro_ai/chains/_llm_models.py b/vizro-ai/src/vizro_ai/chains/_llm_models.py index 7c70952f2..b8f4b929f 100644 --- a/vizro-ai/src/vizro_ai/chains/_llm_models.py +++ b/vizro-ai/src/vizro_ai/chains/_llm_models.py @@ -3,41 +3,24 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_openai import ChatOpenAI -# TODO constant of model inventory, can be converted to yaml and link to docs -PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, BaseChatModel]]] = { - "gpt-3.5-turbo-0613": { - "max_tokens": 4096, - "wrapper": ChatOpenAI, - }, - "gpt-4-0613": { - "max_tokens": 8192, - "wrapper": ChatOpenAI, - }, - "gpt-3.5-turbo-1106": { - "max_tokens": 16385, - "wrapper": ChatOpenAI, - }, - "gpt-4-1106-preview": { - "max_tokens": 128000, - "wrapper": ChatOpenAI, - }, - "gpt-3.5-turbo-0125": { - "max_tokens": 16385, - "wrapper": ChatOpenAI, - }, - "gpt-3.5-turbo": { - "max_tokens": 16385, - "wrapper": ChatOpenAI, - }, - "gpt-4-turbo": { - "max_tokens": 128000, - "wrapper": ChatOpenAI, - }, +SUPPORTED_MODELS = { + "OpenAI": [ + "gpt-4-0613", + "gpt-3.5-turbo-1106", + "gpt-4-1106-preview", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + ] } +DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI} DEFAULT_MODEL = "gpt-3.5-turbo" DEFAULT_TEMPERATURE = 0 +model_to_vendor = {model: key for key, models in SUPPORTED_MODELS.items() for model in models} + def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel: """Fetches and initializes an instance of the LLM. @@ -54,14 +37,21 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo """ if not model: return ChatOpenAI(model_name=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE) - if isinstance(model, ChatOpenAI): + + if isinstance(model, BaseChatModel): return model - if isinstance(model, str) and model in PREDEFINED_MODELS: - return PREDEFINED_MODELS.get(model)["wrapper"](model_name=model, temperature=DEFAULT_TEMPERATURE) + + if isinstance(model, str): + if any(model in model_list for model_list in SUPPORTED_MODELS.values()): + vendor = model_to_vendor[model] + return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE) + raise ValueError( f"Model {model} not found! List of available model can be found at https://vizro.readthedocs.io/projects/vizro-ai/en/latest/pages/explanation/faq/#which-llms-are-supported-by-vizro-ai" ) if __name__ == "__main__": - llm_chat_openai = _get_llm_model() + llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo") + print(repr(llm_chat_openai)) # noqa: T201 + print(llm_chat_openai.model_name) # noqa: T201