-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore(llm): major refactor of VertexAI models #636
Conversation
WalkthroughThe changes primarily revolve around the Changes
TipsChat with CodeRabbit Bot (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 3
Files selected for processing (2)
- pandasai/llm/google_vertexai.py (3 hunks)
- tests/llms/test_google_vertexai.py (3 hunks)
Additional comments (Suppressed): 4
pandasai/llm/google_vertexai.py (3)
19-33: The
init_vertexai
function is a new addition that initializes the VertexAI library with the provided project ID, location, and credentials. This function is a good abstraction as it separates the initialization logic from theGoogleVertexAI
class. However, it's important to ensure that thevertexai
library is thread-safe since it's being initialized globally. If it's not thread-safe, it could lead to issues when multiple instances ofGoogleVertexAI
are used in different threads.54-91: The
__init__
method of theGoogleVertexAI
class has been significantly refactored. Thecredentials
parameter has been added, and theinit_vertexai
function is now called to initialize the VertexAI library. The_configure
method has been removed, and its functionality has been incorporated intoinit_vertexai
. The model initialization logic has been simplified and now uses thefrom_pretrained
method of theCodeGenerationModel
andTextGenerationModel
classes directly. This simplification improves the readability and maintainability of the code.119-128: > Note: This review was outside of the patch, so it was mapped to the patch with the greatest overlap. Original lines [111-126]
The
_generate_text
method has been simplified to use theself.client.predict
method with parameters based on the model type. This simplification improves the readability and maintainability of the code. However, it's important to ensure that thepredict
method of theCodeGenerationModel
andTextGenerationModel
classes can handle all the parameters in_default_params
correctly.tests/llms/test_google_vertexai.py (1)
- 54-59: The test case for an invalid model is well written. It correctly asserts that an
UnsupportedModelError
is raised when an invalid model is passed to theGoogleVertexAI
constructor. No changes are needed here.
def test_init_with_default_model(self, mocker): | ||
mocker.patch( | ||
"vertexai.preview.language_models.TextGenerationModel.from_pretrained", | ||
return_value="Test", | ||
) | ||
project_id = "your_project_id" | ||
location = "northamerica-northeast1" | ||
vertexai_instance = GoogleVertexAI(project_id, location) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mocker.patch
method is used to mock the from_pretrained
method of the TextGenerationModel
class. This is a good practice for unit testing as it isolates the method being tested from external dependencies. However, it would be beneficial to verify that the mocked method is called with the expected arguments. This can be done by storing the mock object and asserting that it was called with the correct arguments.
mock_from_pretrained = mocker.patch(
"vertexai.preview.language_models.TextGenerationModel.from_pretrained",
return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with("default")
def test_init_with_custom_model(self, mocker): | ||
mocker.patch( | ||
"vertexai.preview.language_models.CodeGenerationModel.from_pretrained", | ||
return_value="Test", | ||
) | ||
project_id = "test-project" | ||
location = "northamerica-northeast1" | ||
custom_model = "code-bison@001" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, the mocker.patch
method is used to mock the from_pretrained
method of the CodeGenerationModel
class. Again, it would be beneficial to verify that the mocked method is called with the expected arguments.
mock_from_pretrained = mocker.patch(
"vertexai.preview.language_models.CodeGenerationModel.from_pretrained",
return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with(custom_model)
def test_validate_with_model(self, mocker): | ||
mocker.patch( | ||
"vertexai.preview.language_models.TextGenerationModel.from_pretrained", | ||
return_value="Test", | ||
) | ||
model = "text-bison@001" | ||
project_id = "test-project" | ||
location = "northamerica-northeast1" | ||
custom_model = "code-bison@001" | ||
return GoogleVertexAI(project_id, location, custom_model) | ||
|
||
def test_validate_with_model(self, google_vertexai: GoogleVertexAI): | ||
google_vertexai.model = "text-bison@001" | ||
google_vertexai._validate() # Should not raise any errors | ||
llm = GoogleVertexAI(project_id, location, model) | ||
llm._validate() # Should not raise any errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mocker.patch
method is used to mock the from_pretrained
method of the TextGenerationModel
class. The _validate
method is then called on the GoogleVertexAI
instance. It would be beneficial to verify that the mocked method is called with the expected arguments.
mock_from_pretrained = mocker.patch(
"vertexai.preview.language_models.TextGenerationModel.from_pretrained",
return_value="Test",
)
# ... rest of the test code ...
mock_from_pretrained.assert_called_once_with(model)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 1
Files selected for processing (2)
- pandasai/llm/google_vertexai.py (3 hunks)
- tests/llms/test_google_vertexai.py (3 hunks)
Additional comments (Suppressed): 10
tests/llms/test_google_vertexai.py (4)
15-22: The test
test_init_with_default_model
has been updated to include a mock for thefrom_pretrained
method of theTextGenerationModel
class. This change seems to be in line with the changes made to theGoogleVertexAI
class, where theinit_vertexai
function now calls this method. Ensure that the mock return value is representative of what the actual method would return.28-35: The test
test_init_with_custom_model
now includes a mock for thefrom_pretrained
method of theCodeGenerationModel
class. Similar to the previous comment, ensure that the mock return value is representative of what the actual method would return.43-52: The test
test_validate_with_model
has been updated to include a mock for thefrom_pretrained
method of theTextGenerationModel
class. This change seems to be in line with the changes made to theGoogleVertexAI
class, where theinit_vertexai
function now calls this method. Ensure that the mock return value is representative of what the actual method would return.54-65: The test
test_validate_with_invalid_model
has been updated to create a new instance ofGoogleVertexAI
with an invalid model, instead of modifying the model of an existing instance. This change seems to be in line with the changes made to theGoogleVertexAI
class, where theinit_vertexai
function now validates the model. Ensure that the exception message matches the one raised by theinit_vertexai
function.pandasai/llm/google_vertexai.py (6)
10-36: The import statement for
GoogleVertexAI
has been updated to reflect the new location of the class. Theinit_vertexai
function has been introduced to initialize VertexAI with the provided project ID, location, and credentials. This function is a good abstraction as it encapsulates the initialization logic and makes the code more modular.55-91: The
__init__
method now includes acredentials
parameter and a call to theinit_vertexai
function. The_configure
method has been removed, and its functionality has been integrated into theinit_vertexai
function. The model initialization has been moved from the_generate_text
method to the__init__
method, which simplifies the_generate_text
method and makes the code more efficient by initializing the model only once. Themodel
parameter is now optional and defaults totext-bison@001
if not provided. This change improves the usability of the class by providing a sensible default value.96-109: A
_default_params
property has been added to provide default parameters based on the model type. This property improves the readability and maintainability of the code by encapsulating the logic for determining the default parameters in one place.119-128: > Note: This review was outside of the patch, so it was mapped to the patch with the greatest overlap. Original lines [111-126]
The
_generate_text
method has been simplified to use theself.client.predict
method with parameters based on the model type. This change makes the method more concise and easier to understand.
55-91: The
model
parameter is now optional and defaults totext-bison@001
if not provided. Ensure that all calls to this function throughout the codebase have been updated to match the new signature.55-91: The
credentials
parameter has been added to the__init__
method. Ensure that all calls to this function throughout the codebase have been updated to match the new signature.
self, | ||
project_id: str, | ||
location: str, | ||
model: Optional[str] = None, | ||
credentials: Any = None, | ||
**kwargs | ||
): | ||
""" | ||
A init class to implement the Google Vertexai Models | ||
An init class to implement the Google Vertexai Models | ||
|
||
Args: | ||
project_id (str): GCP project | ||
location (str): GCP project Location | ||
model Optional (str): Model to use Default to text-bison@001 | ||
project_id (str): GCP project to use when making Vertex API calls | ||
location (str): GCP project location to use when making Vertex API calls | ||
model (str): VertexAI Large Language Model to use. Default to text-bison@001 | ||
credentials: The default custom credentials to use when making API calls. | ||
If not provided, credentials will be ascertained from the environment. | ||
**kwargs: Arguments to control the Model Parameters | ||
""" | ||
|
||
if model is None: | ||
self.model = "text-bison@001" | ||
else: | ||
init_vertexai(project_id, location, credentials) | ||
if model: | ||
self.model = model | ||
|
||
self._configure(project_id, location) | ||
if self.model in self._supported_code_models: | ||
from vertexai.preview.language_models import CodeGenerationModel | ||
|
||
self.client = CodeGenerationModel.from_pretrained(self.model) | ||
elif self.model in self._supported_text_models: | ||
from vertexai.preview.language_models import TextGenerationModel | ||
|
||
self.client = TextGenerationModel.from_pretrained(self.model) | ||
else: | ||
raise UnsupportedModelError(self.model) | ||
self.project_id = project_id | ||
self.location = location | ||
self._set_params(**kwargs) | ||
|
||
def _configure(self, project_id: str, location: str): | ||
""" | ||
Configure Google VertexAi. Set value `self.vertexai` attribute. | ||
|
||
Args: | ||
project_id (str): GCP Project. | ||
location (str): Location of Project. | ||
|
||
Returns: | ||
None. | ||
|
||
""" | ||
|
||
err_msg = "Install google-cloud-aiplatform for Google Vertexai" | ||
vertexai = import_dependency("vertexai", extra=err_msg) | ||
vertexai.init(project=project_id, location=location) | ||
self.vertexai = vertexai | ||
self._validate() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The UnsupportedModelError
is raised if the provided model
is not supported. This is a good practice as it provides a clear error message to the user. However, it would be helpful to include a list of supported models in the error message to guide the user.
- raise UnsupportedModelError(self.model)
+ raise UnsupportedModelError(f"{self.model} is not supported. Supported models are {self._supported_code_models + self._supported_text_models}")
Codecov Report
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. @@ Coverage Diff @@
## main #636 +/- ##
=======================================
Coverage 84.40% 84.40%
=======================================
Files 65 65
Lines 3026 3026
=======================================
Hits 2554 2554
Misses 472 472
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 0
Files selected for processing (1)
- tests/llms/test_google_vertexai.py (3 hunks)
Additional comments (Suppressed): 4
tests/llms/test_google_vertexai.py (4)
15-22: The test
test_init_with_default_model
has been updated to mock thefrom_pretrained
method of theTextGenerationModel
class. This change is likely due to the refactoring of theGoogleVertexAI
class, which now calls this method during initialization. Ensure that this mock correctly simulates the behavior of the actual method.28-35: The test
test_init_with_custom_model
has been updated to mock thefrom_pretrained
method of theCodeGenerationModel
class. This change is likely due to the refactoring of theGoogleVertexAI
class, which now calls this method during initialization when a custom model is provided. Ensure that this mock correctly simulates the behavior of the actual method.43-52: The test
test_validate_with_model
has been updated to mock thefrom_pretrained
method of theTextGenerationModel
class. This change is likely due to the refactoring of theGoogleVertexAI
class, which now calls this method during the_validate
method when a model is provided. Ensure that this mock correctly simulates the behavior of the actual method.54-65: The test
test_validate_with_invalid_model
has been updated to directly instantiate theGoogleVertexAI
class with an invalid model, instead of modifying themodel
attribute of an existing instance. This change is likely due to the refactoring of theGoogleVertexAI
class, which now validates the model during initialization. Ensure that this change correctly tests the new behavior of the class.
@mspronesti tried it but for some reason it doesn't work. It is like blocking at some point when we instantiate it. Haven't investigate further, happy to debug more if you can't reproduce it on your machine! |
@mspronesti I finally got an error: google.api_core.exceptions.RetryError: Deadline of 120.0s exceeded while calling target function, last exception: 503 Getting Here's how I instantiate it:
|
@gventuri I can't reproduce the error :/ Do you also have it with text models? |
@mspronesti same error with any model :( Are you instantiating in a similar way to what I do? |
@gventuri Do you happen to have a longer stack trace? |
@mspronesti not really, that's all I have. I can try to debug further and catch the whole traceback. I've retried with the old version and seems to work, so I'm quite sure it's being affected by some changes. I'll try to investigate further. |
I'm off for a few days. I will debug the issue next week unless you figure it out first :) |
@mspronesti sure, I'll try to look into it! :) |
Hi @gventuri,
this PR aims at refactoring VertexAI models. In particular, at only initializing the client once - not at every generation. Also, I'm allowing the user to explicitly pass the credentials for their GCP account. Lastly, I removed some unneeded checks and validations.
Summary by CodeRabbit
GoogleVertexAI
class, including initialization with custom credentials.GoogleVertexAI
class by using thepredict
method of the Vertex AI client.GoogleVertexAI
class to include mocking of model loading and validation of model types.GoogleVertexAI
class by raising an exception when an unsupported or no model is provided.