Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(llm): major refactor of VertexAI models #636

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 61 additions & 67 deletions pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,31 @@
Example:
Use below example to call Google VertexAI

>>> from pandasai.llm.google_palm import GoogleVertexAI
>>> from pandasai.llm import GoogleVertexAI

"""
from typing import Optional
from typing import Optional, Dict, Any
from .base import BaseGoogle
from ..exceptions import UnsupportedModelError
from ..helpers.optional import import_dependency


def init_vertexai(
project_id, location, credentials=None,
) -> None:
vertexai = import_dependency(
"vertexai",
extra="Could not import VertexAI. Please, install "
"it with `pip install google-cloud-aiplatform`"
)
init_params = {
"project": project_id,
"location": location,
**({"credentials": credentials} if credentials is not None else {})
}
vertexai.init(**init_params)


class GoogleVertexAI(BaseGoogle):
"""Google Palm Vertexai LLM
BaseGoogle class is extended for Google Palm model using Vertexai.
Expand All @@ -33,62 +49,64 @@ class GoogleVertexAI(BaseGoogle):
"text-bison-32k",
"text-bison@001",
]
model: str = "text-bison@001"

def __init__(
self, project_id: str, location: str, model: Optional[str] = None, **kwargs
self,
project_id: str,
location: str,
model: Optional[str] = None,
credentials: Any = None,
**kwargs
):
"""
A init class to implement the Google Vertexai Models
An init class to implement the Google Vertexai Models

Args:
project_id (str): GCP project
location (str): GCP project Location
model Optional (str): Model to use Default to text-bison@001
project_id (str): GCP project to use when making Vertex API calls
location (str): GCP project location to use when making Vertex API calls
model (str): VertexAI Large Language Model to use. Default to text-bison@001
credentials: The default custom credentials to use when making API calls.
If not provided, credentials will be ascertained from the environment.
**kwargs: Arguments to control the Model Parameters
"""

if model is None:
self.model = "text-bison@001"
else:
init_vertexai(project_id, location, credentials)
if model:
self.model = model

self._configure(project_id, location)
if self.model in self._supported_code_models:
from vertexai.preview.language_models import CodeGenerationModel

self.client = CodeGenerationModel.from_pretrained(self.model)
elif self.model in self._supported_text_models:
from vertexai.preview.language_models import TextGenerationModel

self.client = TextGenerationModel.from_pretrained(self.model)
else:
raise UnsupportedModelError("Unsupported model")
self.project_id = project_id
self.location = location
self._set_params(**kwargs)

def _configure(self, project_id: str, location: str):
"""
Configure Google VertexAi. Set value `self.vertexai` attribute.

Args:
project_id (str): GCP Project.
location (str): Location of Project.

Returns:
None.

"""

err_msg = "Install google-cloud-aiplatform for Google Vertexai"
vertexai = import_dependency("vertexai", extra=err_msg)
vertexai.init(project=project_id, location=location)
self.vertexai = vertexai
self._validate()

def _valid_params(self):
"""Returns if the Parameters are valid or Not"""
return super()._valid_params() + ["model"]

def _validate(self):
"""
A method to Validate the Model

"""

super()._validate()

if not self.model:
raise ValueError("model is required.")
@property
def _default_params(self) -> Dict[str, Any]:
if "code" in self.model:
return {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
}
else:
return {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}

def _generate_text(self, prompt: str) -> str:
"""
Expand All @@ -101,34 +119,10 @@ def _generate_text(self, prompt: str) -> str:
str: LLM response.

"""
self._validate()

from vertexai.preview.language_models import (
CodeGenerationModel,
TextGenerationModel,
completion = self.client.predict(
prompt,
**self._default_params
)

if self.model in self._supported_code_models:
code_generation = CodeGenerationModel.from_pretrained(self.model)

completion = code_generation.predict(
prefix=prompt,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_text_models:
text_generation = TextGenerationModel.from_pretrained(self.model)

completion = text_generation.predict(
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
else:
raise UnsupportedModelError("Unsupported model")

return str(completion)

@property
Expand Down
41 changes: 23 additions & 18 deletions tests/llms/test_google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def __init__(self, result: str):


class TestGoogleVertexAI:
def test_init_with_default_model(self):
def test_init_with_default_model(self, mocker):
mocker.patch(
"vertexai.preview.language_models.TextGenerationModel.from_pretrained",
return_value="Test",
)
project_id = "your_project_id"
location = "northamerica-northeast1"
vertexai_instance = GoogleVertexAI(project_id, location)
Comment on lines +15 to 22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mocker.patch method is used to mock the from_pretrained method of the TextGenerationModel class. This is a good practice for unit testing as it isolates the method being tested from external dependencies. However, it would be beneficial to verify that the mocked method is called with the expected arguments. This can be done by storing the mock object and asserting that it was called with the correct arguments.

mock_from_pretrained = mocker.patch(
    "vertexai.preview.language_models.TextGenerationModel.from_pretrained",
    return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with("default")

Expand All @@ -21,7 +25,11 @@ def test_init_with_default_model(self):
assert vertexai_instance.project_id == project_id
assert vertexai_instance.location == location

def test_init_with_custom_model(self):
def test_init_with_custom_model(self, mocker):
mocker.patch(
"vertexai.preview.language_models.CodeGenerationModel.from_pretrained",
return_value="Test",
)
project_id = "test-project"
location = "northamerica-northeast1"
custom_model = "code-bison@001"
Comment on lines +28 to 35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, the mocker.patch method is used to mock the from_pretrained method of the CodeGenerationModel class. Again, it would be beneficial to verify that the mocked method is called with the expected arguments.

mock_from_pretrained = mocker.patch(
    "vertexai.preview.language_models.CodeGenerationModel.from_pretrained",
    return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with(custom_model)

Expand All @@ -32,24 +40,21 @@ def test_init_with_custom_model(self):
assert vertexai_instance.project_id == project_id
assert vertexai_instance.location == location

@pytest.fixture
def google_vertexai(self):
# Create an instance of YourClass for testing
def test_validate_with_model(self, mocker):
mocker.patch(
"vertexai.preview.language_models.TextGenerationModel.from_pretrained",
return_value="Test",
)
model = "text-bison@001"
project_id = "test-project"
location = "northamerica-northeast1"
custom_model = "code-bison@001"
return GoogleVertexAI(project_id, location, custom_model)

def test_validate_with_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = "text-bison@001"
google_vertexai._validate() # Should not raise any errors
llm = GoogleVertexAI(project_id, location, model)
llm._validate() # Should not raise any errors
Comment on lines +43 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mocker.patch method is used to mock the from_pretrained method of the TextGenerationModel class. The _validate method is then called on the GoogleVertexAI instance. It would be beneficial to verify that the mocked method is called with the expected arguments.

mock_from_pretrained = mocker.patch(
    "vertexai.preview.language_models.TextGenerationModel.from_pretrained",
    return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with(model)


def test_validate_with_invalid_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = "invalid-model"
def test_validate_with_invalid_model(self):
model = "invalid_model"
project_id = "test-project"
location = "northamerica-northeast1"
with pytest.raises(UnsupportedModelError, match="Unsupported model"):
google_vertexai._generate_text("Test prompt")
GoogleVertexAI(project_id, location, model)

def test_validate_without_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = None
with pytest.raises(ValueError, match="model is required."):
google_vertexai._validate()