Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Microsoft.Extensions.VectorData support #48

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<CentralPackageTransitivePinningEnabled>true</CentralPackageTransitivePinningEnabled>
<MicrosoftExtensionsVersion>8.6.0</MicrosoftExtensionsVersion>
<AspireVersion>8.2.0</AspireVersion>
<MicrosoftExtensionsAiVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAiVersion>
<MicrosoftExtensionsAiVersion>9.0.0-preview.9.24525.1</MicrosoftExtensionsAiVersion>
</PropertyGroup>
<ItemGroup>
<!-- Version together with Aspire -->
Expand All @@ -18,11 +18,11 @@
<PackageVersion Include="Aspire.Npgsql.EntityFrameworkCore.PostgreSQL" Version="$(AspireVersion)" />
<PackageVersion Include="Aspire.StackExchange.Redis" Version="$(AspireVersion)" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.1.0-beta.1" />
<PackageVersion Include="OllamaSharp" Version="4.0.4" />
<PackageVersion Include="IdentityModel" Version="7.0.0" />
<PackageVersion Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.7" />
<PackageVersion Include="Microsoft.AspNetCore.Authentication.OpenIdConnect" Version="8.0.7" />
<PackageVersion Include="Microsoft.Extensions.AI" Version="$(MicrosoftExtensionsAiVersion)" />
<PackageVersion Include="Microsoft.Extensions.AI.Ollama" Version="$(MicrosoftExtensionsAiVersion)" />
<PackageVersion Include="Microsoft.Extensions.AI.OpenAI" Version="$(MicrosoftExtensionsAiVersion)" />
<PackageVersion Include="Microsoft.Extensions.Hosting" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.ServiceDiscovery" Version="$(AspireVersion)" />
Expand All @@ -34,10 +34,10 @@
<PackageVersion Include="Microsoft.FluentUI.AspNetCore.Components.DataGrid.EntityFrameworkAdapter" Version="4.9.3" />
<PackageVersion Include="Microsoft.FluentUI.AspNetCore.Components.Icons" Version="4.9.3" />
<PackageVersion Include="Microsoft.Playwright" Version="1.45.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.16.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.Qdrant" Version="1.16.0-alpha" />
<PackageVersion Include="Microsoft.SemanticKernel.Core" Version="1.16" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="1.16.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.26.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.Qdrant" Version="1.26.0-preview" />
<PackageVersion Include="Microsoft.SemanticKernel.Core" Version="1.26.0" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="1.26.0" />
<!-- Open Telemetry -->
<PackageVersion Include="OpenTelemetry.Exporter.OpenTelemetryProtocol" Version="1.9.0" />
<PackageVersion Include="OpenTelemetry.Extensions.Hosting" Version="1.9.0" />
Expand Down
21 changes: 11 additions & 10 deletions seeddata/DataGenerator/Generators/TicketThreadGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,17 @@ private class AssistantTools(IEmbeddingGenerator<string, Embedding<float>> embed
{
// Obviously it would be more performant to chunk and embed each manual only once, but this is simpler for now
var chunks = SplitIntoChunks(manual.MarkdownText, 200).ToList();
var embeddings = await embedder.GenerateAsync(chunks);
var candidates = chunks.Zip(embeddings);
var queryEmbedding = (await embedder.GenerateAsync(query)).Single();

var closest = candidates
.Select(c => new { Text = c.First, Similarity = TensorPrimitives.CosineSimilarity(c.Second.Vector.Span, queryEmbedding.Vector.Span) })
.OrderByDescending(c => c.Similarity)
.Take(3)
.Where(c => c.Similarity > 0.6f)
.ToList();

var candidates = await embedder.GenerateAndZipAsync(chunks);
var queryEmbedding = await embedder.GenerateEmbeddingAsync(query);

var closest =
candidates
.Select(c => new { Text = c.Value, Similarity = TensorPrimitives.CosineSimilarity(c.Embedding.Vector.Span, queryEmbedding.Vector.Span) })
.OrderByDescending(c => c.Similarity)
.Take(3)
.Where(c => c.Similarity > 0.6f)
.ToList();

if (closest.Any())
{
Expand Down
31 changes: 12 additions & 19 deletions src/Backend/Api/AssistantApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.ComponentModel;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
Expand Down Expand Up @@ -105,24 +106,28 @@ public async Task<object> SearchManual(
await httpContext.Response.WriteAsync(JsonSerializer.Serialize(new AssistantChatReplyItem(AssistantChatReplyItemType.Search, searchPhrase)));

// Do the search, and supply the results to the UI so it can show one as a citation link
var searchResults = await manualSearch.SearchAsync(productId, searchPhrase);
var searchResults =
await (await manualSearch.SearchAsync(productId, searchPhrase))
.Results
.ToListAsync();

foreach (var r in searchResults)
{
await httpContext.Response.WriteAsync(",\n");
await httpContext.Response.WriteAsync(JsonSerializer.Serialize(new AssistantChatReplyItem(
AssistantChatReplyItemType.SearchResult,
string.Empty,
int.Parse(r.Metadata.Id),
GetProductId(r),
GetPageNumber(r))));
r.Record.ChunkId, // This is the ID of the record returned. Looking at the mapping, it was using the ChunkID
r.Record.ProductId,
r.Record.PageNumber)));
}

// Return the search results to the assistant
return searchResults.Select(r => new
{
ProductId = GetProductId(r),
SearchResultId = r.Metadata.Id,
r.Metadata.Text,
ProductId = r.Record.ProductId,
SearchResultId = r.Record.ChunkId,
r.Record.Text,
});
}
finally
Expand All @@ -131,16 +136,4 @@ await httpContext.Response.WriteAsync(JsonSerializer.Serialize(new AssistantChat
}
}
}

