diff --git a/src/Momento.Sdk/IPreviewVectorIndexClient.cs b/src/Momento.Sdk/IPreviewVectorIndexClient.cs index ea1456d5..11740fcf 100644 --- a/src/Momento.Sdk/IPreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/IPreviewVectorIndexClient.cs @@ -182,6 +182,6 @@ public Task UpsertItemBatchAsync(string indexName /// } /// /// - public Task SearchAsync(string indexName, IEnumerable queryVector, uint topK = 10, + public Task SearchAsync(string indexName, IEnumerable queryVector, int topK = 10, MetadataFields? metadataFields = null); } \ No newline at end of file diff --git a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs index 70f5dc20..526b09b9 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs @@ -67,13 +67,14 @@ public async Task DeleteItemBatchAsync(string ind } } - public async Task SearchAsync(string indexName, IEnumerable queryVector, uint topK, + public async Task SearchAsync(string indexName, IEnumerable queryVector, int topK, MetadataFields? metadataFields) { try { _logger.LogTraceVectorIndexRequest("search", indexName); CheckValidIndexName(indexName); + var validatedTopK = ValidateTopK(topK); metadataFields ??= new List(); var metadataRequest = metadataFields switch { @@ -89,7 +90,7 @@ public async Task SearchAsync(string indexName, IEnumerabl { IndexName = indexName, QueryVector = new _Vector { Elements = { queryVector } }, - TopK = topK, + TopK = validatedTopK, MetadataFields = metadataRequest }; @@ -119,28 +120,19 @@ private static IEnumerable<_Metadata> Convert(Dictionary var convertedMetadataList = new List<_Metadata>(); foreach (var metadataPair in metadata) { - _Metadata convertedMetadata; - switch (metadataPair.Value) + var convertedMetadata = metadataPair.Value switch { - case StringValue stringValue: - convertedMetadata = new _Metadata { Field = metadataPair.Key, StringValue = stringValue.Value }; - break; - case LongValue longValue: - convertedMetadata = new _Metadata { Field = metadataPair.Key, IntegerValue = longValue.Value }; - break; - case DoubleValue doubleValue: - convertedMetadata = new _Metadata { Field = metadataPair.Key, DoubleValue = doubleValue.Value }; - break; - case BoolValue boolValue: - convertedMetadata = new _Metadata { Field = metadataPair.Key, BooleanValue = boolValue.Value }; - break; - case StringListValue stringListValue: - var listOfStrings = new _Metadata.Types._ListOfStrings { Values = { stringListValue.Value } }; - convertedMetadata = new _Metadata { Field = metadataPair.Key, ListOfStringsValue = listOfStrings }; - break; - default: - throw new InvalidArgumentException($"Unknown metadata type {metadataPair.Value.GetType()}"); - } + StringValue stringValue => new _Metadata { Field = metadataPair.Key, StringValue = stringValue.Value }, + LongValue longValue => new _Metadata { Field = metadataPair.Key, IntegerValue = longValue.Value }, + DoubleValue doubleValue => new _Metadata { Field = metadataPair.Key, DoubleValue = doubleValue.Value }, + BoolValue boolValue => new _Metadata { Field = metadataPair.Key, BooleanValue = boolValue.Value }, + StringListValue stringListValue => new _Metadata + { + Field = metadataPair.Key, + ListOfStringsValue = new _Metadata.Types._ListOfStrings { Values = { stringListValue.Value } } + }, + _ => throw new InvalidArgumentException($"Unknown metadata type {metadataPair.Value.GetType()}") + }; convertedMetadataList.Add(convertedMetadata); } @@ -186,6 +178,16 @@ private static void CheckValidIndexName(string indexName) } } + private static uint ValidateTopK(long topK) + { + if (topK <= 0) + { + throw new InvalidArgumentException("topK must be greater than 0"); + } + + return (uint)topK; + } + private DateTime CalculateDeadline() { return DateTime.UtcNow.Add(deadline); diff --git a/src/Momento.Sdk/PreviewVectorIndexClient.cs b/src/Momento.Sdk/PreviewVectorIndexClient.cs index 1250a876..19b390e3 100644 --- a/src/Momento.Sdk/PreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/PreviewVectorIndexClient.cs @@ -69,7 +69,7 @@ public async Task DeleteItemBatchAsync(string ind /// public async Task SearchAsync(string indexName, IEnumerable queryVector, - uint topK = 10, MetadataFields? metadataFields = null) + int topK = 10, MetadataFields? metadataFields = null) { return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields); }