Skip to content

Commit

Permalink
BREAKING_CHANGE: [vertexai] Change VertexAI to Builder pattern and re…
Browse files Browse the repository at this point in the history
…move setters. (#10600)

PiperOrigin-RevId: 617935686

Co-authored-by: Zhenyi Qi <[email protected]>
  • Loading branch information
copybara-service[bot] and Zhenyi Qi authored Mar 22, 2024
1 parent 17b01c6 commit 42e806e
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
Expand All @@ -56,9 +57,10 @@ public class VertexAI implements AutoCloseable {

private final String projectId;
private final String location;
private String apiEndpoint;
private CredentialsProvider credentialsProvider = null;
private Transport transport = Transport.GRPC;
private final String apiEndpoint;
private final Transport transport;
// Will be null if the user doesn't provide Credentials or scopes
private final CredentialsProvider credentialsProvider;
// The clients will be instantiated lazily
private PredictionServiceClient predictionServiceClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
Expand All @@ -74,79 +76,104 @@ public VertexAI(String projectId, String location) {
this.projectId = projectId;
this.location = location;
this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location);
this.transport = Transport.GRPC;
this.credentialsProvider = null;
}

/**
* Construct a VertexAI instance with default transport layer.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param transport the default {@link Transport} layer to use to send API requests
*/
public VertexAI(String projectId, String location, Transport transport) {
this(projectId, location);
private VertexAI(
String projectId,
String location,
String apiEndpoint,
Transport transport,
Credentials credentials,
List<String> scopes) {
if (!scopes.isEmpty() && credentials != null) {
throw new IllegalArgumentException(
"At most one of Credentials and scopes should be specified.");
}
checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty");
checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty");
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "apiEndpoint can't be null or empty");
checkNotNull(transport, "transport can't be null");

this.projectId = projectId;
this.location = location;
this.apiEndpoint = apiEndpoint;
this.transport = transport;
if (credentials != null) {
this.credentialsProvider = FixedCredentialsProvider.create(credentials);
} else {
this.credentialsProvider =
scopes.size() == 0
? null
: GoogleCredentialsProvider.newBuilder()
.setScopesToApply(scopes)
.setUseJwtAccessWithScope(true)
.build();
}
}

/**
* Construct a VertexAI instance with custom credentials.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param credentials the custom credentials to use when making API calls
*/
public VertexAI(String projectId, String location, Credentials credentials) {
this(projectId, location);
this.credentialsProvider = FixedCredentialsProvider.create(credentials);
}
/** Builder for {@link VertexAI}. */
public static class Builder {
private String projectId;
private String location;
private Credentials credentials;
private String apiEndpoint;
private Transport transport = Transport.GRPC;
private ImmutableList<String> scopes = ImmutableList.of();

/**
* Construct a VertexAI instance with default transport layer and custom credentials.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param transport the default {@link Transport} layer to use to send API requests
* @param credentials the default custom credentials to use when making API calls
*/
public VertexAI(String projectId, String location, Transport transport, Credentials credentials) {
this(projectId, location, credentials);
this.transport = transport;
}
public VertexAI build() {
checkNotNull(projectId, "projectId must be set.");
checkNotNull(location, "location must be set.");
// Default ApiEndpoint is set here as we need to make sure location is set.
if (apiEndpoint == null) {
apiEndpoint = String.format("%s-aiplatform.googleapis.com", location);
}

/**
* Construct a VertexAI instance with application default credentials.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param scopes List of scopes in the default credentials. Make sure you have specified
* "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI.
*/
public VertexAI(String projectId, String location, List<String> scopes) throws IOException {
this(projectId, location);

CredentialsProvider credentialsProvider =
scopes.size() == 0
? null
: GoogleCredentialsProvider.newBuilder()
.setScopesToApply(scopes)
.setUseJwtAccessWithScope(true)
.build();
this.credentialsProvider = credentialsProvider;
}
return new VertexAI(projectId, location, apiEndpoint, transport, credentials, scopes);
}

/**
* Construct a VertexAI instance with default transport layer and application default credentials.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param transport the default {@link Transport} layer to use to send API requests
* @param scopes List of scopes in the default credentials. Make sure you have specified
* "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI.
*/
public VertexAI(String projectId, String location, Transport transport, List<String> scopes)
throws IOException {
this(projectId, location, scopes);
this.transport = transport;
public Builder setProjectId(String projectId) {
checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty");

this.projectId = projectId;
return this;
}

public Builder setLocation(String location) {
checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty");

this.location = location;
return this;
}

public Builder setApiEndpoint(String apiEndpoint) {
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "apiEndpoint can't be null or empty");

this.apiEndpoint = apiEndpoint;
return this;
}

public Builder setTransport(Transport transport) {
checkNotNull(transport, "transport can't be null");

this.transport = transport;
return this;
}

public Builder setCredentials(Credentials credentials) {
checkNotNull(credentials, "credentials can't be null");

this.credentials = credentials;
return this;
}

public Builder setScopes(List<String> scopes) {
checkNotNull(scopes, "scopes can't be null");

this.scopes = ImmutableList.copyOf(scopes);
return this;
}
}

/**
Expand All @@ -155,7 +182,7 @@ public VertexAI(String projectId, String location, Transport transport, List<Str
* @return {@link Transport} layer used when sending API requests.
*/
public Transport getTransport() {
return this.transport;
return transport;
}

/**
Expand All @@ -164,7 +191,7 @@ public Transport getTransport() {
* @return Project ID in string format.
*/
public String getProjectId() {
return this.projectId;
return projectId;
}

/**
Expand All @@ -173,7 +200,7 @@ public String getProjectId() {
* @return Location in string format.
*/
public String getLocation() {
return this.location;
return location;
}

/**
Expand All @@ -182,7 +209,7 @@ public String getLocation() {
* @return API endpoint in string format.
*/
public String getApiEndpoint() {
return this.apiEndpoint;
return apiEndpoint;
}

/**
Expand All @@ -192,40 +219,13 @@ public String getApiEndpoint() {
* VertexAI object.
*/
public Credentials getCredentials() throws IOException {
return credentialsProvider.getCredentials();
}

/** Sets the value for {@link #getTransport()}. */
public void setTransport(Transport transport) {
checkNotNull(transport, "Transport can't be null.");
if (this.transport == transport) {
return;
}

this.transport = transport;
resetClients();
}

/** Sets the value for {@link #getApiEndpoint()}. */
public void setApiEndpoint(String apiEndpoint) {
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
if (this.apiEndpoint == apiEndpoint) {
return;
}
this.apiEndpoint = apiEndpoint;
resetClients();
}

private void resetClients() {
if (this.predictionServiceClient != null) {
this.predictionServiceClient.close();
this.predictionServiceClient = null;
}

if (this.llmUtilityClient != null) {
this.llmUtilityClient.close();
this.llmUtilityClient = null;
// TODO(b/330780087): support getCredentials() when default credentials (no user provided
// credentials or scopes) are used.
if (credentialsProvider == null) {
throw new IllegalStateException(
"Either Credentials or scopes needs to be provided while instantiating VertexAI.");
}
return credentialsProvider.getCredentials();
}

/**
Expand Down
Loading

0 comments on commit 42e806e

Please sign in to comment.