diff --git a/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs b/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs index 1ba5be2b6..13c57d22b 100644 --- a/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs +++ b/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; +using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; @@ -38,42 +41,64 @@ public ScrollVectorsRequest HavingExternalId(string id) return this; } - public ScrollVectorsRequest HavingAllTags(IEnumerable? tags) + public ScrollVectorsRequest HavingAllTags(IEnumerable? tagFilters) { - if (tags == null) { return this; } + if (tagFilters == null) { return this; } - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } return this; } - public ScrollVectorsRequest HavingSomeTags(IEnumerable?>? tagGroups) + public ScrollVectorsRequest HavingSomeTags(IEnumerable>? tagFiltersGroups) { - if (tagGroups == null) { return this; } + if (tagFiltersGroups == null) { return this; } - var list = tagGroups.ToList(); + var list = tagFiltersGroups.ToList(); if (list.Count < 2) { return this.HavingAllTags(list.FirstOrDefault()); } var orFilter = new Filter.OrClause(); - foreach (var tags in list) + foreach (var tagFilters in list) { - if (tags == null) { continue; } + if (tagFilters == null) { continue; } var andFilter = new Filter.AndClause(); - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - andFilter.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } diff --git a/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs b/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs index 188ae3db8..3a2941be3 100644 --- a/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs +++ b/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; @@ -57,46 +59,61 @@ public SearchVectorsRequest HavingExternalId(string externalId) return this; } - public SearchVectorsRequest HavingAllTags(IEnumerable? tags) + public SearchVectorsRequest HavingAllTags(IEnumerable? tagFilters) { - if (tags == null) { return this; } + if (tagFilters == null) { return this; } - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } return this; } - public SearchVectorsRequest HavingSomeTags(IEnumerable?>? tagGroups) + public SearchVectorsRequest HavingSomeTags(List>? tagFiltersGroup) { - if (tagGroups == null) { return this; } + if (tagFiltersGroup == null) { return this; } - var list = tagGroups.ToList(); + var list = tagFiltersGroup.ToList(); if (list.Count < 2) { return this.HavingAllTags(list.FirstOrDefault()); } var orFilter = new Filter.OrClause(); - foreach (var tags in list) + foreach (var tagFilters in list) { - if (tags == null) { continue; } + if (tagFilters == null) { continue; } var andFilter = new Filter.AndClause(); - foreach (var tag in tags.Where(t => !string.IsNullOrEmpty(t))) + foreach (var tagFilter in tagFilters.Where(t => !string.IsNullOrEmpty(t.Tag))) { - if (tag[0] == '!') + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) { - andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tag[1..])); + andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); } else { - andFilter.AndValue(QdrantConstants.PayloadTagsField, tag); + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); } } diff --git a/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs b/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs index 8abf979ad..d797bb590 100644 --- a/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs +++ b/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using Microsoft.KernelMemory.MemoryStorage; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client; @@ -279,7 +280,7 @@ public async Task DeleteVectorsAsync(string collectionName, IList vectorId /// List of vectors public async Task>> GetListAsync( string collectionName, - IEnumerable?>? requiredTags = null, + List>? requiredTags = null, int offset = 0, int limit = 1, bool withVectors = false, @@ -339,7 +340,7 @@ public async Task>> GetListAsync( double scoreThreshold, int limit = 1, bool withVectors = false, - IEnumerable?>? requiredTags = null, + List>? requiredTags = null, CancellationToken cancellationToken = default) { this._log.LogTrace("Searching top {0} nearest vectors", limit); diff --git a/extensions/Qdrant/Qdrant/Internals/TagFilter.cs b/extensions/Qdrant/Qdrant/Internals/TagFilter.cs new file mode 100644 index 000000000..8fe604f1b --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/TagFilter.cs @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; + +internal record TagFilter(string Tag, TagFilterType FilterType); diff --git a/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs b/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs new file mode 100644 index 000000000..3150b36de --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; + +internal enum TagFilterType +{ + Unknown = 0, + Equal = 1, + NotEqual = 2, +} diff --git a/extensions/Qdrant/Qdrant/QdrantMemory.cs b/extensions/Qdrant/Qdrant/QdrantMemory.cs index dc1bc8666..e0540df52 100644 --- a/extensions/Qdrant/Qdrant/QdrantMemory.cs +++ b/extensions/Qdrant/Qdrant/QdrantMemory.cs @@ -12,6 +12,7 @@ using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using Microsoft.KernelMemory.MemoryStorage; namespace Microsoft.KernelMemory.MemoryDb.Qdrant; @@ -160,7 +161,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable index = NormalizeIndexName(index); if (limit <= 0) { limit = int.MaxValue; } - List> requiredTags = CreateRequiredTagsFromMemoryFilters(filters); + var requiredTags = CreateRequiredTagsFromMemoryFilters(filters); Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false); @@ -200,7 +201,7 @@ public async IAsyncEnumerable GetListAsync( index = NormalizeIndexName(index); if (limit <= 0) { limit = int.MaxValue; } - List> requiredTags = CreateRequiredTagsFromMemoryFilters(filters); + var requiredTags = CreateRequiredTagsFromMemoryFilters(filters); List> results; try @@ -268,9 +269,9 @@ private static string NormalizeIndexName(string index) return index.Trim(); } - private static List> CreateRequiredTagsFromMemoryFilters(ICollection? filters) + private static List> CreateRequiredTagsFromMemoryFilters(ICollection? filters) { - var requiredTags = new List>(); + var requiredTags = new List>(); // Check if we have at least one non-empty filter var nonEmptyFilters = filters?.Where(filters => !filters.IsEmpty()).ToArray() ?? Array.Empty(); if (nonEmptyFilters.Length > 0) @@ -278,16 +279,16 @@ private static List> CreateRequiredTagsFromMemoryFilters(ICo foreach (var memoryFilter in nonEmptyFilters) { var filtersList = memoryFilter.GetFilters(); - List stringFilters = new(); + List stringFilters = new(); foreach (var baseFilter in filtersList) { if (baseFilter is EqualFilter eq) { - stringFilters.Add($"{eq.Key}{Constants.ReservedEqualsChar}{eq.Value}"); + stringFilters.Add(new TagFilter($"{eq.Key}{Constants.ReservedEqualsChar}{eq.Value}", TagFilterType.Equal)); } else if (baseFilter is NotEqualFilter neq) { - stringFilters.Add($"!{neq.Key}{Constants.ReservedEqualsChar}{neq.Value}"); + stringFilters.Add(new TagFilter($"{neq.Key}{Constants.ReservedEqualsChar}{neq.Value}", TagFilterType.NotEqual)); } else {