private static int? GetProductId(MemoryQueryResult result)
{
var match = Regex.Match(result.Metadata.ExternalSourceName, @"productid:(\d+)");
return match.Success ? int.Parse(match.Groups[1].Value) : null;
}

private static int? GetPageNumber(MemoryQueryResult result)
{
var match = Regex.Match(result.Metadata.AdditionalMetadata, @"pagenumber:(\d+)");
return match.Success ? int.Parse(match.Groups[1].Value) : null;
}
}
10 changes: 9 additions & 1 deletion src/Backend/Data/ManualChunk.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
namespace eShopSupport.Backend.Data;
using Microsoft.Extensions.VectorData;

namespace eShopSupport.Backend.Data;

public class ManualChunk
{
[VectorStoreRecordData]
public int ChunkId { get; set; }
[VectorStoreRecordKey]
public int ProductId { get; set; }
[VectorStoreRecordData]
public int PageNumber { get; set; }
[VectorStoreRecordData]
public required string Text { get; set; }

[VectorStoreRecordVector(384,DistanceFunction.CosineDistance)]
public required byte[] Embedding { get; set; }
}
69 changes: 34 additions & 35 deletions src/Backend/Services/ProductManualSemanticSearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,36 @@
using System.Text.Json;
using Azure.Storage.Blobs;
using eShopSupport.Backend.Data;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;

namespace eShopSupport.Backend.Services;

