From c6803615ac0cb2ce469dfd912eb5e720afce77ea Mon Sep 17 00:00:00 2001 From: Zhenyi Qi Date: Wed, 13 Mar 2024 08:00:37 -0700 Subject: [PATCH] chore: [vertexai] Make instantiation of clients thread safe. PiperOrigin-RevId: 615418373 --- .../com/google/cloud/vertexai/VertexAI.java | 42 ++++++++--- .../generativeai/GenerativeModel.java | 69 +++++-------------- .../generativeai/GenerativeModelTest.java | 58 ---------------- 3 files changed, 53 insertions(+), 116 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 589272d20ce8..a282182445be 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -30,6 +30,7 @@ import com.google.cloud.vertexai.api.PredictionServiceSettings; import java.io.IOException; import java.util.List; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; @@ -59,6 +60,7 @@ public class VertexAI implements AutoCloseable { private PredictionServiceClient predictionServiceRestClient = null; private LlmUtilityServiceClient llmUtilityClient = null; private LlmUtilityServiceClient llmUtilityRestClient = null; + private final ReentrantLock lock = new ReentrantLock(); /** * Construct a VertexAI instance. @@ -245,7 +247,11 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException { * method calls that map to the API methods. */ private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException { - if (predictionServiceClient == null) { + if (this.predictionServiceClient != null) { + return this.predictionServiceClient; + } + lock.lock(); + try { PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); if (this.credentialsProvider != null) { @@ -266,8 +272,10 @@ private PredictionServiceClient getPredictionServiceGrpcClient() throws IOExcept defaultCredentialsProviderLogger.setLevel(Level.SEVERE); predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build()); defaultCredentialsProviderLogger.setLevel(previousLevel); + return predictionServiceClient; + } finally { + lock.unlock(); } - return predictionServiceClient; } /** @@ -278,7 +286,11 @@ private PredictionServiceClient getPredictionServiceGrpcClient() throws IOExcept * method calls that map to the API methods. */ private PredictionServiceClient getPredictionServiceRestClient() throws IOException { - if (predictionServiceRestClient == null) { + if (predictionServiceClient != null) { + return predictionServiceClient; + } + lock.lock(); + try { PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); @@ -300,8 +312,10 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept defaultCredentialsProviderLogger.setLevel(Level.SEVERE); predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build()); defaultCredentialsProviderLogger.setLevel(previousLevel); + return predictionServiceRestClient; + } finally { + lock.unlock(); } - return predictionServiceRestClient; } /** @@ -328,7 +342,11 @@ public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { * method calls that map to the API methods. */ private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException { - if (llmUtilityClient == null) { + if (llmUtilityClient != null) { + return llmUtilityClient; + } + lock.lock(); + try { LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); if (this.credentialsProvider != null) { @@ -349,8 +367,10 @@ private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException { defaultCredentialsProviderLogger.setLevel(Level.SEVERE); llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build()); defaultCredentialsProviderLogger.setLevel(previousLevel); + return llmUtilityClient; + } finally { + lock.unlock(); } - return llmUtilityClient; } /** @@ -361,7 +381,11 @@ private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException { * method calls that map to the API methods. */ private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException { - if (llmUtilityRestClient == null) { + if (llmUtilityClient != null) { + return llmUtilityClient; + } + lock.lock(); + try { LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); @@ -383,8 +407,10 @@ private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException { defaultCredentialsProviderLogger.setLevel(Level.SEVERE); llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build()); defaultCredentialsProviderLogger.setLevel(previousLevel); + return llmUtilityRestClient; + } finally { + lock.unlock(); } - return llmUtilityRestClient; } /** Closes the VertexAI instance together with all its instantiated clients. */ diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 60a1b9d75d78..e50ea1d90067 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -32,7 +32,6 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -41,9 +40,9 @@ public final class GenerativeModel { private final String modelName; private final String resourceName; private final VertexAI vertexAi; - private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance(); - private ImmutableList safetySettings = ImmutableList.of(); - private ImmutableList tools = ImmutableList.of(); + private final GenerationConfig generationConfig; + private final ImmutableList safetySettings; + private final ImmutableList tools; /** * Constructs a GenerativeModel instance. @@ -59,8 +58,8 @@ public GenerativeModel(String modelName, VertexAI vertexAi) { this( modelName, GenerationConfig.getDefaultInstance(), - new ArrayList(), - new ArrayList(), + ImmutableList.of(), + ImmutableList.of(), vertexAi); } @@ -81,22 +80,29 @@ public GenerativeModel(String modelName, VertexAI vertexAi) { private GenerativeModel( String modelName, GenerationConfig generationConfig, - List safetySettings, - List tools, + ImmutableList safetySettings, + ImmutableList tools, VertexAI vertexAi) { + checkArgument( + !Strings.isNullOrEmpty(modelName), + "modelName can't be null or empty. Please refer to" + + " https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models" + + " to find the right model name."); + checkNotNull(vertexAi, "VertexAI can't be null."); + checkNotNull(generationConfig, "GenerationConfig can't be null."); + checkNotNull(safetySettings, "ImmutableList can't be null."); + checkNotNull(tools, "ImmutableList can't be null."); + modelName = reconcileModelName(modelName); this.modelName = modelName; this.resourceName = String.format( "projects/%s/locations/%s/publishers/google/models/%s", vertexAi.getProjectId(), vertexAi.getLocation(), modelName); - checkNotNull(generationConfig, "GenerationConfig can't be null."); - checkNotNull(safetySettings, "List can't be null."); - checkNotNull(tools, "List can't be null."); this.vertexAi = vertexAi; this.generationConfig = generationConfig; - this.safetySettings = ImmutableList.copyOf(safetySettings); - this.tools = ImmutableList.copyOf(tools); + this.safetySettings = safetySettings; + this.tools = tools; } /** Builder class for {@link GenerativeModel}. */ @@ -163,7 +169,6 @@ public Builder setSafetySettings(List safetySettings) { checkNotNull( safetySettings, "safetySettings can't be null. Use an empty list if no safety settings is intended."); - safetySettings.removeIf(safetySetting -> safetySetting == null); this.safetySettings = ImmutableList.copyOf(safetySettings); return this; } @@ -175,47 +180,11 @@ public Builder setSafetySettings(List safetySettings) { @BetaApi public Builder setTools(List tools) { checkNotNull(tools, "tools can't be null. Use an empty list if no tool is to be used."); - tools.removeIf(tool -> tool == null); this.tools = ImmutableList.copyOf(tools); return this; } } - /** - * Creates a copy of the current model with updated GenerationConfig. - * - * @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be - * used in the new model. - * @return a new {@link GenerativeModel} instance with the specified GenerationConfig. - */ - public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) { - return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi); - } - - /** - * Creates a copy of the current model with updated safetySettings. - * - * @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will - * be used in the new model. - * @return a new {@link GenerativeModel} instance with the specified safetySettings. - */ - public GenerativeModel withSafetySettings(List safetySettings) { - return new GenerativeModel( - modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi); - } - - /** - * Creates a copy of the current model with updated tools. - * - * @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in - * the new model. - * @return a new {@link GenerativeModel} instance with the specified tools. - */ - public GenerativeModel withTools(List tools) { - return new GenerativeModel( - modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi); - } - /** * Counts tokens in a text message. * diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 0f74e84fc768..aba97637b0af 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -416,34 +416,6 @@ public void testGenerateContentwithDefaultTools() throws Exception { assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } - @Test - public void testGenerateContentwithFluentApi() throws Exception { - model = new GenerativeModel(MODEL_NAME, vertexAi); - - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); - - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); - when(mockUnaryCallable.call(any(GenerateContentRequest.class))) - .thenReturn(mockGenerateContentResponse); - - GenerateContentResponse unused = - model - .withGenerationConfig(GENERATION_CONFIG) - .withSafetySettings(safetySettings) - .withTools(tools) - .generateContent(TEXT); - - ArgumentCaptor request = - ArgumentCaptor.forClass(GenerateContentRequest.class); - verify(mockUnaryCallable).call(request.capture()); - assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); - assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); - assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); - assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); - } - @Test public void testGenerateContentStreamwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); @@ -597,34 +569,4 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { verify(mockServerStreamCallable).call(request.capture()); assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } - - @Test - public void testGenerateContentStreamwithFluentApi() throws Exception { - model = new GenerativeModel(MODEL_NAME, vertexAi); - - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); - - when(mockPredictionServiceClient.streamGenerateContentCallable()) - .thenReturn(mockServerStreamCallable); - when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) - .thenReturn(mockServerStream); - when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator); - - ResponseStream unused = - model - .withGenerationConfig(GENERATION_CONFIG) - .withSafetySettings(safetySettings) - .withTools(tools) - .generateContentStream(TEXT); - - ArgumentCaptor request = - ArgumentCaptor.forClass(GenerateContentRequest.class); - verify(mockServerStreamCallable).call(request.capture()); - assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); - assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); - assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); - assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); - } }