Skip to content

Commit

Permalink
Fixed Qdrant filtering with NOT Clause.
Browse files Browse the repository at this point in the history
  • Loading branch information
alkampfergit committed Sep 3, 2024
1 parent 1d0c474 commit 5f5b820
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 36 deletions.
51 changes: 38 additions & 13 deletions extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text.Json.Serialization;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;
using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter;

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http;

Expand Down Expand Up @@ -38,42 +41,64 @@ public ScrollVectorsRequest HavingExternalId(string id)
return this;
}

public ScrollVectorsRequest HavingAllTags(IEnumerable<string>? tags)
public ScrollVectorsRequest HavingAllTags(IEnumerable<TagFilter>? tagFilters)
{
if (tags == null) { return this; }
if (tagFilters == null) { return this; }

foreach (var tag in tags)
foreach (var tagFilter in tagFilters)
{
if (!string.IsNullOrEmpty(tag))
if (!string.IsNullOrEmpty(tagFilter.Tag))
{
this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag);
if (tagFilter.FilterType == TagFilterType.NotEqual)
{
this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag));
}
else if (tagFilter.FilterType == TagFilterType.Equal)
{
this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag);
}
else
{
throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory");
}
}
}

return this;
}

public ScrollVectorsRequest HavingSomeTags(IEnumerable<IEnumerable<string>?>? tagGroups)
public ScrollVectorsRequest HavingSomeTags(IEnumerable<IEnumerable<TagFilter>>? tagFiltersGroups)
{
if (tagGroups == null) { return this; }
if (tagFiltersGroups == null) { return this; }

var list = tagGroups.ToList();
var list = tagFiltersGroups.ToList();
if (list.Count < 2)
{
return this.HavingAllTags(list.FirstOrDefault());
}

var orFilter = new Filter.OrClause();
foreach (var tags in list)
foreach (var tagFilters in list)
{
if (tags == null) { continue; }
if (tagFilters == null) { continue; }

var andFilter = new Filter.AndClause();
foreach (var tag in tags)
foreach (var tagFilter in tagFilters)
{
if (!string.IsNullOrEmpty(tag))
if (!string.IsNullOrEmpty(tagFilter.Tag))
{
andFilter.AndValue(QdrantConstants.PayloadTagsField, tag);
if (tagFilter.FilterType == TagFilterType.NotEqual)
{
andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag));
}
else if (tagFilter.FilterType == TagFilterType.Equal)
{
andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag);
}
else
{
throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory");
}
}
}

Expand Down
45 changes: 31 additions & 14 deletions extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text.Json.Serialization;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;
using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter;

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http;
Expand Down Expand Up @@ -57,46 +59,61 @@ public SearchVectorsRequest HavingExternalId(string externalId)
return this;
}

public SearchVectorsRequest HavingAllTags(IEnumerable<string>? tags)
public SearchVectorsRequest HavingAllTags(IEnumerable<TagFilter>? tagFilters)
{
if (tags == null) { return this; }
if (tagFilters == null) { return this; }

foreach (var tag in tags)
foreach (var tagFilter in tagFilters)
{
if (!string.IsNullOrEmpty(tag))
if (!string.IsNullOrEmpty(tagFilter.Tag))
{
this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag);
if (tagFilter.FilterType == TagFilterType.NotEqual)
{
this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag));
}
else if (tagFilter.FilterType == TagFilterType.Equal)
{
this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag);
}
else
{
throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory");
}
}
}

return this;
}

public SearchVectorsRequest HavingSomeTags(IEnumerable<IEnumerable<string>?>? tagGroups)
public SearchVectorsRequest HavingSomeTags(List<IEnumerable<TagFilter>>? tagFiltersGroup)
{
if (tagGroups == null) { return this; }
if (tagFiltersGroup == null) { return this; }

var list = tagGroups.ToList();
var list = tagFiltersGroup.ToList();
if (list.Count < 2)
{
return this.HavingAllTags(list.FirstOrDefault());
}

var orFilter = new Filter.OrClause();
foreach (var tags in list)
foreach (var tagFilters in list)
{
if (tags == null) { continue; }
if (tagFilters == null) { continue; }

var andFilter = new Filter.AndClause();
foreach (var tag in tags.Where(t => !string.IsNullOrEmpty(t)))
foreach (var tagFilter in tagFilters.Where(t => !string.IsNullOrEmpty(t.Tag)))
{
if (tag[0] == '!')
if (tagFilter.FilterType == TagFilterType.NotEqual)
{
andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag));
}
else if (tagFilter.FilterType == TagFilterType.Equal)
{
andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tag[1..]));
andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag);
}
else
{
andFilter.AndValue(QdrantConstants.PayloadTagsField, tag);
throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory");
}
}