public class ProductManualSemanticSearch(ITextEmbeddingGenerationService embedder, IServiceProvider services)
public class ProductManualSemanticSearch(ITextEmbeddingGenerationService embedder, IVectorStore store)
{
private const string ManualCollectionName = "manuals";

public async Task<IReadOnlyList<MemoryQueryResult>> SearchAsync(int? productId, string query)
public async Task<VectorSearchResults<ManualChunk>> SearchAsync(int? productId, string query)
{
var embedding = await embedder.GenerateEmbeddingAsync(query);
var filter = !productId.HasValue
? null
: new
{
must = new[]
{
new { key = "external_source_name", match = new { value = $"productid:{productId}" } }
}
};

var httpClient = services.GetQdrantHttpClient("vector-db");
var response = await httpClient.PostAsync($"collections/{ManualCollectionName}/points/search",
JsonContent.Create(new
{
vector = embedding,
with_payload = new[] { "id", "text", "external_source_name", "additional_metadata" },
limit = 3,
filter,
}));

var responseParsed = await response.Content.ReadFromJsonAsync<QdrantResult>();

return responseParsed!.Result.Select(r => new MemoryQueryResult(
new MemoryRecordMetadata(true, r.Payload.Id, r.Payload.Text, "", r.Payload.External_Source_Name, r.Payload.Additional_Metadata),
r.Score,
null)).ToList();
var filter = new VectorSearchFilter([
new EqualToFilterClause("external_source_name", $"productid:{productId}")
]);


var searchOptions = new VectorSearchOptions
{
Filter = filter,
Top = 3
};

var collection = store.GetCollection<int,ManualChunk>(ManualCollectionName);

var results = await collection.VectorizedSearchAsync(embedding, searchOptions);

return results;
}

public static async Task EnsureSeedDataImportedAsync(IServiceProvider services, string? initialImportDataDir)
Expand Down Expand Up @@ -75,26 +67,33 @@ private static async Task ImportManualFilesSeedDataAsync(string importDataFromDi

private static async Task ImportManualChunkSeedDataAsync(string importDataFromDir, IServiceScope scope)
{
var semanticMemory = scope.ServiceProvider.GetRequiredService<IMemoryStore>();
var collections = await semanticMemory.GetCollectionsAsync().ToListAsync();
var semanticMemory = scope.ServiceProvider.GetRequiredService<IVectorStore>();
var collections = await semanticMemory.ListCollectionNamesAsync().ToListAsync();

if (!collections.Contains(ManualCollectionName))
{
await semanticMemory.CreateCollectionAsync(ManualCollectionName);
var collection = semanticMemory.GetCollection<int,ManualChunk>(ManualCollectionName);

using var fileStream = File.OpenRead(Path.Combine(importDataFromDir, "manual-chunks.json"));
var manualChunks = JsonSerializer.DeserializeAsyncEnumerable<ManualChunk>(fileStream);
await foreach (var chunkChunk in ReadChunkedAsync(manualChunks, 1000))
{

var mappedRecords = chunkChunk.Select(chunk =>
{
var id = chunk!.ChunkId.ToString();
var metadata = new MemoryRecordMetadata(false, id, chunk.Text, "", $"productid:{chunk.ProductId}", $"pagenumber:{chunk.PageNumber}");
var embedding = MemoryMarshal.Cast<byte, float>(new ReadOnlySpan<byte>(chunk.Embedding)).ToArray();
return new MemoryRecord(metadata, embedding, null);

});

await foreach (var _ in semanticMemory.UpsertBatchAsync(ManualCollectionName, mappedRecords)) { }

//var mappedRecords = chunkChunk.Select(chunk =>
//{
// var id = chunk!.ChunkId.ToString();
// var metadata = new MemoryRecordMetadata(false, id, chunk.Text, "", $"productid:{chunk.ProductId}", $"pagenumber:{chunk.PageNumber}");
// var embedding = MemoryMarshal.Cast<byte, float>(new ReadOnlySpan<byte>(chunk.Embedding)).ToArray();
// return new MemoryRecord(metadata, embedding, null);
//});

//await foreach (var _ in semanticMemory.UpsertBatchAsync(ManualCollectionName, mappedRecords)) { }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.AI.OpenAI;
using OllamaSharp;
using System.ClientModel;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -47,7 +48,7 @@ public static IServiceCollection AddOllamaChatClient(
pipeline.UsePreventStreamingWithFunctions();

var httpClient = pipeline.Services.GetService<HttpClient>() ?? new();
return pipeline.Use(new OllamaChatClient(uri, modelName, httpClient));
return pipeline.Use(new OllamaApiClient(httpClient, modelName));
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/ServiceDefaults/ServiceDefaults.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="Microsoft.Extensions.AI" />
<PackageReference Include="Microsoft.Extensions.AI.Ollama" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
<PackageReference Include="Microsoft.Extensions.Http.Resilience" />
<PackageReference Include="Microsoft.Extensions.ServiceDiscovery" />
<PackageReference Include="OllamaSharp" />
<PackageReference Include="OpenTelemetry.Exporter.OpenTelemetryProtocol" />
<PackageReference Include="OpenTelemetry.Extensions.Hosting" />
<PackageReference Include="OpenTelemetry.Instrumentation.AspNetCore" />
Expand Down