Skip to content

Commit

Permalink
Merge pull request #1049 from solliancenet/aa-datasource-rbac
Browse files Browse the repository at this point in the history
Add RBAC support for data source resource provider
  • Loading branch information
ciprianjichici authored May 24, 2024
2 parents 03c5fd0 + d3a5d12 commit feafc1f
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 61 deletions.
40 changes: 7 additions & 33 deletions src/dotnet/Agent/ResourceProviders/AgentResourceProviderService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
using FoundationaLLM.Common.Constants.Configuration;
using FoundationaLLM.Common.Constants.ResourceProviders;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Extensions;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.Authentication;
using FoundationaLLM.Common.Models.Authorization;
using FoundationaLLM.Common.Models.Configuration.Instance;
using FoundationaLLM.Common.Models.Events;
using FoundationaLLM.Common.Models.ResourceProviders;
Expand Down Expand Up @@ -117,8 +117,9 @@ private async Task<List<ResourceProviderGetResult<AgentBase>>> LoadAgents(Resour
agents = (await Task.WhenAll(_agentReferences.Values
.Where(ar => !ar.Deleted)
.Select(ar => LoadAgent(ar))))
.Where(agent => agent != null)
.ToList();
.Where(agent => agent != null)
.Select(agent => agent!)
.ToList();
}
else
{
Expand All @@ -127,50 +128,23 @@ private async Task<List<ResourceProviderGetResult<AgentBase>>> LoadAgents(Resour
{
agent = await LoadAgent(null, instance.ResourceId);
if (agent != null)
{
agents.Add(agent);
}
}
else
{
if (agentReference.Deleted)
{
throw new ResourceProviderException($"Could not locate the {instance.ResourceId} agent resource.",
StatusCodes.Status404NotFound);
}

agent = await LoadAgent(agentReference);
if (agent != null)
{
agents.Add(agent);
}
}
}

var rolesWithActions = await _authorizationService.ProcessRoleAssignmentsWithActionsRequest(
_instanceSettings.Id,
new RoleAssignmentsWithActionsRequest()
{
Scopes = agents.Select(x => x.ObjectId!).ToList(),
PrincipalId = userIdentity.UserId!,
SecurityGroupIds = userIdentity.GroupIds
});

var results = new List<ResourceProviderGetResult<AgentBase>>();
foreach (var agent in agents)
{
if (rolesWithActions[agent.ObjectId!].Actions.Contains(AuthorizableActionNames.FoundationaLLM_Agent_Agents_Read))
{
results.Add(new ResourceProviderGetResult<AgentBase>()
{
Resource = agent,
Actions = rolesWithActions[agent.ObjectId!].Actions,
Roles = rolesWithActions[agent.ObjectId!].Roles
});
}
}

return results;
return await _authorizationService.FilterResourcesByAuthorizableAction(
_instanceSettings.Id, userIdentity, agents,
AuthorizableActionNames.FoundationaLLM_Agent_Agents_Read);
}

private async Task<AgentBase?> LoadAgent(AgentReference? agentReference, string? resourceId = null)
Expand Down
54 changes: 54 additions & 0 deletions src/dotnet/Common/Extensions/AuthorizationServiceExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.Authentication;
using FoundationaLLM.Common.Models.Authorization;
using FoundationaLLM.Common.Models.ResourceProviders;