Expand Down
5 changes: 3 additions & 2 deletions extensions/Qdrant/Qdrant/Internals/QdrantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;
using Microsoft.KernelMemory.MemoryStorage;

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client;
Expand Down Expand Up @@ -279,7 +280,7 @@ public async Task DeleteVectorsAsync(string collectionName, IList<Guid> vectorId
/// <returns>List of vectors</returns>
public async Task<List<QdrantPoint<T>>> GetListAsync(
string collectionName,
IEnumerable<IEnumerable<string>?>? requiredTags = null,
List<IEnumerable<TagFilter>>? requiredTags = null,
int offset = 0,
int limit = 1,
bool withVectors = false,
Expand Down Expand Up @@ -339,7 +340,7 @@ public async Task<List<QdrantPoint<T>>> GetListAsync(
double scoreThreshold,
int limit = 1,
bool withVectors = false,
IEnumerable<IEnumerable<string>?>? requiredTags = null,
List<IEnumerable<TagFilter>>? requiredTags = null,
CancellationToken cancellationToken = default)
{
this._log.LogTrace("Searching top {0} nearest vectors", limit);
Expand Down
5 changes: 5 additions & 0 deletions extensions/Qdrant/Qdrant/Internals/TagFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;

internal record TagFilter(string Tag, TagFilterType FilterType);
10 changes: 10 additions & 0 deletions extensions/Qdrant/Qdrant/Internals/TagFilterType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;

internal enum TagFilterType
{
Unknown = 0,
Equal = 1,
NotEqual = 2,
}
15 changes: 8 additions & 7 deletions extensions/Qdrant/Qdrant/QdrantMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Client;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals;
using Microsoft.KernelMemory.MemoryStorage;

namespace Microsoft.KernelMemory.MemoryDb.Qdrant;
Expand Down Expand Up @@ -160,7 +161,7 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(string index, IEnumerable
index = NormalizeIndexName(index);
if (limit <= 0) { limit = int.MaxValue; }

List<IEnumerable<string>> requiredTags = CreateRequiredTagsFromMemoryFilters(filters);
var requiredTags = CreateRequiredTagsFromMemoryFilters(filters);

Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -200,7 +201,7 @@ public async IAsyncEnumerable<MemoryRecord> GetListAsync(
index = NormalizeIndexName(index);
if (limit <= 0) { limit = int.MaxValue; }

List<IEnumerable<string>> requiredTags = CreateRequiredTagsFromMemoryFilters(filters);
var requiredTags = CreateRequiredTagsFromMemoryFilters(filters);

List<QdrantPoint<DefaultQdrantPayload>> results;
try
Expand Down Expand Up @@ -268,26 +269,26 @@ private static string NormalizeIndexName(string index)
return index.Trim();
}

private static List<IEnumerable<string>> CreateRequiredTagsFromMemoryFilters(ICollection<MemoryFilter>? filters)
private static List<IEnumerable<TagFilter>> CreateRequiredTagsFromMemoryFilters(ICollection<MemoryFilter>? filters)
{
var requiredTags = new List<IEnumerable<string>>();
var requiredTags = new List<IEnumerable<TagFilter>>();
// Check if we have at least one non-empty filter
var nonEmptyFilters = filters?.Where(filters => !filters.IsEmpty()).ToArray() ?? Array.Empty<MemoryFilter>();
if (nonEmptyFilters.Length > 0)
{
foreach (var memoryFilter in nonEmptyFilters)
{
var filtersList = memoryFilter.GetFilters();
List<string> stringFilters = new();
List<TagFilter> stringFilters = new();
foreach (var baseFilter in filtersList)
{
if (baseFilter is EqualFilter eq)
{
stringFilters.Add($"{eq.Key}{Constants.ReservedEqualsChar}{eq.Value}");
stringFilters.Add(new TagFilter($"{eq.Key}{Constants.ReservedEqualsChar}{eq.Value}", TagFilterType.Equal));
}
else if (baseFilter is NotEqualFilter neq)
{
stringFilters.Add($"!{neq.Key}{Constants.ReservedEqualsChar}{neq.Value}");
stringFilters.Add(new TagFilter($"{neq.Key}{Constants.ReservedEqualsChar}{neq.Value}", TagFilterType.NotEqual));
}
else
{
Expand Down

0 comments on commit 5f5b820

Please sign in to comment.