Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add score threshold to MVI search #516

Merged
merged 1 commit into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,32 +156,36 @@ public Task<UpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
///</returns>
public Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName, IEnumerable<string> ids);

/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created.
/// </summary>
/// <param name="indexName">The name of the vector index to search in.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchResponse.Success</description></item>
/// <item><description>SearchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
///</returns>
/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created.
/// </summary>
/// <param name="indexName">The name of the vector index to search in.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <param name="scoreThreshold">A score threshold to filter results by. For cosine
/// similarity and inner product, scores lower than the threshold are excluded. For
/// euclidean similarity, scores higher than the threshold are excluded. The threshold
/// is exclusive. Defaults to None, ie no threshold.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchResponse.Success</description></item>
/// <item><description>SearchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
/// </returns>
public Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK = 10,
MetadataFields? metadataFields = null);
MetadataFields? metadataFields = null, float? scoreThreshold = null);
}
15 changes: 12 additions & 3 deletions src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public async Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName
}

public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK,
MetadataFields? metadataFields)
MetadataFields? metadataFields, float? scoreThreshold)
{
try
{
Expand All @@ -91,9 +91,18 @@ public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<floa
IndexName = indexName,
QueryVector = new _Vector { Elements = { queryVector } },
TopK = validatedTopK,
MetadataFields = metadataRequest
MetadataFields = metadataRequest,
};

if (scoreThreshold != null)
{
request.ScoreThreshold = scoreThreshold.Value;
}
else
{
request.NoScoreThreshold = new _NoScoreThreshold();
}

var response =
await grpcManager.Client.SearchAsync(request, new CallOptions(deadline: CalculateDeadline()));
var searchHits = response.Hits.Select(Convert).ToList();
Expand Down Expand Up @@ -167,7 +176,7 @@ private static MetadataValue Convert(_Metadata metadata)

private static SearchHit Convert(_SearchHit hit)
{
return new SearchHit(hit.Id, hit.Distance, Convert(hit.Metadata));
return new SearchHit(hit.Id, hit.Score, Convert(hit.Metadata));
}

private static void CheckValidIndexName(string indexName)
Expand Down
2 changes: 1 addition & 1 deletion src/Momento.Sdk/Momento.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
<ItemGroup>
<PackageReference Include="Grpc.Net.Client" Version="2.49.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Momento.Protos" Version="0.91.1" />
<PackageReference Include="Momento.Protos" Version="0.94.1" />
<PackageReference Include="JWT" Version="9.0.3" />
<PackageReference Include="System.Threading.Channels" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
Expand Down
4 changes: 2 additions & 2 deletions src/Momento.Sdk/PreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ public async Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName

/// <inheritdoc />
public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector,
int topK = 10, MetadataFields? metadataFields = null)
int topK = 10, MetadataFields? metadataFields = null, float? searchThreshold = null)
{
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields);
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields, searchThreshold);
}

/// <inheritdoc />
Expand Down
85 changes: 85 additions & 0 deletions tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,89 @@ public async Task UpsertAndSearch_WithDiverseMetadata()
await vectorIndexClient.DeleteIndexAsync(indexName);
}
}

public static IEnumerable<object[]> SearchThresholdTestCases =>
new List<object[]>
{
// similarity metric, scores, thresholds
new object[]
{
SimilarityMetric.CosineSimilarity,
new List<float> { 1.0f, 0.0f, -1.0f },
new List<float> { 0.5f, -1.01f, 1.0f }
},
new object[]
{
SimilarityMetric.InnerProduct,
new List<float> { 4.0f, 0.0f, -4.0f },
new List<float> { 0.0f, -4.01f, 4.0f }
},
new object[]
{
SimilarityMetric.EuclideanSimilarity,
new List<float> { 2.0f, 10.0f, 18.0f },
new List<float> { 3.0f, 20.0f, -0.01f }
}
};

[Theory]
[MemberData(nameof(SearchThresholdTestCases))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric, List<float> scores,
List<float> thresholds)
{
var indexName = $"index-{Utils.NewGuidString()}";

var createResponse = await vectorIndexClient.CreateIndexAsync(indexName, 2, similarityMetric);
Assert.True(createResponse is CreateIndexResponse.Success, $"Unexpected response: {createResponse}");

try
{
var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List<Item>
{
new("test_item_1", new List<float> { 1.0f, 1.0f }),
new("test_item_2", new List<float> { -1.0f, 1.0f }),
new("test_item_3", new List<float> { -1.0f, -1.0f })
});
Assert.True(upsertResponse is UpsertItemBatchResponse.Success,
$"Unexpected response: {upsertResponse}");

await Task.Delay(2_000);

var queryVector = new List<float> { 2.0f, 2.0f };
var searchHits = new List<SearchHit>
{
new("test_item_1", scores[0]),
new("test_item_2", scores[1]),
new("test_item_3", scores[2])
};

// Test threshold to get only the top result
var searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[0]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
var successResponse = (SearchResponse.Success)searchResponse;
Assert.Equal(new List<SearchHit>
{
searchHits[0]
}, successResponse.Hits);

// Test threshold to get all results
searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[1]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
successResponse = (SearchResponse.Success)searchResponse;
Assert.Equal(searchHits, successResponse.Hits);

// Test threshold to get no results
searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[2]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
successResponse = (SearchResponse.Success)searchResponse;
Assert.Empty(successResponse.Hits);
}
finally
{
await vectorIndexClient.DeleteIndexAsync(indexName);
}
}
}
Loading