Skip to content

Commit

Permalink
Fixing GoogleVertex tests after merge from upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-rubinstein committed Nov 18, 2024
1 parent aa36841 commit 42e3f5a
Showing 1 changed file with 29 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAre
}

public void testOverrideWith_SetsInputTypeToOverride_WhenFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
var model = createModel("model", Boolean.FALSE, null);
var model = createModel("model", Boolean.FALSE, (InputType) null);
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.SEARCH);

var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
Expand All @@ -80,7 +80,7 @@ public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredT
}

public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
var model = createModel("model", Boolean.FALSE, null);
var model = createModel("model", Boolean.FALSE, (InputType) null);
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, InputType.CLUSTERING), InputType.SEARCH);

var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
Expand All @@ -96,10 +96,10 @@ public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_Wh
}

public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
var model = createModel("model", Boolean.FALSE, null);
var model = createModel("model", Boolean.FALSE, (InputType) null);
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED);

var expectedModel = createModel("model", Boolean.FALSE, null);
var expectedModel = createModel("model", Boolean.FALSE, (InputType) null);
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}

Expand Down Expand Up @@ -136,6 +136,31 @@ public static GoogleVertexAiEmbeddingsModel createModel(
);
}

public static GoogleVertexAiEmbeddingsModel createModel(
String modelId,
@Nullable Boolean autoTruncate,
SimilarityMeasure similarityMeasure
) {
return new GoogleVertexAiEmbeddingsModel(
"id",
TaskType.TEXT_EMBEDDING,
"service",
new GoogleVertexAiEmbeddingsServiceSettings(
randomAlphaOfLength(8),
randomAlphaOfLength(8),
modelId,
false,
null,
null,
similarityMeasure,
null
),
new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, randomFrom(InputType.INGEST, InputType.SEARCH)),
null,
new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray()))
);
}

public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullable Boolean autoTruncate, @Nullable InputType inputType) {
return new GoogleVertexAiEmbeddingsModel(
"id",
Expand Down

0 comments on commit 42e3f5a

Please sign in to comment.