diff --git a/src/dotnet/Agent/ResourceProviders/AgentResourceProviderService.cs b/src/dotnet/Agent/ResourceProviders/AgentResourceProviderService.cs index c5ade3861b..db2a2e6d57 100644 --- a/src/dotnet/Agent/ResourceProviders/AgentResourceProviderService.cs +++ b/src/dotnet/Agent/ResourceProviders/AgentResourceProviderService.cs @@ -6,9 +6,9 @@ using FoundationaLLM.Common.Constants.Configuration; using FoundationaLLM.Common.Constants.ResourceProviders; using FoundationaLLM.Common.Exceptions; +using FoundationaLLM.Common.Extensions; using FoundationaLLM.Common.Interfaces; using FoundationaLLM.Common.Models.Authentication; -using FoundationaLLM.Common.Models.Authorization; using FoundationaLLM.Common.Models.Configuration.Instance; using FoundationaLLM.Common.Models.Events; using FoundationaLLM.Common.Models.ResourceProviders; @@ -117,8 +117,9 @@ private async Task>> LoadAgents(Resour agents = (await Task.WhenAll(_agentReferences.Values .Where(ar => !ar.Deleted) .Select(ar => LoadAgent(ar)))) - .Where(agent => agent != null) - .ToList(); + .Where(agent => agent != null) + .Select(agent => agent!) + .ToList(); } else { @@ -127,50 +128,23 @@ private async Task>> LoadAgents(Resour { agent = await LoadAgent(null, instance.ResourceId); if (agent != null) - { agents.Add(agent); - } } else { if (agentReference.Deleted) - { throw new ResourceProviderException($"Could not locate the {instance.ResourceId} agent resource.", StatusCodes.Status404NotFound); - } agent = await LoadAgent(agentReference); if (agent != null) - { agents.Add(agent); - } - } - } - - var rolesWithActions = await _authorizationService.ProcessRoleAssignmentsWithActionsRequest( - _instanceSettings.Id, - new RoleAssignmentsWithActionsRequest() - { - Scopes = agents.Select(x => x.ObjectId!).ToList(), - PrincipalId = userIdentity.UserId!, - SecurityGroupIds = userIdentity.GroupIds - }); - - var results = new List>(); - foreach (var agent in agents) - { - if (rolesWithActions[agent.ObjectId!].Actions.Contains(AuthorizableActionNames.FoundationaLLM_Agent_Agents_Read)) - { - results.Add(new ResourceProviderGetResult() - { - Resource = agent, - Actions = rolesWithActions[agent.ObjectId!].Actions, - Roles = rolesWithActions[agent.ObjectId!].Roles - }); } } - return results; + return await _authorizationService.FilterResourcesByAuthorizableAction( + _instanceSettings.Id, userIdentity, agents, + AuthorizableActionNames.FoundationaLLM_Agent_Agents_Read); } private async Task LoadAgent(AgentReference? agentReference, string? resourceId = null) diff --git a/src/dotnet/Common/Extensions/AuthorizationServiceExtensions.cs b/src/dotnet/Common/Extensions/AuthorizationServiceExtensions.cs new file mode 100644 index 0000000000..7e94f4fc75 --- /dev/null +++ b/src/dotnet/Common/Extensions/AuthorizationServiceExtensions.cs @@ -0,0 +1,54 @@ +using FoundationaLLM.Common.Interfaces; +using FoundationaLLM.Common.Models.Authentication; +using FoundationaLLM.Common.Models.Authorization; +using FoundationaLLM.Common.Models.ResourceProviders; + +namespace FoundationaLLM.Common.Extensions +{ + /// + /// Extends the interface with helper methods. + /// + public static class AuthorizationServiceExtensions + { + /// + /// Filters the list of resources based on the authorizable action. + /// + /// The object type of the resource being retrieved. + /// The service. + /// The FoundationaLLM instance identifier. + /// The providing information about the calling user identity. + /// The list of all resources. + /// The authorizable action to be checked. + /// A list of resources on which the user identity is allowed to perform the authorizable action. + public static async Task>> FilterResourcesByAuthorizableAction( + this IAuthorizationService authorizationService, + string instanceId, + UnifiedUserIdentity userIdentity, + List resources, + string authorizableAction) + where T : ResourceBase + { + var rolesWithActions = await authorizationService.ProcessRoleAssignmentsWithActionsRequest( + instanceId, + new RoleAssignmentsWithActionsRequest() + { + Scopes = resources.Select(x => x.ObjectId!).ToList(), + PrincipalId = userIdentity.UserId!, + SecurityGroupIds = userIdentity.GroupIds + }); + + var results = new List>(); + + foreach (var resource in resources) + if (rolesWithActions[resource.ObjectId!].Actions.Contains(authorizableAction)) + results.Add(new ResourceProviderGetResult() + { + Resource = resource, + Actions = rolesWithActions[resource.ObjectId!].Actions, + Roles = rolesWithActions[resource.ObjectId!].Roles + }); + + return results; + } + } +} diff --git a/src/dotnet/Common/Extensions/ResourceProviderServiceExtensions.cs b/src/dotnet/Common/Extensions/ResourceProviderServiceExtensions.cs index 6424676cf0..3afeea04c7 100644 --- a/src/dotnet/Common/Extensions/ResourceProviderServiceExtensions.cs +++ b/src/dotnet/Common/Extensions/ResourceProviderServiceExtensions.cs @@ -29,8 +29,12 @@ public static async Task GetResource( var result = await resourceProviderService.HandleGetAsync( objectId, - userIdentity); - return (result as List>)!.First().Resource; + userIdentity) as List>; + + if (result == null || result.Count == 0) + throw new ResourceProviderException($"The resource provider {resourceProviderService.Name} is unable to retrieve the {objectId} resource."); + + return result.First().Resource; } /// diff --git a/src/dotnet/DataSource/ResourceProviders/DataSourceResourceProviderService.cs b/src/dotnet/DataSource/ResourceProviders/DataSourceResourceProviderService.cs index f9ecd41974..d9e13b5c06 100644 --- a/src/dotnet/DataSource/ResourceProviders/DataSourceResourceProviderService.cs +++ b/src/dotnet/DataSource/ResourceProviders/DataSourceResourceProviderService.cs @@ -1,9 +1,11 @@ using Azure.Messaging; using FluentValidation; using FoundationaLLM.Common.Constants; +using FoundationaLLM.Common.Constants.Authorization; using FoundationaLLM.Common.Constants.Configuration; using FoundationaLLM.Common.Constants.ResourceProviders; using FoundationaLLM.Common.Exceptions; +using FoundationaLLM.Common.Extensions; using FoundationaLLM.Common.Interfaces; using FoundationaLLM.Common.Models.Authentication; using FoundationaLLM.Common.Models.Configuration.Instance; @@ -99,25 +101,25 @@ await _storageService.WriteFileAsync( protected override async Task GetResourcesAsync(ResourcePath resourcePath, UnifiedUserIdentity userIdentity) => resourcePath.ResourceTypeInstances[0].ResourceType switch { - DataSourceResourceTypeNames.DataSources => await LoadDataSources(resourcePath.ResourceTypeInstances[0]), + DataSourceResourceTypeNames.DataSources => await LoadDataSources(resourcePath.ResourceTypeInstances[0], userIdentity), _ => throw new ResourceProviderException($"The resource type {resourcePath.ResourceTypeInstances[0].ResourceType} is not supported by the {_name} resource provider.", StatusCodes.Status400BadRequest) }; #region Helpers for GetResourcesAsyncInternal - private async Task>> LoadDataSources(ResourceTypeInstance instance) + private async Task>> LoadDataSources(ResourceTypeInstance instance, UnifiedUserIdentity userIdentity) { + var dataSources = new List(); + if (instance.ResourceId == null) { - var dataSources = (await Task.WhenAll( - _dataSourceReferences.Values - .Where(dsr => !dsr.Deleted) - .Select(dsr => LoadDataSource(dsr)))) - .Where(ds => ds != null) - .ToList(); - - return dataSources.Select(dataSource => new ResourceProviderGetResult() { Resource = dataSource, Actions = [], Roles = [] }).ToList(); + dataSources = (await Task.WhenAll(_dataSourceReferences.Values + .Where(dsr => !dsr.Deleted) + .Select(dsr => LoadDataSource(dsr)))) + .Where(ds => ds != null) + .Select(ds => ds!) + .ToList(); } else { @@ -126,26 +128,24 @@ private async Task>> LoadDataSour { dataSource = await LoadDataSource(null, instance.ResourceId); if (dataSource != null) - { - return [new ResourceProviderGetResult() { Resource = dataSource, Actions = [], Roles = [] }]; - } - return []; + dataSources.Add(dataSource); } - - if (dataSourceReference.Deleted) + else { - throw new ResourceProviderException( - $"Could not locate the {instance.ResourceId} data source resource.", - StatusCodes.Status404NotFound); - } + if (dataSourceReference.Deleted) + throw new ResourceProviderException( + $"Could not locate the {instance.ResourceId} data source resource.", + StatusCodes.Status404NotFound); - dataSource = await LoadDataSource(dataSourceReference); - if (dataSource != null) - { - return [new ResourceProviderGetResult() { Resource = dataSource, Actions = [], Roles = [] }]; + dataSource = await LoadDataSource(dataSourceReference); + if (dataSource != null) + dataSources.Add(dataSource); } - return []; } + + return await _authorizationService.FilterResourcesByAuthorizableAction( + _instanceSettings.Id, userIdentity, dataSources, + AuthorizableActionNames.FoundationaLLM_DataSource_DataSources_Read); } private async Task LoadDataSource(DataSourceReference? dataSourceReference, string? resourceId = null) diff --git a/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs b/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs index 5112a67b13..7341853b1d 100644 --- a/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs +++ b/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs @@ -5,6 +5,7 @@ using FoundationaLLM.Common.Interfaces; using FoundationaLLM.Common.Models.Authentication; using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using FoundationaLLM.Common.Models.ResourceProviders.DataSource; using FoundationaLLM.Common.Models.ResourceProviders.Prompt; using FoundationaLLM.Common.Models.ResourceProviders.Vectorization; using FoundationaLLM.Orchestration.Core.Interfaces; @@ -64,7 +65,7 @@ public class OrchestrationBuilder return null; } - private static async Task LoadAgent( + private static async Task LoadAgent( string? agentName, Dictionary resourceProviderServices, UnifiedUserIdentity currentUserIdentity, @@ -79,6 +80,8 @@ private static async Task LoadAgent( throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_Prompt} was not loaded."); if (!resourceProviderServices.TryGetValue(ResourceProviderNames.FoundationaLLM_Vectorization, out var vectorizationResourceProvider)) throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_Vectorization} was not loaded."); + if (!resourceProviderServices.TryGetValue(ResourceProviderNames.FoundationaLLM_DataSource, out var dataSourceResourceProvider)) + throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_DataSource} was not loaded."); var agentBase = await agentResourceProvider.GetResource( $"/{AgentResourceTypeNames.Agents}/{agentName}", @@ -126,6 +129,16 @@ private static async Task LoadAgent( kmAgent.OrchestrationSettings!.AgentParameters![kmAgent.Vectorization.TextEmbeddingProfileObjectId!] = textEmbeddingProfile; } + + if (!string.IsNullOrWhiteSpace(kmAgent.Vectorization.DataSourceObjectId)) + { + var dataSource = await dataSourceResourceProvider.GetResource( + kmAgent.Vectorization.DataSourceObjectId, + currentUserIdentity); + + if (dataSource == null) + return null; + } } } diff --git a/src/dotnet/OrchestrationAPI/Program.cs b/src/dotnet/OrchestrationAPI/Program.cs index 2d20b6627a..599c9d9f83 100644 --- a/src/dotnet/OrchestrationAPI/Program.cs +++ b/src/dotnet/OrchestrationAPI/Program.cs @@ -58,6 +58,7 @@ public static void Main(string[] args) options.Select(AppConfigurationKeyFilters.FoundationaLLM_Prompt); options.Select(AppConfigurationKeyFilters.FoundationaLLM_Vectorization); options.Select(AppConfigurationKeyFilters.FoundationaLLM_Configuration); + options.Select(AppConfigurationKeyFilters.FoundationaLLM_DataSource); }); if (builder.Environment.IsDevelopment()) builder.Configuration.AddJsonFile("appsettings.development.json", true, true); @@ -127,6 +128,7 @@ public static void Main(string[] args) builder.AddPromptResourceProvider(); builder.AddVectorizationResourceProvider(); builder.AddConfigurationResourceProvider(); + builder.AddDataSourceResourceProvider(); // Register the downstream services and HTTP clients. RegisterDownstreamServices(builder);