namespace FoundationaLLM.Common.Extensions
{
/// <summary>
/// Extends the <see cref="IAuthorizationService"/> interface with helper methods.
/// </summary>
public static class AuthorizationServiceExtensions
{
/// <summary>
/// Filters the list of resources based on the authorizable action.
/// </summary>
/// <typeparam name="T">The object type of the resource being retrieved.</typeparam>
/// <param name="authorizationService">The <see cref="IAuthorizationService"/> service.</param>
/// <param name="instanceId">The FoundationaLLM instance identifier.</param>
/// <param name="userIdentity">The <see cref="UnifiedUserIdentity"/> providing information about the calling user identity.</param>
/// <param name="resources">The list of all resources.</param>
/// <param name="authorizableAction">The authorizable action to be checked.</param>
/// <returns>A list of resources on which the user identity is allowed to perform the authorizable action.</returns>
public static async Task<List<ResourceProviderGetResult<T>>> FilterResourcesByAuthorizableAction<T>(
this IAuthorizationService authorizationService,
string instanceId,
UnifiedUserIdentity userIdentity,
List<T> resources,
string authorizableAction)
where T : ResourceBase
{
var rolesWithActions = await authorizationService.ProcessRoleAssignmentsWithActionsRequest(
instanceId,
new RoleAssignmentsWithActionsRequest()
{
Scopes = resources.Select(x => x.ObjectId!).ToList(),
PrincipalId = userIdentity.UserId!,
SecurityGroupIds = userIdentity.GroupIds
});

var results = new List<ResourceProviderGetResult<T>>();

foreach (var resource in resources)
if (rolesWithActions[resource.ObjectId!].Actions.Contains(authorizableAction))
results.Add(new ResourceProviderGetResult<T>()
{
Resource = resource,
Actions = rolesWithActions[resource.ObjectId!].Actions,
Roles = rolesWithActions[resource.ObjectId!].Roles
});

return results;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ public static async Task<T> GetResource<T>(

var result = await resourceProviderService.HandleGetAsync(
objectId,
userIdentity);
return (result as List<ResourceProviderGetResult<T>>)!.First().Resource;
userIdentity) as List<ResourceProviderGetResult<T>>;

if (result == null || result.Count == 0)
throw new ResourceProviderException($"The resource provider {resourceProviderService.Name} is unable to retrieve the {objectId} resource.");

return result.First().Resource;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Azure.Messaging;
using FluentValidation;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Common.Constants.Authorization;
using FoundationaLLM.Common.Constants.Configuration;
using FoundationaLLM.Common.Constants.ResourceProviders;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Extensions;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.Authentication;
using FoundationaLLM.Common.Models.Configuration.Instance;
Expand Down Expand Up @@ -99,25 +101,25 @@ await _storageService.WriteFileAsync(
protected override async Task<object> GetResourcesAsync(ResourcePath resourcePath, UnifiedUserIdentity userIdentity) =>
resourcePath.ResourceTypeInstances[0].ResourceType switch
{
DataSourceResourceTypeNames.DataSources => await LoadDataSources(resourcePath.ResourceTypeInstances[0]),
DataSourceResourceTypeNames.DataSources => await LoadDataSources(resourcePath.ResourceTypeInstances[0], userIdentity),
_ => throw new ResourceProviderException($"The resource type {resourcePath.ResourceTypeInstances[0].ResourceType} is not supported by the {_name} resource provider.",
StatusCodes.Status400BadRequest)
};

#region Helpers for GetResourcesAsyncInternal

private async Task<List<ResourceProviderGetResult<DataSourceBase>>> LoadDataSources(ResourceTypeInstance instance)
private async Task<List<ResourceProviderGetResult<DataSourceBase>>> LoadDataSources(ResourceTypeInstance instance, UnifiedUserIdentity userIdentity)
{
var dataSources = new List<DataSourceBase>();

if (instance.ResourceId == null)
{
var dataSources = (await Task.WhenAll(
_dataSourceReferences.Values
.Where(dsr => !dsr.Deleted)
.Select(dsr => LoadDataSource(dsr))))
.Where(ds => ds != null)
.ToList();

return dataSources.Select(dataSource => new ResourceProviderGetResult<DataSourceBase>() { Resource = dataSource, Actions = [], Roles = [] }).ToList();
dataSources = (await Task.WhenAll(_dataSourceReferences.Values
.Where(dsr => !dsr.Deleted)
.Select(dsr => LoadDataSource(dsr))))
.Where(ds => ds != null)
.Select(ds => ds!)
.ToList();
}
else
{
Expand All @@ -126,26 +128,24 @@ private async Task<List<ResourceProviderGetResult<DataSourceBase>>> LoadDataSour
{
dataSource = await LoadDataSource(null, instance.ResourceId);
if (dataSource != null)
{
return [new ResourceProviderGetResult<DataSourceBase>() { Resource = dataSource, Actions = [], Roles = [] }];
}
return [];
dataSources.Add(dataSource);
}

if (dataSourceReference.Deleted)
else
{
throw new ResourceProviderException(
$"Could not locate the {instance.ResourceId} data source resource.",
StatusCodes.Status404NotFound);
}
if (dataSourceReference.Deleted)
throw new ResourceProviderException(
$"Could not locate the {instance.ResourceId} data source resource.",
StatusCodes.Status404NotFound);

dataSource = await LoadDataSource(dataSourceReference);
if (dataSource != null)
{
return [new ResourceProviderGetResult<DataSourceBase>() { Resource = dataSource, Actions = [], Roles = [] }];
dataSource = await LoadDataSource(dataSourceReference);
if (dataSource != null)
dataSources.Add(dataSource);
}
return [];
}

return await _authorizationService.FilterResourcesByAuthorizableAction(
_instanceSettings.Id, userIdentity, dataSources,
AuthorizableActionNames.FoundationaLLM_DataSource_DataSources_Read);
}

private async Task<DataSourceBase?> LoadDataSource(DataSourceReference? dataSourceReference, string? resourceId = null)
Expand Down
15 changes: 14 additions & 1 deletion src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.Authentication;
using FoundationaLLM.Common.Models.ResourceProviders.Agent;
using FoundationaLLM.Common.Models.ResourceProviders.DataSource;
using FoundationaLLM.Common.Models.ResourceProviders.Prompt;
using FoundationaLLM.Common.Models.ResourceProviders.Vectorization;
using FoundationaLLM.Orchestration.Core.Interfaces;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class OrchestrationBuilder
return null;
}

private static async Task<AgentBase> LoadAgent(
private static async Task<AgentBase?> LoadAgent(
string? agentName,
Dictionary<string, IResourceProviderService> resourceProviderServices,
UnifiedUserIdentity currentUserIdentity,
Expand All @@ -79,6 +80,8 @@ private static async Task<AgentBase> LoadAgent(
throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_Prompt} was not loaded.");
if (!resourceProviderServices.TryGetValue(ResourceProviderNames.FoundationaLLM_Vectorization, out var vectorizationResourceProvider))
throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_Vectorization} was not loaded.");
if (!resourceProviderServices.TryGetValue(ResourceProviderNames.FoundationaLLM_DataSource, out var dataSourceResourceProvider))
throw new OrchestrationException($"The resource provider {ResourceProviderNames.FoundationaLLM_DataSource} was not loaded.");

var agentBase = await agentResourceProvider.GetResource<AgentBase>(
$"/{AgentResourceTypeNames.Agents}/{agentName}",
Expand Down Expand Up @@ -126,6 +129,16 @@ private static async Task<AgentBase> LoadAgent(

kmAgent.OrchestrationSettings!.AgentParameters![kmAgent.Vectorization.TextEmbeddingProfileObjectId!] = textEmbeddingProfile;
}

if (!string.IsNullOrWhiteSpace(kmAgent.Vectorization.DataSourceObjectId))
{
var dataSource = await dataSourceResourceProvider.GetResource<DataSourceBase>(
kmAgent.Vectorization.DataSourceObjectId,
currentUserIdentity);

if (dataSource == null)
return null;
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/dotnet/OrchestrationAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public static void Main(string[] args)
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Prompt);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Vectorization);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Configuration);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_DataSource);
});
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);
Expand Down Expand Up @@ -127,6 +128,7 @@ public static void Main(string[] args)
builder.AddPromptResourceProvider();
builder.AddVectorizationResourceProvider();
builder.AddConfigurationResourceProvider();
builder.AddDataSourceResourceProvider();

// Register the downstream services and HTTP clients.
RegisterDownstreamServices(builder);
Expand Down

0 comments on commit feafc1f

Please sign in to comment.