Skip to content

Commit

Permalink
chore: [vertexai] Make instantiation of clients thread safe.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615418373
  • Loading branch information
Zhenyi Qi authored and copybara-github committed Mar 21, 2024
1 parent 8bc8adb commit c680361
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand Down Expand Up @@ -59,6 +60,7 @@ public class VertexAI implements AutoCloseable {
private PredictionServiceClient predictionServiceRestClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
private LlmUtilityServiceClient llmUtilityRestClient = null;
private final ReentrantLock lock = new ReentrantLock();

/**
* Construct a VertexAI instance.
Expand Down Expand Up @@ -245,7 +247,11 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
if (predictionServiceClient == null) {
if (this.predictionServiceClient != null) {
return this.predictionServiceClient;
}
lock.lock();
try {
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
Expand All @@ -266,8 +272,10 @@ private PredictionServiceClient getPredictionServiceGrpcClient() throws IOExcept
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return predictionServiceClient;
} finally {
lock.unlock();
}
return predictionServiceClient;
}

/**
Expand All @@ -278,7 +286,11 @@ private PredictionServiceClient getPredictionServiceGrpcClient() throws IOExcept
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
if (predictionServiceClient != null) {
return predictionServiceClient;
}
lock.lock();
try {
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand All @@ -300,8 +312,10 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return predictionServiceRestClient;
} finally {
lock.unlock();
}
return predictionServiceRestClient;
}

/**
Expand All @@ -328,7 +342,11 @@ public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
if (llmUtilityClient == null) {
if (llmUtilityClient != null) {
return llmUtilityClient;
}
lock.lock();
try {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
Expand All @@ -349,8 +367,10 @@ private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return llmUtilityClient;
} finally {
lock.unlock();
}
return llmUtilityClient;
}

/**
Expand All @@ -361,7 +381,11 @@ private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
if (llmUtilityClient != null) {
return llmUtilityClient;
}
lock.lock();
try {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand All @@ -383,8 +407,10 @@ private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return llmUtilityRestClient;
} finally {
lock.unlock();
}
return llmUtilityRestClient;
}

/** Closes the VertexAI instance together with all its instantiated clients. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Expand All @@ -41,9 +40,9 @@ public final class GenerativeModel {
private final String modelName;
private final String resourceName;
private final VertexAI vertexAi;
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;

/**
* Constructs a GenerativeModel instance.
Expand All @@ -59,8 +58,8 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
this(
modelName,
GenerationConfig.getDefaultInstance(),
new ArrayList<SafetySetting>(),
new ArrayList<Tool>(),
ImmutableList.of(),
ImmutableList.of(),
vertexAi);
}

Expand All @@ -81,22 +80,29 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
private GenerativeModel(
String modelName,
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
List<Tool> tools,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
VertexAI vertexAi) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
"modelName can't be null or empty. Please refer to"
+ " https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models"
+ " to find the right model name.");
checkNotNull(vertexAi, "VertexAI can't be null.");
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
checkNotNull(tools, "ImmutableList<Tool> can't be null.");

modelName = reconcileModelName(modelName);
this.modelName = modelName;
this.resourceName =
String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
vertexAi.getProjectId(), vertexAi.getLocation(), modelName);
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "List<SafetySettings> can't be null.");
checkNotNull(tools, "List<Tool> can't be null.");
this.vertexAi = vertexAi;
this.generationConfig = generationConfig;
this.safetySettings = ImmutableList.copyOf(safetySettings);
this.tools = ImmutableList.copyOf(tools);
this.safetySettings = safetySettings;
this.tools = tools;
}

/** Builder class for {@link GenerativeModel}. */
Expand Down Expand Up @@ -163,7 +169,6 @@ public Builder setSafetySettings(List<SafetySetting> safetySettings) {
checkNotNull(
safetySettings,
"safetySettings can't be null. Use an empty list if no safety settings is intended.");
safetySettings.removeIf(safetySetting -> safetySetting == null);
this.safetySettings = ImmutableList.copyOf(safetySettings);
return this;
}
Expand All @@ -175,47 +180,11 @@ public Builder setSafetySettings(List<SafetySetting> safetySettings) {
@BetaApi
public Builder setTools(List<Tool> tools) {
checkNotNull(tools, "tools can't be null. Use an empty list if no tool is to be used.");
tools.removeIf(tool -> tool == null);
this.tools = ImmutableList.copyOf(tools);
return this;
}
}

/**
* Creates a copy of the current model with updated GenerationConfig.
*
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
* used in the new model.
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
}

/**
* Creates a copy of the current model with updated safetySettings.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will
* be used in the new model.
* @return a new {@link GenerativeModel} instance with the specified safetySettings.
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
}

/**
* Counts tokens in a text message.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,34 +416,6 @@ public void testGenerateContentwithDefaultTools() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

GenerateContentResponse unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down Expand Up @@ -597,34 +569,4 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockServerStream);
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);

ResponseStream unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContentStream(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}
}

0 comments on commit c680361

Please sign in to comment.