diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index a4a699bcc5a2..f69f1b38a0e6 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -32,6 +32,7 @@ import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.List; import java.util.concurrent.locks.ReentrantLock; @@ -56,9 +57,10 @@ public class VertexAI implements AutoCloseable { private final String projectId; private final String location; - private String apiEndpoint; - private CredentialsProvider credentialsProvider = null; - private Transport transport = Transport.GRPC; + private final String apiEndpoint; + private final Transport transport; + // Will be null if the user doesn't provide Credentials or scopes + private final CredentialsProvider credentialsProvider; // The clients will be instantiated lazily private PredictionServiceClient predictionServiceClient = null; private LlmUtilityServiceClient llmUtilityClient = null; @@ -74,79 +76,104 @@ public VertexAI(String projectId, String location) { this.projectId = projectId; this.location = location; this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location); + this.transport = Transport.GRPC; + this.credentialsProvider = null; } - /** - * Construct a VertexAI instance with default transport layer. - * - * @param projectId the default project to use when making API calls - * @param location the default location to use when making API calls - * @param transport the default {@link Transport} layer to use to send API requests - */ - public VertexAI(String projectId, String location, Transport transport) { - this(projectId, location); + private VertexAI( + String projectId, + String location, + String apiEndpoint, + Transport transport, + Credentials credentials, + List scopes) { + if (!scopes.isEmpty() && credentials != null) { + throw new IllegalArgumentException( + "At most one of Credentials and scopes should be specified."); + } + checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty"); + checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty"); + checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "apiEndpoint can't be null or empty"); + checkNotNull(transport, "transport can't be null"); + + this.projectId = projectId; + this.location = location; + this.apiEndpoint = apiEndpoint; this.transport = transport; + if (credentials != null) { + this.credentialsProvider = FixedCredentialsProvider.create(credentials); + } else { + this.credentialsProvider = + scopes.size() == 0 + ? null + : GoogleCredentialsProvider.newBuilder() + .setScopesToApply(scopes) + .setUseJwtAccessWithScope(true) + .build(); + } } - /** - * Construct a VertexAI instance with custom credentials. - * - * @param projectId the default project to use when making API calls - * @param location the default location to use when making API calls - * @param credentials the custom credentials to use when making API calls - */ - public VertexAI(String projectId, String location, Credentials credentials) { - this(projectId, location); - this.credentialsProvider = FixedCredentialsProvider.create(credentials); - } + /** Builder for {@link VertexAI}. */ + public static class Builder { + private String projectId; + private String location; + private Credentials credentials; + private String apiEndpoint; + private Transport transport = Transport.GRPC; + private ImmutableList scopes = ImmutableList.of(); - /** - * Construct a VertexAI instance with default transport layer and custom credentials. - * - * @param projectId the default project to use when making API calls - * @param location the default location to use when making API calls - * @param transport the default {@link Transport} layer to use to send API requests - * @param credentials the default custom credentials to use when making API calls - */ - public VertexAI(String projectId, String location, Transport transport, Credentials credentials) { - this(projectId, location, credentials); - this.transport = transport; - } + public VertexAI build() { + checkNotNull(projectId, "projectId must be set."); + checkNotNull(location, "location must be set."); + // Default ApiEndpoint is set here as we need to make sure location is set. + if (apiEndpoint == null) { + apiEndpoint = String.format("%s-aiplatform.googleapis.com", location); + } - /** - * Construct a VertexAI instance with application default credentials. - * - * @param projectId the default project to use when making API calls - * @param location the default location to use when making API calls - * @param scopes List of scopes in the default credentials. Make sure you have specified - * "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI. - */ - public VertexAI(String projectId, String location, List scopes) throws IOException { - this(projectId, location); - - CredentialsProvider credentialsProvider = - scopes.size() == 0 - ? null - : GoogleCredentialsProvider.newBuilder() - .setScopesToApply(scopes) - .setUseJwtAccessWithScope(true) - .build(); - this.credentialsProvider = credentialsProvider; - } + return new VertexAI(projectId, location, apiEndpoint, transport, credentials, scopes); + } - /** - * Construct a VertexAI instance with default transport layer and application default credentials. - * - * @param projectId the default project to use when making API calls - * @param location the default location to use when making API calls - * @param transport the default {@link Transport} layer to use to send API requests - * @param scopes List of scopes in the default credentials. Make sure you have specified - * "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI. - */ - public VertexAI(String projectId, String location, Transport transport, List scopes) - throws IOException { - this(projectId, location, scopes); - this.transport = transport; + public Builder setProjectId(String projectId) { + checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty"); + + this.projectId = projectId; + return this; + } + + public Builder setLocation(String location) { + checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty"); + + this.location = location; + return this; + } + + public Builder setApiEndpoint(String apiEndpoint) { + checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "apiEndpoint can't be null or empty"); + + this.apiEndpoint = apiEndpoint; + return this; + } + + public Builder setTransport(Transport transport) { + checkNotNull(transport, "transport can't be null"); + + this.transport = transport; + return this; + } + + public Builder setCredentials(Credentials credentials) { + checkNotNull(credentials, "credentials can't be null"); + + this.credentials = credentials; + return this; + } + + public Builder setScopes(List scopes) { + checkNotNull(scopes, "scopes can't be null"); + + this.scopes = ImmutableList.copyOf(scopes); + return this; + } } /** @@ -155,7 +182,7 @@ public VertexAI(String projectId, String location, Transport transport, List vertexAi.getCredentials()); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "Either Credentials or scopes needs to be provided while instantiating VertexAI."); } @Test - public void testCustomEndpointInVertexAI() throws IOException { + public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFields() + throws IOException { + vertexAi = + new VertexAI.Builder() + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .setCredentials(mockGoogleCredentials) + .build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT); + assertThat(vertexAi.getCredentials()).isEqualTo(mockGoogleCredentials); + } + + @Test + public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException() + throws IOException { + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + new VertexAI.Builder() + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .setCredentials(mockGoogleCredentials) + .setScopes(ImmutableList.of("test_scope")) + .build()); + assertThat(thrown) + .hasMessageThat() + .isEqualTo("At most one of Credentials and scopes should be specified."); + } + + @Test + public void testInstantiateVertexAI_builderWithEndpoint_shouldContainRightFields() + throws IOException { try (MockedStatic mockStatic = mockStatic(PredictionServiceClient.class)) { mockStatic .when(() -> PredictionServiceClient.create(any(PredictionServiceSettings.class))) .thenReturn(mockPredictionServiceClient); - vertexAi = new VertexAI(TEST_PROJECT, TEST_LOCATION); - vertexAi.setApiEndpoint(TEST_ENDPOINT); + vertexAi = + new VertexAI.Builder() + .setApiEndpoint(TEST_ENDPOINT) + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .build(); + PredictionServiceClient unused = vertexAi.getPredictionServiceClient(); ArgumentCaptor settings = @@ -81,25 +129,25 @@ public void testCustomEndpointInVertexAI() throws IOException { } @Test - public void testSetApiEndpoint() throws IOException { - try (MockedStatic mockStatic = mockStatic(PredictionServiceClient.class)) { - mockStatic - .when(() -> PredictionServiceClient.create(any(PredictionServiceSettings.class))) - .thenReturn(mockPredictionServiceClient); - - vertexAi = new VertexAI(TEST_PROJECT, TEST_LOCATION); - PredictionServiceClient unused = vertexAi.getPredictionServiceClient(); + public void testInstantiateVertexAI_builderWithTransport_shouldContainRightFields() + throws IOException { - ArgumentCaptor settings = - ArgumentCaptor.forClass(PredictionServiceSettings.class); - mockStatic.verify(() -> PredictionServiceClient.create(settings.capture())); + vertexAi = + new VertexAI.Builder() + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .setTransport(Transport.REST) + .build(); - assertThat(settings.getValue().getEndpoint()) - .isEqualTo(String.format("%s:443", TEST_DEFAULT_ENDPOINT)); - - // After setting a new endpoint, clients should be closed and reset. - vertexAi.setApiEndpoint(TEST_ENDPOINT); - verify(mockPredictionServiceClient).close(); - } + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.REST); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT); + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> vertexAi.getCredentials()); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "Either Credentials or scopes needs to be provided while instantiating VertexAI."); } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 7b05548fc169..84f8d0b61fa5 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -144,7 +144,12 @@ public final class GenerativeModelTest { @Before public void doBeforeEachTest() { - vertexAi = new VertexAI(PROJECT, LOCATION, mockGoogleCredentials); + vertexAi = + new VertexAI.Builder() + .setProjectId(PROJECT) + .setLocation(LOCATION) + .setCredentials(mockGoogleCredentials) + .build(); } @Test