Skip to content

Commit

Permalink
Not filter on elasticsearch.
Browse files Browse the repository at this point in the history
  • Loading branch information
alkampfergit committed Aug 2, 2024
1 parent 42bd83b commit 49dad5b
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void BadIndexNamesAreRejected(string indexName, int errorCount)
$"" +
$"The expected number of errors was {errorCount}.");

Assert.True(errorCount == exception.Errors.Count(), $"The number of errprs expected is different than the number of errors found.");
Assert.True(errorCount == exception.Errors.Count(), $"The number of errors expected is different than the number of errors found.");
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,29 @@ public class KernelMemoryTests : MemoryDbFunctionalTest
public KernelMemoryTests(IConfiguration cfg, ITestOutputHelper output)
: base(cfg, output)
{
this.KernelMemory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithOpenAI(this.OpenAiConfig)
.WithElasticsearchMemoryDb(this.ElasticsearchConfig)
.Build<MemoryServerless>();
if (cfg.GetValue<bool>("UseAzureOpenAI"))
{
Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey));

this.KernelMemory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration)
.WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration)
.WithElasticsearchMemoryDb(this.ElasticsearchConfig)
.Build<MemoryServerless>();
}
else
{
Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey));

this.KernelMemory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithOpenAI(this.OpenAiConfig)
.WithElasticsearchMemoryDb(this.ElasticsearchConfig)
.Build<MemoryServerless>();
}
}

public IKernelMemory KernelMemory { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Elastic.Clients.Elasticsearch;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.AzureOpenAI;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.MemoryDb.Elasticsearch;
using Microsoft.KernelMemory.MemoryDb.Elasticsearch.Internals;
Expand All @@ -24,12 +25,21 @@ protected MemoryDbFunctionalTest(IConfiguration cfg, ITestOutputHelper output)
: base(cfg, output)
{
this.Output = output ?? throw new ArgumentNullException(nameof(output));

if (cfg.GetValue<bool>("UseAzureOpenAI"))
{
this.TextEmbeddingGenerator = new AzureOpenAITextEmbeddingGenerator(
config: base.AzureOpenAIEmbeddingConfiguration,
textTokenizer: default,
loggerFactory: default);
}
else
{
#pragma warning disable KMEXP01 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
this.TextEmbeddingGenerator = new OpenAITextEmbeddingGenerator(
config: base.OpenAiConfig,
textTokenizer: default,
loggerFactory: default);
this.TextEmbeddingGenerator = new OpenAITextEmbeddingGenerator(
config: base.OpenAiConfig,
textTokenizer: default,
loggerFactory: default);
}
#pragma warning restore KMEXP01 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

this.Client = new ElasticsearchClient(base.ElasticsearchConfig.ToElasticsearchClientSettings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,31 @@ public class DefaultTests : BaseFunctionalTestCase

public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output)
{
Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey));

this._esConfig = cfg.GetSection("KernelMemory:Services:Elasticsearch").Get<ElasticsearchConfig>()!;

this._memory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithOpenAI(this.OpenAiConfig)
// .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration)
// .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration)
.WithElasticsearchMemoryDb(this._esConfig)
.Build<MemoryServerless>();
if (cfg.GetValue<bool>("UseAzureOpenAI"))
{
Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey));

this._memory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration)
.WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration)
.WithElasticsearchMemoryDb(this._esConfig)
.Build<MemoryServerless>();
}
else
{
Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey));

this._memory = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithOpenAI(this.OpenAiConfig)
.WithElasticsearchMemoryDb(this._esConfig)
.Build<MemoryServerless>();
}
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<UserSecretsId>3e71ea16-d53c-41a5-9202-c3f264cf1c24</UserSecretsId>
</PropertyGroup>

