diff --git a/backend/onyx/llm/chat_llm.py b/backend/onyx/llm/chat_llm.py index aa627703b14..69b213dfce5 100644 --- a/backend/onyx/llm/chat_llm.py +++ b/backend/onyx/llm/chat_llm.py @@ -266,25 +266,28 @@ def __init__( # ) self._custom_config = custom_config + # Create a dictionary for model-specific arguments if it's None + model_kwargs = model_kwargs or {} + # NOTE: have to set these as environment variables for Litellm since # not all are able to passed in but they always support them set as env # variables. We'll also try passing them in, since litellm just ignores # addtional kwargs (and some kwargs MUST be passed in rather than set as # env variables) - - # Create a dictionary for model-specific arguments if it's None - model_kwargs = model_kwargs or {} - - # Filter out empty or None values from custom_config before use if custom_config: - filtered_config = {k: v for k, v in custom_config.items() if v} - - # Set non-empty config entries as environment variables for litellm - for k, v in filtered_config.items(): - os.environ[k] = v - - # Update model_kwargs with remaining non-empty config - model_kwargs.update(filtered_config) + # Specifically pass in "vertex_credentials" as a model_kwarg to the + # completion call for vertex AI. More details here: + # https://docs.litellm.ai/docs/providers/vertex + vertex_credentials_key = "vertex_credentials" + vertex_credentials = ( + custom_config.get(vertex_credentials_key) if custom_config else None + ) + if vertex_credentials and model_provider == "vertex_ai": + model_kwargs[vertex_credentials_key] = vertex_credentials + else: + # standard case + for k, v in custom_config.items(): + os.environ[k] = v if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) @@ -521,4 +524,4 @@ def _stream_implementation( log_msg = "" logger.debug(f"Raw Model Output:\n{log_msg}") else: - logger.debug(f"Raw Model Output:\n{content}") + logger.debug(f"Raw Model Output:\n{content}") #