diff --git a/pandasai/llm/google_vertexai.py b/pandasai/llm/google_vertexai.py index 0a3c340bd..1643cd811 100644 --- a/pandasai/llm/google_vertexai.py +++ b/pandasai/llm/google_vertexai.py @@ -7,15 +7,31 @@ Example: Use below example to call Google VertexAI - >>> from pandasai.llm.google_palm import GoogleVertexAI + >>> from pandasai.llm import GoogleVertexAI """ -from typing import Optional +from typing import Optional, Dict, Any from .base import BaseGoogle from ..exceptions import UnsupportedModelError from ..helpers.optional import import_dependency +def init_vertexai( + project_id, location, credentials=None, +) -> None: + vertexai = import_dependency( + "vertexai", + extra="Could not import VertexAI. Please, install " + "it with `pip install google-cloud-aiplatform`" + ) + init_params = { + "project": project_id, + "location": location, + **({"credentials": credentials} if credentials is not None else {}) + } + vertexai.init(**init_params) + + class GoogleVertexAI(BaseGoogle): """Google Palm Vertexai LLM BaseGoogle class is extended for Google Palm model using Vertexai. @@ -33,62 +49,64 @@ class GoogleVertexAI(BaseGoogle): "text-bison-32k", "text-bison@001", ] + model: str = "text-bison@001" def __init__( - self, project_id: str, location: str, model: Optional[str] = None, **kwargs + self, + project_id: str, + location: str, + model: Optional[str] = None, + credentials: Any = None, + **kwargs ): """ - A init class to implement the Google Vertexai Models + An init class to implement the Google Vertexai Models Args: - project_id (str): GCP project - location (str): GCP project Location - model Optional (str): Model to use Default to text-bison@001 + project_id (str): GCP project to use when making Vertex API calls + location (str): GCP project location to use when making Vertex API calls + model (str): VertexAI Large Language Model to use. Default to text-bison@001 + credentials: The default custom credentials to use when making API calls. + If not provided, credentials will be ascertained from the environment. **kwargs: Arguments to control the Model Parameters """ - - if model is None: - self.model = "text-bison@001" - else: + init_vertexai(project_id, location, credentials) + if model: self.model = model - self._configure(project_id, location) + if self.model in self._supported_code_models: + from vertexai.preview.language_models import CodeGenerationModel + + self.client = CodeGenerationModel.from_pretrained(self.model) + elif self.model in self._supported_text_models: + from vertexai.preview.language_models import TextGenerationModel + + self.client = TextGenerationModel.from_pretrained(self.model) + else: + raise UnsupportedModelError(self.model) self.project_id = project_id self.location = location self._set_params(**kwargs) - - def _configure(self, project_id: str, location: str): - """ - Configure Google VertexAi. Set value `self.vertexai` attribute. - - Args: - project_id (str): GCP Project. - location (str): Location of Project. - - Returns: - None. - - """ - - err_msg = "Install google-cloud-aiplatform for Google Vertexai" - vertexai = import_dependency("vertexai", extra=err_msg) - vertexai.init(project=project_id, location=location) - self.vertexai = vertexai + self._validate() def _valid_params(self): """Returns if the Parameters are valid or Not""" return super()._valid_params() + ["model"] - def _validate(self): - """ - A method to Validate the Model - - """ - - super()._validate() - - if not self.model: - raise ValueError("model is required.") + @property + def _default_params(self) -> Dict[str, Any]: + if "code" in self.model: + return { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + } + else: + return { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + } def _generate_text(self, prompt: str) -> str: """ @@ -101,35 +119,10 @@ def _generate_text(self, prompt: str) -> str: str: LLM response. """ - self._validate() - - from vertexai.preview.language_models import ( - CodeGenerationModel, - TextGenerationModel, + completion = self.client.predict( + prompt, + **self._default_params ) - - if self.model in self._supported_code_models: - code_generation = CodeGenerationModel.from_pretrained(self.model) - - completion = code_generation.predict( - prefix=prompt, - temperature=self.temperature, - max_output_tokens=self.max_output_tokens, - ) - elif self.model in self._supported_text_models: - text_generation = TextGenerationModel.from_pretrained(self.model) - - completion = text_generation.predict( - prompt=prompt, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - max_output_tokens=self.max_output_tokens, - ) - else: - raise UnsupportedModelError(self.model) - - return str(completion) @property diff --git a/tests/llms/test_google_vertexai.py b/tests/llms/test_google_vertexai.py index ff477bc8d..def187ce5 100644 --- a/tests/llms/test_google_vertexai.py +++ b/tests/llms/test_google_vertexai.py @@ -12,7 +12,11 @@ def __init__(self, result: str): class TestGoogleVertexAI: - def test_init_with_default_model(self): + def test_init_with_default_model(self, mocker): + mocker.patch( + "vertexai.preview.language_models.TextGenerationModel.from_pretrained", + return_value="Test", + ) project_id = "your_project_id" location = "northamerica-northeast1" vertexai_instance = GoogleVertexAI(project_id, location) @@ -21,7 +25,11 @@ def test_init_with_default_model(self): assert vertexai_instance.project_id == project_id assert vertexai_instance.location == location - def test_init_with_custom_model(self): + def test_init_with_custom_model(self, mocker): + mocker.patch( + "vertexai.preview.language_models.CodeGenerationModel.from_pretrained", + return_value="Test", + ) project_id = "test-project" location = "northamerica-northeast1" custom_model = "code-bison@001" @@ -32,30 +40,26 @@ def test_init_with_custom_model(self): assert vertexai_instance.project_id == project_id assert vertexai_instance.location == location - @pytest.fixture - def google_vertexai(self): - # Create an instance of YourClass for testing + def test_validate_with_model(self, mocker): + mocker.patch( + "vertexai.preview.language_models.TextGenerationModel.from_pretrained", + return_value="Test", + ) + model = "text-bison@001" project_id = "test-project" location = "northamerica-northeast1" - custom_model = "code-bison@001" - return GoogleVertexAI(project_id, location, custom_model) - - def test_validate_with_model(self, google_vertexai: GoogleVertexAI): - google_vertexai.model = "text-bison@001" - google_vertexai._validate() # Should not raise any errors + llm = GoogleVertexAI(project_id, location, model) + llm._validate() # Should not raise any errors - def test_validate_with_invalid_model(self, google_vertexai: GoogleVertexAI): - google_vertexai.model = "invalid-model" + def test_validate_with_invalid_model(self): + model = "invalid_model" + project_id = "test-project" + location = "northamerica-northeast1" with pytest.raises( UnsupportedModelError, match=( - "Unsupported model: The model 'invalid-model' doesn't exist " + "Unsupported model: The model 'invalid_model' doesn't exist " "or is not supported yet." ), ): - google_vertexai._generate_text("Test prompt") - - def test_validate_without_model(self, google_vertexai: GoogleVertexAI): - google_vertexai.model = None - with pytest.raises(ValueError, match="model is required."): - google_vertexai._validate() + GoogleVertexAI(project_id, location, model)