Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nand4011 committed Oct 30, 2023
1 parent d9affd3 commit 078b735
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,6 @@ public Task<VectorUpsertItemBatchResponse> UpsertItemBatchAsync(string indexName
/// }
/// </code>
///</returns>
public Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, uint topK = 10,
public Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK = 10,
MetadataFields? metadataFields = null);
}
48 changes: 25 additions & 23 deletions src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ public async Task<VectorDeleteItemBatchResponse> DeleteItemBatchAsync(string ind
}
}

public async Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, uint topK,
public async Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK,
MetadataFields? metadataFields)
{
try
{
_logger.LogTraceVectorIndexRequest("search", indexName);
CheckValidIndexName(indexName);
var validatedTopK = ValidateTopK(topK);
metadataFields ??= new List<string>();
var metadataRequest = metadataFields switch
{
Expand All @@ -89,7 +90,7 @@ public async Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerabl
{
IndexName = indexName,
QueryVector = new _Vector { Elements = { queryVector } },
TopK = topK,
TopK = validatedTopK,
MetadataFields = metadataRequest
};

Expand Down Expand Up @@ -119,28 +120,19 @@ private static IEnumerable<_Metadata> Convert(Dictionary<string, MetadataValue>
var convertedMetadataList = new List<_Metadata>();
foreach (var metadataPair in metadata)
{
_Metadata convertedMetadata;
switch (metadataPair.Value)
var convertedMetadata = metadataPair.Value switch
{
case StringValue stringValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, StringValue = stringValue.Value };
break;
case LongValue longValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, IntegerValue = longValue.Value };
break;
case DoubleValue doubleValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, DoubleValue = doubleValue.Value };
break;
case BoolValue boolValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, BooleanValue = boolValue.Value };
break;
case StringListValue stringListValue:
var listOfStrings = new _Metadata.Types._ListOfStrings { Values = { stringListValue.Value } };
convertedMetadata = new _Metadata { Field = metadataPair.Key, ListOfStringsValue = listOfStrings };
break;
default:
throw new InvalidArgumentException($"Unknown metadata type {metadataPair.Value.GetType()}");
}
StringValue stringValue => new _Metadata { Field = metadataPair.Key, StringValue = stringValue.Value },
LongValue longValue => new _Metadata { Field = metadataPair.Key, IntegerValue = longValue.Value },
DoubleValue doubleValue => new _Metadata { Field = metadataPair.Key, DoubleValue = doubleValue.Value },
BoolValue boolValue => new _Metadata { Field = metadataPair.Key, BooleanValue = boolValue.Value },
StringListValue stringListValue => new _Metadata
{
Field = metadataPair.Key,
ListOfStringsValue = new _Metadata.Types._ListOfStrings { Values = { stringListValue.Value } }
},
_ => throw new InvalidArgumentException($"Unknown metadata type {metadataPair.Value.GetType()}")
};

convertedMetadataList.Add(convertedMetadata);
}
Expand Down Expand Up @@ -186,6 +178,16 @@ private static void CheckValidIndexName(string indexName)
}
}

private static uint ValidateTopK(long topK)
{
if (topK <= 0)
{
throw new InvalidArgumentException("topK must be greater than 0");
}

return (uint)topK;
}

private DateTime CalculateDeadline()
{
return DateTime.UtcNow.Add(deadline);
Expand Down
2 changes: 1 addition & 1 deletion src/Momento.Sdk/PreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public async Task<VectorDeleteItemBatchResponse> DeleteItemBatchAsync(string ind

/// <inheritdoc />
public async Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector,
uint topK = 10, MetadataFields? metadataFields = null)
int topK = 10, MetadataFields? metadataFields = null)
{
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields);
}
Expand Down

0 comments on commit 078b735

Please sign in to comment.