Skip to content

Commit

Permalink
feat: support more palm models (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri authored Oct 10, 2023
1 parent 3813663 commit 720811f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ class MethodNotImplementedError(Exception):
"""


class UnsupportedOpenAIModelError(Exception):
class UnsupportedModelError(Exception):
"""
Raised when an unsupported OpenAI model is used.
Args:
Exception (Exception): UnsupportedOpenAIModelError
Exception (Exception): UnsupportedModelError
"""


Expand Down
4 changes: 2 additions & 2 deletions pandasai/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import openai
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from ..exceptions import APIKeyNotFoundError, UnsupportedModelError
from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI

Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
openai.api_type = self.api_type

if deployment_name is None:
raise UnsupportedOpenAIModelError("Model deployment name is required.")
raise UnsupportedModelError("Model deployment name is required.")

self.is_chat_model = is_chat_model
self.engine = deployment_name
Expand Down
19 changes: 16 additions & 3 deletions pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
from typing import Optional
from .base import BaseGoogle
from ..exceptions import UnsupportedModelError
from ..helpers.optional import import_dependency


Expand All @@ -20,9 +21,19 @@ class GoogleVertexAI(BaseGoogle):
BaseGoogle class is extended for Google Palm model using Vertexai.
The default model support at the moment is text-bison-001.
However, user can choose to use code-bison-001 too.
"""

_supported_code_models = [
"code-bison",
"code-bison-32k",
"code-bison@001",
]
_supported_text_models = [
"text-bison",
"text-bison-32k",
"text-bison@001",
]

def __init__(
self, project_id: str, location: str, model: Optional[str] = None, **kwargs
):
Expand Down Expand Up @@ -97,15 +108,15 @@ def _generate_text(self, prompt: str) -> str:
TextGenerationModel,
)

if self.model == "code-bison@001":
if self.model in self._supported_code_models:
code_generation = CodeGenerationModel.from_pretrained(self.model)

completion = code_generation.predict(
prefix=prompt,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
)
else:
elif self.model in self._supported_text_models:
text_generation = TextGenerationModel.from_pretrained(self.model)

completion = text_generation.predict(
Expand All @@ -115,6 +126,8 @@ def _generate_text(self, prompt: str) -> str:
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
else:
raise UnsupportedModelError("Unsupported model")

return str(completion)

Expand Down
6 changes: 3 additions & 3 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import openai
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from ..exceptions import APIKeyNotFoundError, UnsupportedModelError
from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI

Expand Down Expand Up @@ -95,7 +95,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
suffix (str): Suffix to pass.
Raises:
UnsupportedOpenAIModelError: Unsupported model
UnsupportedModelError: Unsupported model
Returns:
str: Response
Expand All @@ -107,7 +107,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
elif self.model in self._supported_completion_models:
response = self.completion(self.last_prompt)
else:
raise UnsupportedOpenAIModelError("Unsupported model")
raise UnsupportedModelError("Unsupported model")

return response

Expand Down
4 changes: 2 additions & 2 deletions tests/llms/test_azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import openai
import pytest

from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError
from pandasai.llm import AzureOpenAI
from openai.openai_object import OpenAIObject

Expand All @@ -23,7 +23,7 @@ def test_type_without_api_version(self):
AzureOpenAI(api_token="test", api_base="test")

def test_type_without_deployment(self):
with pytest.raises(UnsupportedOpenAIModelError):
with pytest.raises(UnsupportedModelError):
AzureOpenAI(api_token="test", api_base="test", api_version="test")

def test_type_with_token(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/llms/test_google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from pandasai.llm import GoogleVertexAI
from pandasai.exceptions import UnsupportedModelError


class MockedCompletion:
Expand Down Expand Up @@ -43,6 +44,11 @@ def test_validate_with_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = "text-bison@001"
google_vertexai._validate() # Should not raise any errors

def test_validate_with_invalid_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = "invalid-model"
with pytest.raises(UnsupportedModelError, match="Unsupported model"):
google_vertexai._generate_text("Test prompt")

def test_validate_without_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = None
with pytest.raises(ValueError, match="model is required."):
Expand Down
4 changes: 2 additions & 2 deletions tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import openai
import pytest

from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError
from pandasai.llm import OpenAI
from pandasai.prompts import AbstractPrompt
from openai.openai_object import OpenAIObject
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_chat_completion(self, mocker):
assert result == expected_response

def test_call_with_unsupported_model(self, prompt):
with pytest.raises(UnsupportedOpenAIModelError):
with pytest.raises(UnsupportedModelError):
llm = OpenAI(api_token="test", model="not a model")
llm.call(instruction=prompt)

Expand Down

0 comments on commit 720811f

Please sign in to comment.