Skip to content

Commit

Permalink
BREAKING_CHANGE: [vertexai] make client getters in VertexAI private (#…
Browse files Browse the repository at this point in the history
…10550)

PiperOrigin-RevId: 616185850

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Mar 18, 2024
1 parent 840df96 commit e8994c7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.vertexai;

import com.google.api.core.InternalApi;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.core.GaxProperties;
Expand Down Expand Up @@ -220,11 +221,30 @@ public void setApiEndpoint(String apiEndpoint) {
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC or REST, based on the Transport type. The
* client will be instantiated when the first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send requests to the backing service through
* method calls that map to the API methods.
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getPredictionServiceGrpcClient();
} else {
return getPredictionServiceRestClient();
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
* method calls that map to the API methods.
*/
public PredictionServiceClient getPredictionServiceClient() throws IOException {
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand Down Expand Up @@ -257,7 +277,7 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
* method calls that map to the API methods.
*/
public PredictionServiceClient getPredictionServiceRestClient() throws IOException {
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
Expand All @@ -284,14 +304,30 @@ public PredictionServiceClient getPredictionServiceRestClient() throws IOExcepti
return predictionServiceRestClient;
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC or REST, based on the Transport type. The
* client will be instantiated when the first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes calls to the backing service through method
* calls that map to the API methods.
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getLlmUtilityGrpcClient();
} else {
return getLlmUtilityRestClient();
}
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
* method calls that map to the API methods.
*/
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand Down Expand Up @@ -319,12 +355,12 @@ public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {

/**
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
* first prediction API call is made.
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
* method calls that map to the API methods.
*/
public LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.google.cloud.vertexai.generativeai;

import com.google.api.core.BetaApi;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
Expand Down Expand Up @@ -289,11 +288,7 @@ public CountTokensResponse countTokens(List<Content> contents) throws IOExceptio
@BetaApi
private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getLlmUtilityRestClient().countTokens(request);
} else {
return vertexAi.getLlmUtilityClient().countTokens(request);
}
return vertexAi.getLlmUtilityClient().countTokens(request);
}

/**
Expand Down Expand Up @@ -520,11 +515,7 @@ public GenerateContentResponse generateContent(
*/
private GenerateContentResponse generateContent(GenerateContentRequest request)
throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request);
} else {
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
}
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
}

/**
Expand Down Expand Up @@ -932,23 +923,13 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
*/
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest request) throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceRestClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
} else {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
}
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
}

/**
Expand Down

0 comments on commit e8994c7

Please sign in to comment.