Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new streaming ask endpoint (WIP) #400

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<PackageVersion Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Memory.Data" Version="8.0.0" />
<PackageVersion Include="System.Net.Http.Json" Version="8.0.0" />
<PackageVersion Include="System.Numerics.Tensors" Version="8.0.0" />
<PackageVersion Include="System.Text.Json" Version="8.0.3" />
</ItemGroup>
Expand Down
44 changes: 44 additions & 0 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Json;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
Expand Down Expand Up @@ -358,6 +360,48 @@ public async Task<MemoryAnswer> AskAsync(
return JsonSerializer.Deserialize<MemoryAnswer>(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
}

/// <inheritdoc />
public async IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

MemoryQuery request = new()
{
Index = index,
Question = question,
Filters = (filters is { Count: > 0 }) ? filters.ToList() : new(),
MinRelevance = minRelevance
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");

using var httpRequest = new HttpRequestMessage(HttpMethod.Post, Constants.HttpAskStreamEndpoint);
httpRequest.Content = content;

using HttpResponseMessage response = await this._client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();

await foreach (var responsePart in response.Content.ReadFromJsonAsAsyncEnumerable<string>(cancellationToken))
{
if (responsePart is null || responsePart.Length == 0)
{
continue;
}

yield return responsePart;
}
}

#region private

private static (string contentType, long contentLength, DateTimeOffset lastModified) GetFileDetails(HttpResponseMessage response)
Expand Down
1 change: 1 addition & 0 deletions clients/dotnet/WebClient/WebClient.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

<ItemGroup>
<ProjectReference Include="..\..\..\service\Abstractions\Abstractions.csproj" />
<PackageReference Include="System.Net.Http.Json" />
</ItemGroup>

<PropertyGroup>
Expand Down
30 changes: 30 additions & 0 deletions examples/001-dotnet-WebClient/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public static async Task Main()
// =======================

await AskSimpleQuestion();
await AskSimpleQuestionWithStreaming();
await AskSimpleQuestionAndShowSources();
await AskQuestionAboutImageContent();
await AskQuestionUsingFilter();
Expand Down Expand Up @@ -275,6 +276,35 @@ due to the speed of light being a very large number when squared. This concept i
*/
}

// Question without filters
private static async Task AskSimpleQuestionWithStreaming()
{
var question = "Any news from NASA about Orion?";
Console.WriteLine($"Question: {question}");
var answer = s_memory.AskStreamingAsync(question, filter: MemoryFilters.ByTag("user", "Taylor"));

Console.WriteLine("\nAnswer:\n");
await foreach (var answerPart in answer)
{
Console.Write(answerPart);
}

Console.WriteLine("\n====================================\n");

/* OUTPUT

Question: Any news from NASA about Orion?

Answer: Yes, NASA has invited media to see the new test version of the Orion spacecraft and the hardware teams will use to recover the capsule and astronauts upon their return from space during the Artemis II mission.
The event will take place at Naval Base San Diego on August 2.
Personnel involved in recovery operations from NASA, the U.S. Navy, and the U.S. Air Force will be available to speak with media.
Teams are currently conducting tests in the Pacific Ocean to demonstrate and evaluate the processes, procedures, and hardware for recovery operations for crewed Artemis missions.
The tests will help prepare the team for Artemis II, NASA's first crewed mission under Artemis that will send four astronauts in Orion around the Moon to checkout systems ahead of future lunar missions.
The Artemis II crew will participate in recovery testing at sea next year.

*/
}

