Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new streaming ask endpoint (WIP) #400

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
Expand Down Expand Up @@ -299,6 +300,52 @@ public async Task<MemoryAnswer> AskAsync(
return JsonSerializer.Deserialize<MemoryAnswer>(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
}

/// <inheritdoc />
public async IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

MemoryQuery request = new()
{
Index = index,
Question = question,
Filters = (filters is { Count: > 0 }) ? filters.ToList() : new(),
MinRelevance = minRelevance
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");

using var httpRequest = new HttpRequestMessage(HttpMethod.Post, Constants.HttpAskStreamEndpoint);
httpRequest.Content = content;

HttpResponseMessage response = await this._client.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
JonathanVelkeneers marked this conversation as resolved.
Show resolved Hide resolved
response.EnsureSuccessStatusCode();

using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
using var reader = new StreamReader(stream, Encoding.UTF8);

while (!reader.EndOfStream)
{
var line = await reader.ReadLineAsync().ConfigureAwait(false);
if (line is null || line.Length <= 0)
{
continue;
}

yield return line;
}
}

#region private

/// <returns>Document ID</returns>
Expand Down
1 change: 1 addition & 0 deletions service/Abstractions/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public static class Constants

// Endpoints
public const string HttpAskEndpoint = "/ask";
public const string HttpAskStreamEndpoint = "/ask/stream";
public const string HttpSearchEndpoint = "/search";
public const string HttpUploadEndpoint = "/upload";
public const string HttpUploadStatusEndpoint = "/upload-status";
Expand Down
18 changes: 18 additions & 0 deletions service/Abstractions/IKernelMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,22 @@ public Task<MemoryAnswer> AskAsync(
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Search the given index for an answer to the given query.
/// </summary>
/// <param name="question">Question to answer</param>
/// <param name="index">Optional index name</param>
/// <param name="filter">Filter to match</param>
/// <param name="filters">Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list.</param>
/// <param name="minRelevance">Minimum Cosine Similarity required</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>A stream that contains an answer to the query, or an empty list</returns>
public IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);
}
16 changes: 16 additions & 0 deletions service/Abstractions/Search/ISearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ Task<MemoryAnswer> AskAsync(
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Answer the given question, if possible.
/// </summary>
/// <param name="index">Index (aka collection) to search for grounding information</param>
/// <param name="question">Question to answer</param>
/// <param name="filters">Filtering criteria to select memories to consider</param>
/// <param name="minRelevance">Minimum relevance of the memories considered</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the given question</returns>
IAsyncEnumerable<string> AskStreamingAsync(
string index,
string question,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// List the available memory indexes (aka collections).
/// </summary>
Expand Down
25 changes: 25 additions & 0 deletions service/Core/MemoryServerless.cs
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskStreamingAsync(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this new method returns only the answer, missing information about sources, relevance, etc. Why not stream a json object?

Copy link
Author

@JonathanVelkeneers JonathanVelkeneers May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this new method returns only the answer, missing information about sources, relevance, etc. Why not stream a json object?

I did a small write-up here.

In this PR I made it as a stream of strings, because using an object would require making some design choices.
Is the extra information (sources, relevance, etc) sent with every "answer part", or only the first/last one?

In an implementation I did in a personal project I sent all the metadata in the first response part, and only sent messageID + messageChunk with every remaining part.

If you prefer that approach or have another solution I can rewrite these changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everyone will want streaming, without giving up the other metadata, in particular the list of sources and their relevance. I would use the same approach used by OpenAI streaming, returning a content_update property within the response object, that includes the token - potentially sending the list of sources on the first response only using the same approach (something like sources_update) to avoid the overhead of sending the same sources with every token.

Something like, pseudo structure (I would reuse the existing class, just change it a bit to support "_update"):

response:

{
    sources_update {
          stream of items.   // sent all at once with the first token, then the list stays empty
   },
   content_update {
          next string token
   }
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some changes to return MemoryAnswer instead of string.
I'm not sure what you exactly want with MemoryAnswer. New properties for streaming, or modifying the existing ones?

In my latest changes I've used the existing properties for the time being

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just my two cents, that sounds perfect @dluc . We're interested as well in streaming, it's one of the reasons why we have not implemented KM yet. We do indeed need the sources as well.
To send the sources only on the first response makes perfect sense.

Just wondering when this will be merged? Are there any blockers? I could possibly assist if need be.

string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
25 changes: 25 additions & 0 deletions service/Core/MemoryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
116 changes: 115 additions & 1 deletion service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -331,7 +332,120 @@ public async Task<MemoryAnswer> AskAsync(
return answer;
}

private IAsyncEnumerable<string> GenerateAnswerAsync(string question, string facts)
/// <inheritdoc />
public async IAsyncEnumerable<string> AskStreamingAsync(
string index,
string question,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(question))
{
this._log.LogWarning("No question provided");
yield return this._config.EmptyAnswer;
yield break;
}

var facts = new StringBuilder();
var maxTokens = this._config.MaxAskPromptSize > 0
? this._config.MaxAskPromptSize
: this._textGenerator.MaxTokenTotal;
var tokensAvailable = maxTokens
- this._textGenerator.CountTokens(this._answerPrompt)
- this._textGenerator.CountTokens(question)
- this._config.AnswerTokens;

var factsUsedCount = 0;
var factsAvailableCount = 0;

this._log.LogTrace("Fetching relevant memories");
IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
index: index,
text: question,
filters: filters,
minRelevance: minRelevance,
limit: this._config.MaxMatchesCount,
withEmbeddings: false,
cancellationToken: cancellationToken);

// Memories are sorted by relevance, starting from the most relevant
await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
{
string fileName = memory.GetFileName(this._log);

var partitionText = memory.GetPartitionText(this._log).Trim();
if (string.IsNullOrEmpty(partitionText))
{
this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
continue;
}

factsAvailableCount++;

// TODO: add file age in days, to push relevance of newer documents
var fact = $"==== [File:{fileName};Relevance:{relevance:P1}]:\n{partitionText}\n";

// Use the partition/chunk only if there's room for it
var size = this._textGenerator.CountTokens(fact);
if (size >= tokensAvailable)
{
// Stop after reaching the max number of tokens
break;
}

factsUsedCount++;
this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance);

facts.Append(fact);
tokensAvailable -= size;
}

if (factsAvailableCount > 0 && factsUsedCount == 0)
{
this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
yield return this._config.EmptyAnswer;
yield break;
}

if (factsUsedCount == 0)
{
this._log.LogWarning("No memories available");
yield return this._config.EmptyAnswer;
yield break;
}

StringBuilder bufferedAnswer = new();
var watch = Stopwatch.StartNew();
await foreach (var x in this.GenerateAnswerAsync(question, facts.ToString())
.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (x is null || x.Length == 0)
{
continue;
}

bufferedAnswer.Append(x);
yield return x;

int currentLength = bufferedAnswer.Length;
if (currentLength <= this._config.EmptyAnswer.Length && ValueIsEquivalentTo(bufferedAnswer.ToString(), this._config.EmptyAnswer))
{
this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds);
yield break;
}

if (this._log.IsEnabled(LogLevel.Trace) && currentLength >= 30)
{
this._log.LogTrace("{0} chars generated", currentLength);
}
}

watch.Stop();
this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
}

