diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index c63fb94f0..74de081dc 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -42,12 +42,12 @@ class MethodNotImplementedError(Exception): """ -class UnsupportedOpenAIModelError(Exception): +class UnsupportedModelError(Exception): """ Raised when an unsupported OpenAI model is used. Args: - Exception (Exception): UnsupportedOpenAIModelError + Exception (Exception): UnsupportedModelError """ diff --git a/pandasai/llm/azure_openai.py b/pandasai/llm/azure_openai.py index ffe722146..edf0939d0 100644 --- a/pandasai/llm/azure_openai.py +++ b/pandasai/llm/azure_openai.py @@ -17,7 +17,7 @@ import openai from ..helpers import load_dotenv -from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError +from ..exceptions import APIKeyNotFoundError, UnsupportedModelError from ..prompts.base import AbstractPrompt from .base import BaseOpenAI @@ -83,7 +83,7 @@ def __init__( openai.api_type = self.api_type if deployment_name is None: - raise UnsupportedOpenAIModelError("Model deployment name is required.") + raise UnsupportedModelError("Model deployment name is required.") self.is_chat_model = is_chat_model self.engine = deployment_name diff --git a/pandasai/llm/google_vertexai.py b/pandasai/llm/google_vertexai.py index 42acc77b6..cb0e256e3 100644 --- a/pandasai/llm/google_vertexai.py +++ b/pandasai/llm/google_vertexai.py @@ -12,6 +12,7 @@ """ from typing import Optional from .base import BaseGoogle +from ..exceptions import UnsupportedModelError from ..helpers.optional import import_dependency @@ -20,9 +21,19 @@ class GoogleVertexAI(BaseGoogle): BaseGoogle class is extended for Google Palm model using Vertexai. The default model support at the moment is text-bison-001. However, user can choose to use code-bison-001 too. - """ + _supported_code_models = [ + "code-bison", + "code-bison-32k", + "code-bison@001", + ] + _supported_text_models = [ + "text-bison", + "text-bison-32k", + "text-bison@001", + ] + def __init__( self, project_id: str, location: str, model: Optional[str] = None, **kwargs ): @@ -97,7 +108,7 @@ def _generate_text(self, prompt: str) -> str: TextGenerationModel, ) - if self.model == "code-bison@001": + if self.model in self._supported_code_models: code_generation = CodeGenerationModel.from_pretrained(self.model) completion = code_generation.predict( @@ -105,7 +116,7 @@ def _generate_text(self, prompt: str) -> str: temperature=self.temperature, max_output_tokens=self.max_output_tokens, ) - else: + elif self.model in self._supported_text_models: text_generation = TextGenerationModel.from_pretrained(self.model) completion = text_generation.predict( @@ -115,6 +126,8 @@ def _generate_text(self, prompt: str) -> str: top_k=self.top_k, max_output_tokens=self.max_output_tokens, ) + else: + raise UnsupportedModelError("Unsupported model") return str(completion) diff --git a/pandasai/llm/openai.py b/pandasai/llm/openai.py index c4ebf0ff2..bd34ed770 100644 --- a/pandasai/llm/openai.py +++ b/pandasai/llm/openai.py @@ -14,7 +14,7 @@ import openai from ..helpers import load_dotenv -from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError +from ..exceptions import APIKeyNotFoundError, UnsupportedModelError from ..prompts.base import AbstractPrompt from .base import BaseOpenAI @@ -95,7 +95,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: suffix (str): Suffix to pass. Raises: - UnsupportedOpenAIModelError: Unsupported model + UnsupportedModelError: Unsupported model Returns: str: Response @@ -107,7 +107,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: elif self.model in self._supported_completion_models: response = self.completion(self.last_prompt) else: - raise UnsupportedOpenAIModelError("Unsupported model") + raise UnsupportedModelError("Unsupported model") return response diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index 62a8bffc3..13275baab 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -2,7 +2,7 @@ import openai import pytest -from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError +from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError from pandasai.llm import AzureOpenAI from openai.openai_object import OpenAIObject @@ -23,7 +23,7 @@ def test_type_without_api_version(self): AzureOpenAI(api_token="test", api_base="test") def test_type_without_deployment(self): - with pytest.raises(UnsupportedOpenAIModelError): + with pytest.raises(UnsupportedModelError): AzureOpenAI(api_token="test", api_base="test", api_version="test") def test_type_with_token(self): diff --git a/tests/llms/test_google_vertexai.py b/tests/llms/test_google_vertexai.py index d57873c3c..23bf136de 100644 --- a/tests/llms/test_google_vertexai.py +++ b/tests/llms/test_google_vertexai.py @@ -3,6 +3,7 @@ import pytest from pandasai.llm import GoogleVertexAI +from pandasai.exceptions import UnsupportedModelError class MockedCompletion: @@ -43,6 +44,11 @@ def test_validate_with_model(self, google_vertexai: GoogleVertexAI): google_vertexai.model = "text-bison@001" google_vertexai._validate() # Should not raise any errors + def test_validate_with_invalid_model(self, google_vertexai: GoogleVertexAI): + google_vertexai.model = "invalid-model" + with pytest.raises(UnsupportedModelError, match="Unsupported model"): + google_vertexai._generate_text("Test prompt") + def test_validate_without_model(self, google_vertexai: GoogleVertexAI): google_vertexai.model = None with pytest.raises(ValueError, match="model is required."): diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 18fa3ffa6..749a2fd6a 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -2,7 +2,7 @@ import openai import pytest -from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError +from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError from pandasai.llm import OpenAI from pandasai.prompts import AbstractPrompt from openai.openai_object import OpenAIObject @@ -107,7 +107,7 @@ def test_chat_completion(self, mocker): assert result == expected_response def test_call_with_unsupported_model(self, prompt): - with pytest.raises(UnsupportedOpenAIModelError): + with pytest.raises(UnsupportedModelError): llm = OpenAI(api_token="test", model="not a model") llm.call(instruction=prompt)