<ItemGroup>
Expand Down
35 changes: 20 additions & 15 deletions extensions/Elasticsearch/Elasticsearch/ElasticsearchMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -306,40 +306,45 @@ private QueryDescriptor<ElasticsearchMemoryRecord> ConvertTagFilters(
QueryDescriptor<ElasticsearchMemoryRecord> qd,
ICollection<MemoryFilter>? filters = null)
{
if ((filters == null) || (filters.Count == 0))
{
qd.MatchAll();
return qd;
}

filters = filters.Where(f => f.Keys.Count > 0)
.ToList(); // Remove empty filters

if (filters.Count == 0)
var hasOneNotEmptyFilter = filters != null && filters.Any(f => !f.IsEmpty());
if (!hasOneNotEmptyFilter)
{
qd.MatchAll();
return qd;
}

List<Query> super = new();

foreach (MemoryFilter filter in filters)
foreach (MemoryFilter filter in filters!)
{
List<Query> thisMust = new();

// Each filter is a list of key/value pairs.
foreach (var pair in filter.Pairs)
foreach (var baseFilter in filter.GetAllFilters())
{
Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.TagsName) { Value = pair.Key };
Query termQuery = new TermQuery(ElasticsearchMemoryRecord.TagsValue) { Value = pair.Value ?? string.Empty };
Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.TagsName) { Value = baseFilter.Key };
Query termQuery = new TermQuery(ElasticsearchMemoryRecord.TagsValue) { Value = baseFilter.Value ?? string.Empty };

newTagQuery &= termQuery;

var nestedQd = new NestedQuery();
nestedQd.Path = ElasticsearchMemoryRecord.TagsField;
nestedQd.Query = newTagQuery;

thisMust.Add(nestedQd);
if (baseFilter is EqualFilter eq)
{
thisMust.Add(nestedQd);
}
else if (baseFilter is NotEqualFilter neq)
{
var notQuery = new BoolQuery();
notQuery.MustNot = [nestedQd];
thisMust.Add(notQuery);
}
else
{
throw new ElasticsearchException($"Filter type {baseFilter.GetType().Name} is not supported.");
}
}

var filterQuery = new BoolQuery();
Expand Down
2 changes: 1 addition & 1 deletion extensions/Postgres/Postgres/PostgresMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ private static string NormalizeTableNamePrefix(string? name)
.Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}").ToList();
List<string> safeSqlPlaceholders = new();

List<string> conditions = new ();
List<string> conditions = new();

if (equalTags.Count > 0)
{
Expand Down
12 changes: 6 additions & 6 deletions extensions/SQLServer/SQLServer.FunctionalTests/DefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,39 @@ public class DefaultTests : BaseFunctionalTestCase

public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output)
{
IKernelMemoryBuilder builder;
if (cfg.GetValue<bool>("UseAzureOpenAI"))
{
Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey));

SqlServerConfig sqlServerConfig = cfg.GetSection("KernelMemory:Services:SqlServer").Get<SqlServerConfig>()!;

this._memory = new KernelMemoryBuilder()
builder = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.Configure(kmb => kmb.Services.AddLogging(b => { b.AddConsole().SetMinimumLevel(LogLevel.Trace); }))
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration)
.WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration)
.WithSqlServerMemoryDb(sqlServerConfig)
.Build<MemoryServerless>();
.WithSqlServerMemoryDb(sqlServerConfig);
}
else
{
Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey));

SqlServerConfig sqlServerConfig = cfg.GetSection("KernelMemory:Services:SqlServer").Get<SqlServerConfig>()!;

this._memory = new KernelMemoryBuilder()
builder = new KernelMemoryBuilder()
.With(new KernelMemoryConfig { DefaultIndexName = "default4tests" })
.Configure(kmb => kmb.Services.AddLogging(b => { b.AddConsole().SetMinimumLevel(LogLevel.Trace); }))
.WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound })
.WithOpenAI(this.OpenAiConfig)
// .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration)
// .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration)
.WithSqlServerMemoryDb(sqlServerConfig)
.Build<MemoryServerless>();
.WithSqlServerMemoryDb(sqlServerConfig);
}

var serviceProvider = builder.Services.BuildServiceProvider();
this._memory = builder.Build<MemoryServerless>();
this._memoryDb = serviceProvider.GetService<IMemoryDb>()!;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,17 @@ await memory.ImportDocumentAsync(
log(answer.Result);
Assert.Contains(Found, answer.Result, StringComparison.OrdinalIgnoreCase);

// Simple filter: NOT the news.
answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByNotTag("type", "news"), index: indexName);
log(answer.Result);
Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase);

// Simple filter: the memory is of the user but we do not want to use memory of that user.
answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByNotTag("user", "owner"), index: indexName);
log(answer.Result);
Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase);

// Simple filter: Test AND logic between equality and not equality
// not equality on a field where we have two names
answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("user", "owner").ByNotTag("type", "news"), index: indexName);
log(answer.Result);
Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase);
Expand Down

0 comments on commit 49dad5b

Please sign in to comment.