private IAsyncEnumerable<string?> GenerateAnswerAsync(string question, string facts)
{
var prompt = this._answerPrompt;
prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
Expand Down
25 changes: 25 additions & 0 deletions service/Service/WebAPIEndpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public static void ConfigureMinimalAPI(this WebApplication app, KernelMemoryConf
app.UseDeleteIndexesEndpoint(authFilter);
app.UseDeleteDocumentsEndpoint(authFilter);
app.UseAskEndpoint(authFilter);
app.UseAskStreamEndpoint(authFilter);
app.UseSearchEndpoint(authFilter);
app.UseUploadStatusEndpoint(authFilter);
}
Expand Down Expand Up @@ -223,6 +224,30 @@ async Task<IResult> (
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
}

public static void UseAskStreamEndpoint(this IEndpointRouteBuilder app, IEndpointFilter? authFilter = null)
{
// Ask streaming endpoint
var route = app.MapPost(Constants.HttpAskStreamEndpoint, IAsyncEnumerable<string> (
MemoryQuery query,
IKernelMemory service,
ILogger<WebAPIEndpoint> log,
CancellationToken cancellationToken) =>
{
log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance);
return service.AskStreamingAsync(
question: query.Question,
index: query.Index,
filters: query.Filters,
minRelevance: query.MinRelevance,
cancellationToken: cancellationToken);
})
.Produces<IAsyncEnumerable<string>>(StatusCodes.Status200OK)
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);

if (authFilter != null) { route.AddEndpointFilter(authFilter); }
}

public static void UseSearchEndpoint(this IEndpointRouteBuilder app, IEndpointFilter? authFilter = null)
{
// Search endpoint
Expand Down