Skip to content

Commit

Permalink
Merge pull request #1371 from solliancenet/cjg-080-audio-agent
Browse files Browse the repository at this point in the history
Merge pull request #1369 from solliancenet/cjg-audio-agent
  • Loading branch information
ciprianjichici authored Aug 8, 2024
2 parents eaf9197 + 19e68e4 commit 99e5719
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 114 deletions.
1 change: 1 addition & 0 deletions src/dotnet/Agent/Models/Resources/AgentReference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class AgentReference : ResourceReference
{
AgentTypes.Basic => typeof(AgentBase),
AgentTypes.KnowledgeManagement => typeof(KnowledgeManagementAgent),
AgentTypes.AudioClassification => typeof(AudioClassificationAgent),
_ => throw new ResourceProviderException($"The agent type {Type} is not supported.")
};
}
Expand Down
1 change: 1 addition & 0 deletions src/dotnet/Agent/ResourceProviders/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public static void AddAgentResourceProvider(this IHostApplicationBuilder builder
// Register validators.
builder.Services.AddSingleton<IValidator<AgentBase>, AgentBaseValidator>();
builder.Services.AddSingleton<IValidator<KnowledgeManagementAgent>, KnowledgeManagementAgentValidator>();
builder.Services.AddSingleton<IValidator<AudioClassificationAgent>, KnowledgeManagementAgentValidator>();

builder.Services.AddSingleton<IResourceProviderService, AgentResourceProviderService>(sp =>
new AgentResourceProviderService(
Expand Down
2 changes: 2 additions & 0 deletions src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace FoundationaLLM.Common.Models.ResourceProviders.Agent
/// </summary>
[JsonPolymorphic(TypeDiscriminatorPropertyName = "type")]
[JsonDerivedType(typeof(KnowledgeManagementAgent), "knowledge-management")]
[JsonDerivedType(typeof(AudioClassificationAgent), "audio-classification")]
public class AgentBase : ResourceBase
{
/// <inheritdoc/>
Expand Down Expand Up @@ -70,6 +71,7 @@ public class AgentBase : ResourceBase
Type switch
{
AgentTypes.KnowledgeManagement => typeof(KnowledgeManagementAgent),
AgentTypes.AudioClassification => typeof(AudioClassificationAgent),
_ => throw new ResourceProviderException($"The agent type {Type} is not supported.")
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@ public static class AgentTypes
/// Analytic agents are best for querying, analyzing, calculating, and reporting on tabular data.
/// </summary>
public const string Analytic = "analytic";

/// <summary>
/// An audio classification agent.
/// </summary>
public const string AudioClassification = "audio-classification";

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System.Text.Json.Serialization;

namespace FoundationaLLM.Common.Models.ResourceProviders.Agent
{
/// <summary>
/// The Knowledge Management agent metadata model.
/// </summary>
public class AudioClassificationAgent : KnowledgeManagementAgent
{
/// <summary>
/// Set default property values.
/// </summary>
public AudioClassificationAgent() =>
Type = AgentTypes.AudioClassification;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ private async Task Authorize(ResourcePath resourcePath, UnifiedUserIdentity? use
}
catch (AuthorizationException)
{
_logger.LogWarning("The {ActionType} access to the resource path {ResourcePath} was not authorized for user {UserName}.",
actionType, resourcePath.GetObjectId(_instanceSettings.Id, _name), userIdentity!.Username);
_logger.LogWarning("The {ActionType} access to the resource path {ResourcePath} was not authorized for user {UserName} : userId {UserId}.",
actionType, resourcePath.GetObjectId(_instanceSettings.Id, _name), userIdentity!.Username, userIdentity!.UserId);
throw new ResourceProviderException("Access is not authorized.", StatusCodes.Status403Forbidden);
}
catch (Exception ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using FoundationaLLM.Common.Models.ResourceProviders.Prompt;
using FoundationaLLM.Common.Models.ResourceProviders.Vectorization;
using FoundationaLLM.Orchestration.Core.Interfaces;
using Microsoft.Azure.Cosmos.Serialization.HybridRow;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using System.Net;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class OrchestrationBuilder

if (result.Agent == null) return null;

if (result.Agent.AgentType == typeof(KnowledgeManagementAgent))
if (result.Agent.AgentType == typeof(KnowledgeManagementAgent) || result.Agent.AgentType == typeof(AudioClassificationAgent))
{
var orchestrationName = string.IsNullOrWhiteSpace(result.Agent.OrchestrationSettings?.Orchestrator)
? LLMOrchestrationServiceNames.LangChain
Expand Down Expand Up @@ -151,8 +152,10 @@ public class OrchestrationBuilder
.ToDictionary(x => x.Name, x => x.Description);
explodedObjects[CompletionRequestObjectsKeys.AllAgents] = allAgentsDescriptions;

if (agentBase is KnowledgeManagementAgent kmAgent)
if (agentBase.AgentType == typeof(KnowledgeManagementAgent) || agentBase.AgentType == typeof(AudioClassificationAgent))
{
KnowledgeManagementAgent kmAgent = (KnowledgeManagementAgent)agentBase;

// Check for inline-context agents, they are valid KM agents that do not have a vectorization section.
if (kmAgent is {Vectorization: not null, InlineContext: false})
{
Expand Down
44 changes: 24 additions & 20 deletions src/python/LangChainAPI/app/routers/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@
responses={404: {'description':'Not found'}}
)

async def resolve_completion_request(request_body: dict = Body(...)) -> CompletionRequestBase:
agent_type = request_body.get("agent", {}).get("type", None)
async def resolve_completion_request(request_body: dict = Body(...)) -> CompletionRequestBase:
agent_type = request_body.get("agent", {}).get("type", None)

match agent_type:
case "knowledge-management":
request = KnowledgeManagementCompletionRequest(**request_body)
request.agent.type = agent_type
return request
case "audio-classification":
request = KnowledgeManagementCompletionRequest(**request_body)
request.agent.type = agent_type
return request
case _:
raise ValueError(f"Unsupported agent type: {agent_type}")

Expand All @@ -71,7 +75,7 @@ async def submit_completion_request(
) -> LongRunningOperation:
"""
Initiates the creation of a completion response in the background.
Returns
-------
CompletionOperation
Expand All @@ -81,19 +85,19 @@ async def submit_completion_request(
try:
# Get the operation_id from the completion request.
operation_id = completion_request.operation_id

span.set_attribute('operation_id', operation_id)
span.set_attribute('instance_id', instance_id)
span.set_attribute('user_identity', x_user_identity)

location = f'{raw_request.base_url}instances/{instance_id}/async-completions/{operation_id}/status'
response.headers['location'] = location

# Create an operations manager to create the operation.
operations_manager = OperationsManager(raw_request.app.extra['config'])
# Submit the completion request operation to the state API.
operation = await operations_manager.create_operation(operation_id, instance_id)

# Start a background task to perform the completion request.
background_tasks.add_task(
create_completion_response,
Expand All @@ -106,7 +110,7 @@ async def submit_completion_request(

# Return the long running operation object.
return operation

except Exception as e:
handle_exception(e)

Expand All @@ -123,7 +127,7 @@ async def create_completion_response(
with tracer.start_as_current_span(f'create_completion_response') as span:
# Create an operations manager to update the operation status.
operations_manager = OperationsManager(configuration)

try:
span.set_attribute('operation_id', operation_id)
span.set_attribute('instance_id', instance_id)
Expand All @@ -136,7 +140,7 @@ async def create_completion_response(
status = OperationStatus.INPROGRESS,
status_message = 'Operation state changed to in progress.'
)

# Create an orchestration manager to process the completion request.
orchestration_manager = OrchestrationManager(
completion_request = completion_request,
Expand All @@ -157,7 +161,7 @@ async def create_completion_response(
status=OperationStatus.COMPLETED,
status_message=f'Operation {operation_id} completed successfully.'
)
)
)
except Exception as e:
# Send the completion response to the State API and mark the operation as failed.
print(f'Operation {operation_id} failed with error: {e}')
Expand All @@ -177,7 +181,7 @@ async def create_completion_response(
status = OperationStatus.FAILED,
status_message = f'Operation failed with error: {e}'
)
)
)

@router.get(
'/async-completions/{operation_id}/status',
Expand All @@ -195,7 +199,7 @@ async def get_operation_status(
with tracer.start_as_current_span(f'get_operation_status') as span:
# Create an operations manager to get the operation status.
operations_manager = OperationsManager(raw_request.app.extra['config'])

try:
span.set_attribute('operation_id', operation_id)
span.set_attribute('instance_id', instance_id)
Expand All @@ -204,7 +208,7 @@ async def get_operation_status(
operation_id,
instance_id
)

if operation is None:
raise HTTPException(status_code=404)

Expand All @@ -228,16 +232,16 @@ async def get_operation_result(
with tracer.start_as_current_span(f'get_operation_result') as span:
# Create an operations manager to get the operation result.
operations_manager = OperationsManager(raw_request.app.extra['config'])

try:
span.set_attribute('operation_id', operation_id)
span.set_attribute('instance_id', instance_id)

completion_response = await operations_manager.get_operation_result(
operation_id,
instance_id
)

if completion_response is None:
raise HTTPException(status_code=404)

Expand All @@ -261,16 +265,16 @@ async def get_operation_log(
with tracer.start_as_current_span(f'get_operation_log') as span:
# Create an operations manager to get the operation log.
operations_manager = OperationsManager(raw_request.app.extra['config'])

try:
span.set_attribute('operation_id', operation_id)
span.set_attribute('instance_id', instance_id)

log = await operations_manager.get_operation_log(
operation_id,
instance_id
)

if log is None:
raise HTTPException(status_code=404)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""LangChain Agents module"""
from .langchain_agent_base import LangChainAgentBase
from .langchain_knowledge_management_agent import LangChainKnowledgeManagementAgent
from .langchain_audio_classifier_agent import LangChainAudioClassifierAgent
from .agent_factory import AgentFactory
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from foundationallm.config import Configuration
from foundationallm.langchain.agents import (
LangChainAgentBase,
LangChainKnowledgeManagementAgent
LangChainKnowledgeManagementAgent,
LangChainAudioClassifierAgent
)

class AgentFactory:
Expand All @@ -28,17 +29,19 @@ def get_agent(self, agent_type: str) -> LangChainAgentBase:
----------
agent_type : str
The type type assign to the agent returned.
Returns
-------
AgentBase
Returns an agent of the requested type.
"""
if agent_type is None:
raise ValueError("Agent not constructed. Cannot access an object of 'NoneType'.")

match agent_type:
case 'knowledge-management':
return LangChainKnowledgeManagementAgent(config=self.config)
case 'audio-classification':
return LangChainAudioClassifierAgent(config=self.config)
case _:
raise ValueError(f'The agent type {agent_type} is not supported.')
Loading

0 comments on commit 99e5719

Please sign in to comment.