// Another question without filters and show sources
private static async Task AskSimpleQuestionAndShowSources()
{
Expand Down
1 change: 1 addition & 0 deletions service/Abstractions/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public static class Constants

// Endpoints
public const string HttpAskEndpoint = "/ask";
public const string HttpAskStreamEndpoint = "/ask/stream";
public const string HttpSearchEndpoint = "/search";
public const string HttpDownloadEndpoint = "/download";
public const string HttpUploadEndpoint = "/upload";
Expand Down
18 changes: 18 additions & 0 deletions service/Abstractions/IKernelMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,22 @@ public Task<MemoryAnswer> AskAsync(
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Search the given index for an answer to the given query.
/// </summary>
/// <param name="question">Question to answer</param>
/// <param name="index">Optional index name</param>
/// <param name="filter">Filter to match</param>
/// <param name="filters">Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list.</param>
/// <param name="minRelevance">Minimum Cosine Similarity required</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>A stream that contains an answer to the query, or an empty list</returns>
public IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);
}
16 changes: 16 additions & 0 deletions service/Abstractions/Search/ISearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ Task<MemoryAnswer> AskAsync(
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Answer the given question, if possible.
/// </summary>
/// <param name="index">Index (aka collection) to search for grounding information</param>
/// <param name="question">Question to answer</param>
/// <param name="filters">Filtering criteria to select memories to consider</param>
/// <param name="minRelevance">Minimum relevance of the memories considered</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the given question</returns>
IAsyncEnumerable<string> AskStreamingAsync(
string index,
string question,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// List the available memory indexes (aka collections).
/// </summary>
Expand Down
25 changes: 25 additions & 0 deletions service/Core/MemoryServerless.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskStreamingAsync(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this new method returns only the answer, missing information about sources, relevance, etc. Why not stream a json object?

Copy link
Author

@JonathanVelkeneers JonathanVelkeneers May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this new method returns only the answer, missing information about sources, relevance, etc. Why not stream a json object?

I did a small write-up here.

In this PR I made it as a stream of strings, because using an object would require making some design choices.
Is the extra information (sources, relevance, etc) sent with every "answer part", or only the first/last one?

In an implementation I did in a personal project I sent all the metadata in the first response part, and only sent messageID + messageChunk with every remaining part.

If you prefer that approach or have another solution I can rewrite these changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everyone will want streaming, without giving up the other metadata, in particular the list of sources and their relevance. I would use the same approach used by OpenAI streaming, returning a content_update property within the response object, that includes the token - potentially sending the list of sources on the first response only using the same approach (something like sources_update) to avoid the overhead of sending the same sources with every token.

Something like, pseudo structure (I would reuse the existing class, just change it a bit to support "_update"):

response:

{
    sources_update {
          stream of items.   // sent all at once with the first token, then the list stays empty
   },
   content_update {
          next string token
   }
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some changes to return MemoryAnswer instead of string.
I'm not sure what you exactly want with MemoryAnswer. New properties for streaming, or modifying the existing ones?

In my latest changes I've used the existing properties for the time being

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just my two cents, that sounds perfect @dluc . We're interested as well in streaming, it's one of the reasons why we have not implemented KM yet. We do indeed need the sources as well.
To send the sources only on the first response makes perfect sense.

Just wondering when this will be merged? Are there any blockers? I could possibly assist if need be.

string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
25 changes: 25 additions & 0 deletions service/Core/MemoryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
116 changes: 115 additions & 1 deletion service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -345,7 +346,120 @@ public async Task<MemoryAnswer> AskAsync(
return answer;
}

private IAsyncEnumerable<string> GenerateAnswerAsync(string question, string facts)
/// <inheritdoc />
public async IAsyncEnumerable<string> AskStreamingAsync(
string index,
string question,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(question))
{
this._log.LogWarning("No question provided");
yield return this._config.EmptyAnswer;
yield break;
}

var facts = new StringBuilder();
var maxTokens = this._config.MaxAskPromptSize > 0
? this._config.MaxAskPromptSize
: this._textGenerator.MaxTokenTotal;
var tokensAvailable = maxTokens
- this._textGenerator.CountTokens(this._answerPrompt)
- this._textGenerator.CountTokens(question)
- this._config.AnswerTokens;

var factsUsedCount = 0;
var factsAvailableCount = 0;

this._log.LogTrace("Fetching relevant memories");
IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
index: index,
text: question,
filters: filters,
minRelevance: minRelevance,
limit: this._config.MaxMatchesCount,
withEmbeddings: false,
cancellationToken: cancellationToken);

// Memories are sorted by relevance, starting from the most relevant
await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
{
string fileName = memory.GetFileName(this._log);

var partitionText = memory.GetPartitionText(this._log).Trim();
if (string.IsNullOrEmpty(partitionText))
{
this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
continue;
}

factsAvailableCount++;

// TODO: add file age in days, to push relevance of newer documents
var fact = $"==== [File:{fileName};Relevance:{relevance:P1}]:\n{partitionText}\n";

// Use the partition/chunk only if there's room for it
var size = this._textGenerator.CountTokens(fact);
if (size >= tokensAvailable)
{
// Stop after reaching the max number of tokens
break;
}

factsUsedCount++;
this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance);

facts.Append(fact);
tokensAvailable -= size;
}

if (factsAvailableCount > 0 && factsUsedCount == 0)
{
this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
yield return this._config.EmptyAnswer;
yield break;
}

if (factsUsedCount == 0)
{
this._log.LogWarning("No memories available");
yield return this._config.EmptyAnswer;
yield break;
}

StringBuilder bufferedAnswer = new();
var watch = Stopwatch.StartNew();
await foreach (var x in this.GenerateAnswerAsync(question, facts.ToString())
.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (x is null || x.Length == 0)
{
continue;
}

bufferedAnswer.Append(x);
yield return x;

int currentLength = bufferedAnswer.Length;
if (currentLength <= this._config.EmptyAnswer.Length && ValueIsEquivalentTo(bufferedAnswer.ToString(), this._config.EmptyAnswer))
{
this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds);
yield break;
}

if (this._log.IsEnabled(LogLevel.Trace) && currentLength >= 30)
{
this._log.LogTrace("{0} chars generated", currentLength);
}
}

watch.Stop();
this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
}

private IAsyncEnumerable<string?> GenerateAnswerAsync(string question, string facts)
{
var prompt = this._answerPrompt;
prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
Expand Down
Loading
Loading