Skip to content

Commit

Permalink
Add the new chat session name parameter to Core client
Browse files Browse the repository at this point in the history
  • Loading branch information
alistar-andrei committed Aug 7, 2024
1 parent 0779492 commit 93d154a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 27 deletions.
10 changes: 5 additions & 5 deletions src/dotnet/CoreClient/Clients/RESTClients/SessionRESTClient.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System.Text.Encodings.Web;
using System.Text.Json;
using Azure.Core;
using Azure.Core;
using FoundationaLLM.Client.Core.Interfaces;
using FoundationaLLM.Common.Models.Chat;
using System.Text.Encodings.Web;
using System.Text.Json;

namespace FoundationaLLM.Client.Core.Clients.RESTClients
{
Expand All @@ -17,10 +17,10 @@ internal class SessionRESTClient(
private readonly string _instanceId = instanceId ?? throw new ArgumentNullException(nameof(instanceId));

/// <inheritdoc/>
public async Task<string> CreateSessionAsync()
public async Task<string> CreateSessionAsync(string chatSessionName)
{
var coreClient = await GetCoreClientAsync();
var responseSession = await coreClient.PostAsync($"instances/{_instanceId}/sessions", null);
var responseSession = await coreClient.PostAsync($"instances/{_instanceId}/sessions?chatSessionName={chatSessionName}", null);

if (responseSession.IsSuccessStatusCode)
{
Expand Down
24 changes: 17 additions & 7 deletions src/dotnet/CoreClient/CoreClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,9 @@ public CoreClient(
_coreRestClient = new CoreRESTClient(coreUri, credential, instanceId, options);

/// <inheritdoc/>
public async Task<string> CreateChatSessionAsync(string? sessionName)
public async Task<string> CreateChatSessionAsync(string sessionName)
{
var sessionId = await _coreRestClient.Sessions.CreateSessionAsync();
if (!string.IsNullOrWhiteSpace(sessionName))
{
await _coreRestClient.Sessions.RenameChatSession(sessionId, sessionName);
}

var sessionId = await _coreRestClient.Sessions.CreateSessionAsync(sessionName);
return sessionId;
}

Expand All @@ -71,6 +66,13 @@ public async Task<Completion> GetCompletionWithSessionAsync(string? sessionId, s
{
if (string.IsNullOrWhiteSpace(sessionId))
{
if (string.IsNullOrWhiteSpace(sessionName))
{
throw new ArgumentException(
"The completion request must contain a sessionName if no sessionId is provided. " +
"A new session will be created with the provided session name.");
}

sessionId = await CreateChatSessionAsync(sessionName);
}

Expand All @@ -81,6 +83,7 @@ public async Task<Completion> GetCompletionWithSessionAsync(string? sessionId, s
SessionId = sessionId,
UserPrompt = userPrompt
};

return await GetCompletionWithSessionAsync(orchestrationRequest);
}

Expand Down Expand Up @@ -139,6 +142,13 @@ public async Task<Completion> AttachFileAndAskQuestionAsync(Stream fileStream, s
{
if (string.IsNullOrWhiteSpace(sessionId))
{
if (string.IsNullOrWhiteSpace(sessionName))
{
throw new ArgumentException(
"The completion request must contain a sessionName if no sessionId is provided. " +
"A new session will be created with the provided session name.");
}

sessionId = await CreateChatSessionAsync(sessionName);
}

Expand Down
6 changes: 3 additions & 3 deletions src/dotnet/CoreClient/Interfaces/ICoreClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ namespace FoundationaLLM.Client.Core.Interfaces
public interface ICoreClient
{
/// <summary>
/// Creates a new chat session and renames it if a session name is provided.
/// Creates a new chat session with the specified name.
/// </summary>
/// <param name="sessionName">Renames the new chat session if not null or empty.</param>
/// <param name="sessionName">The chat session name.</param>
/// <returns>The new chat session ID.</returns>
Task<string> CreateChatSessionAsync(string? sessionName);
Task<string> CreateChatSessionAsync(string sessionName);

/// <summary>
/// Runs a single completion request with an agent using the Core API and a chat session.
Expand Down
6 changes: 3 additions & 3 deletions src/dotnet/CoreClient/Interfaces/ISessionRESTClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using FoundationaLLM.Common.Models.Chat;
using FoundationaLLM.Common.Models.Orchestration;

namespace FoundationaLLM.Client.Core.Interfaces
{
Expand All @@ -24,10 +23,11 @@ public interface ISessionRESTClient
Task RateMessageAsync(string sessionId, string messageId, bool rating);

/// <summary>
/// Creates and renames a session.
/// Creates a new session with the specified name.
/// </summary>
/// <param name="chatSessionName">The name for the chat session.</param>
/// <returns>Returns the new Session ID.</returns>
Task<string> CreateSessionAsync();
Task<string> CreateSessionAsync(string chatSessionName);

/// <summary>
/// Renames a chat session.
Expand Down
19 changes: 10 additions & 9 deletions tests/dotnet/Core.Client.Tests/CoreClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ public async Task CreateChatSessionAsync_WithName_CreatesAndRenamesSession()
// Arrange
var sessionName = "TestSession";
var sessionId = "session-id";
_coreRestClient.Sessions.CreateSessionAsync().Returns(Task.FromResult(sessionId));
_coreRestClient.Sessions.CreateSessionAsync(sessionName).Returns(Task.FromResult(sessionId));

// Act
var result = await _coreClient.CreateChatSessionAsync(sessionName);

// Assert
Assert.Equal(sessionId, result);
await _coreRestClient.Sessions.Received(1).CreateSessionAsync();
await _coreRestClient.Sessions.Received(1).RenameChatSession(sessionId, sessionName);
await _coreRestClient.Sessions.Received(1).CreateSessionAsync(sessionName);
}

[Fact]
Expand All @@ -43,17 +42,18 @@ public async Task GetCompletionWithSessionAsync_WithNewSession_CreatesSessionAnd
// Arrange
var userPrompt = "Hello, World!";
var agentName = "TestAgent";
var sessionName = "TestSession";
var sessionId = "new-session-id";
var completion = new Completion();
_coreRestClient.Sessions.CreateSessionAsync().Returns(Task.FromResult(sessionId));
_coreRestClient.Sessions.CreateSessionAsync(sessionName).Returns(Task.FromResult(sessionId));
_coreRestClient.Completions.GetChatCompletionAsync(Arg.Any<CompletionRequest>()).Returns(Task.FromResult(completion));

// Act
var result = await _coreClient.GetCompletionWithSessionAsync(null, "NewSession", userPrompt, agentName);
var result = await _coreClient.GetCompletionWithSessionAsync(null, sessionName, userPrompt, agentName);

// Assert
Assert.Equal(completion, result);
await _coreRestClient.Sessions.Received(1).CreateSessionAsync();
await _coreRestClient.Sessions.Received(1).CreateSessionAsync(sessionName);
await _coreRestClient.Completions.GetChatCompletionAsync(Arg.Is<CompletionRequest>(
r => r.SessionId == sessionId && r.AgentName == agentName && r.UserPrompt == userPrompt));
}
Expand Down Expand Up @@ -100,20 +100,21 @@ public async Task AttachFileAndAskQuestionAsync_UsesSession_UploadsFileAndSendsS
var contentType = "text/plain";
var agentName = "TestAgent";
var question = "What is this file about?";
var sessionName = "TestSession";
var sessionId = "session-id";
var objectId = "object-id";
var completion = new Completion();
_coreRestClient.Attachments.UploadAttachmentAsync(fileStream, fileName, contentType).Returns(Task.FromResult(objectId));
_coreRestClient.Sessions.CreateSessionAsync().Returns(Task.FromResult(sessionId));
_coreRestClient.Sessions.CreateSessionAsync(sessionName).Returns(Task.FromResult(sessionId));
_coreRestClient.Completions.GetChatCompletionAsync(Arg.Any<CompletionRequest>()).Returns(Task.FromResult(completion));

// Act
var result = await _coreClient.AttachFileAndAskQuestionAsync(fileStream, fileName, contentType, agentName, question, true, null, "NewSession");
var result = await _coreClient.AttachFileAndAskQuestionAsync(fileStream, fileName, contentType, agentName, question, true, null, sessionName);

// Assert
Assert.Equal(completion, result);
await _coreRestClient.Attachments.Received(1).UploadAttachmentAsync(fileStream, fileName, contentType);
await _coreRestClient.Sessions.Received(1).CreateSessionAsync();
await _coreRestClient.Sessions.Received(1).CreateSessionAsync(sessionName);
await _coreRestClient.Completions.GetChatCompletionAsync(Arg.Is<CompletionRequest>(
r => r.AgentName == agentName && r.SessionId == sessionId && r.UserPrompt == question && r.Attachments.Contains(objectId)));
}
Expand Down

0 comments on commit 93d154a

Please sign in to comment.