diff --git a/src/Orleans.SignalR/Clients/ClientGrain.cs b/src/Orleans.SignalR/Clients/ClientGrain.cs index 6e847d5..d5497fd 100644 --- a/src/Orleans.SignalR/Clients/ClientGrain.cs +++ b/src/Orleans.SignalR/Clients/ClientGrain.cs @@ -2,7 +2,7 @@ using Microsoft.Extensions.Logging; using Orleans.Concurrency; using Orleans.Providers; -using Orleans.SignalR.Core; +using Orleans.SignalR.Connections; using Orleans.Streams; using System; using System.Collections.Generic; diff --git a/src/Orleans.SignalR/Core/ConnectionGrain.cs b/src/Orleans.SignalR/Connections/ConnectionGrain.cs similarity index 93% rename from src/Orleans.SignalR/Core/ConnectionGrain.cs rename to src/Orleans.SignalR/Connections/ConnectionGrain.cs index f3e03f9..d2e3ad4 100644 --- a/src/Orleans.SignalR/Core/ConnectionGrain.cs +++ b/src/Orleans.SignalR/Connections/ConnectionGrain.cs @@ -1,17 +1,15 @@ -using System.Buffers; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Logging; -using Orleans; using Orleans.Concurrency; using Orleans.Streams; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; -namespace Orleans.SignalR.Core +namespace Orleans.SignalR.Connections { - internal abstract class ConnectionGrain : Grain, IConnectionGrain - where TGrainState : ConnectionState, new() + internal class ConnectionGrain : Grain, IConnectionGrain { private readonly ILogger _logger; private IStreamProvider _streamProvider; @@ -19,7 +17,7 @@ internal abstract class ConnectionGrain : Grain, IConn protected ConnectionGrainKey KeyData; - internal ConnectionGrain(ILogger logger) + public ConnectionGrain(ILogger logger) { _logger = logger; } diff --git a/src/Orleans.SignalR/Core/ConnectionGrainKey.cs b/src/Orleans.SignalR/Connections/ConnectionGrainKey.cs similarity index 94% rename from src/Orleans.SignalR/Core/ConnectionGrainKey.cs rename to src/Orleans.SignalR/Connections/ConnectionGrainKey.cs index b36c242..4b98818 100644 --- a/src/Orleans.SignalR/Core/ConnectionGrainKey.cs +++ b/src/Orleans.SignalR/Connections/ConnectionGrainKey.cs @@ -1,6 +1,6 @@ using System.Diagnostics; -namespace Orleans.SignalR.Core +namespace Orleans.SignalR.Connections { [DebuggerDisplay("{DebuggerDisplay,nq}")] internal struct ConnectionGrainKey diff --git a/src/Orleans.SignalR/Core/IConnectionGrain.cs b/src/Orleans.SignalR/Connections/IConnectionGrain.cs similarity index 92% rename from src/Orleans.SignalR/Core/IConnectionGrain.cs rename to src/Orleans.SignalR/Connections/IConnectionGrain.cs index 3361274..a587141 100644 --- a/src/Orleans.SignalR/Core/IConnectionGrain.cs +++ b/src/Orleans.SignalR/Connections/IConnectionGrain.cs @@ -1,8 +1,8 @@ -using System.Collections.Generic; +using Orleans.SignalR.Core; +using System.Collections.Generic; using System.Threading.Tasks; -using Orleans; -namespace Orleans.SignalR.Core +namespace Orleans.SignalR.Connections { /// /// Grain interface Grouped of connections, such as user or custom group. diff --git a/src/Orleans.SignalR/Core/GrainExtensions.cs b/src/Orleans.SignalR/Core/GrainExtensions.cs index aed8e03..568219d 100644 --- a/src/Orleans.SignalR/Core/GrainExtensions.cs +++ b/src/Orleans.SignalR/Core/GrainExtensions.cs @@ -1,6 +1,7 @@ using Microsoft.AspNetCore.SignalR.Protocol; using Orleans.Concurrency; using Orleans.SignalR.Clients; +using Orleans.SignalR.Connections; using Orleans.SignalR.Core; using Orleans.SignalR.Groups; using Orleans.SignalR.Users; @@ -55,12 +56,12 @@ public static void SendOneWay(this IHubMessageInvoker grain, string methodName, grain.InvokeOneWay(g => g.Send(methodName, args)); } - [Obsolete("Use Send instead", false)] - public static async Task SendSignalRMessage(this IConnectionGrain grain, string methodName, params object[] message) - { - var invocationMessage = new InvocationMessage(methodName, message).AsImmutable(); - await grain.Send(invocationMessage); - } + //[Obsolete("Use Send instead", false)] + //public static async Task SendSignalRMessage(this IConnectionGrain grain, string methodName, params object[] message) + //{ + // var invocationMessage = new InvocationMessage(methodName, message).AsImmutable(); + // await grain.Send(invocationMessage); + //} /// /// Invokes a method on the hub. diff --git a/src/Orleans.SignalR/Groups/GroupGrain.cs b/src/Orleans.SignalR/Groups/GroupGrain.cs index dd79f20..aa13da6 100644 --- a/src/Orleans.SignalR/Groups/GroupGrain.cs +++ b/src/Orleans.SignalR/Groups/GroupGrain.cs @@ -1,20 +1,125 @@ +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Logging; using Orleans.Concurrency; using Orleans.Providers; -using Orleans.SignalR.Core; +using Orleans.SignalR.Connections; +using Orleans.Streams; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; namespace Orleans.SignalR.Groups { [StorageProvider(ProviderName = SignalrConstants.STORAGE_PROVIDER)] [Reentrant] - internal class GroupGrain : ConnectionGrain, IGroupGrain + internal class GroupGrain : Grain, IGroupGrain { - public GroupGrain(ILogger logger) : base(logger) + private readonly ILogger _logger; + private IStreamProvider _streamProvider; + private Dictionary> _connectionStreamHandles; + + protected ConnectionGrainKey KeyData; + + public GroupGrain(ILogger logger) + { + _logger = logger; + } + + public override async Task OnActivateAsync() + { + KeyData = new ConnectionGrainKey(this.GetPrimaryKeyString()); + _connectionStreamHandles = new Dictionary>(); + _streamProvider = GetStreamProvider(SignalrConstants.STREAM_PROVIDER); + var subscriptionTasks = new List(); + foreach (var connection in State.Connections) + { + var clientDisconnectStream = _streamProvider.GetStream(SignalrConstants.CLIENT_DISCONNECT_STREAM_ID, connection); + var subscriptions = await clientDisconnectStream.GetAllSubscriptionHandles(); + foreach (var subscription in subscriptions) + { + subscriptionTasks.Add(subscription.ResumeAsync(async (connectionId, _) => await Remove(connectionId))); + } + } + await Task.WhenAll(subscriptionTasks); + } + + public virtual async Task Add(string connectionId) { + var shouldWriteState = State.Connections.Add(connectionId); + if (!_connectionStreamHandles.ContainsKey(connectionId)) + { + var clientDisconnectStream = _streamProvider.GetStream(SignalrConstants.CLIENT_DISCONNECT_STREAM_ID, connectionId); + var subscription = await clientDisconnectStream.SubscribeAsync(async (connId, _) => await Remove(connId)); + _connectionStreamHandles[connectionId] = subscription; + } + + if (shouldWriteState) + await WriteStateAsync(); + } + + public virtual async Task Remove(string connectionId) + { + var shouldWriteState = State.Connections.Remove(connectionId); + if (_connectionStreamHandles.TryGetValue(connectionId, out var stream)) + { + await stream.UnsubscribeAsync(); + _connectionStreamHandles.Remove(connectionId); + } + + if (State.Connections.Count == 0) + { + await ClearStateAsync(); + DeactivateOnIdle(); + } + else if (shouldWriteState) + { + await WriteStateAsync(); + } + } + + public virtual Task Send(Immutable message) + { + return SendAll(message, State.Connections); + } + + public Task SendExcept(string methodName, object[] args, IReadOnlyList excludedConnectionIds) + { + var message = new Immutable(new InvocationMessage(methodName, args)); + return SendAll(message, State.Connections.Where(x => !excludedConnectionIds.Contains(x)).ToList()); + } + + public Task Count() + { + return Task.FromResult(State.Connections.Count); + } + + protected Task SendAll(Immutable message, IReadOnlyCollection connections) + { + _logger.LogDebug("Sending message to {hubName}.{targetMethod} on group {groupId} to {connectionsCount} connection(s)", + KeyData.HubName, message.Value.Target, KeyData.Id, connections.Count); + + var tasks = ArrayPool.Shared.Rent(connections.Count); + try + { + int index = 0; + foreach (var connection in connections) + { + var client = GrainFactory.GetClientGrain(KeyData.HubName, connection); + tasks[index++] = client.Send(message); + } + + return Task.WhenAll(tasks.Where(x => x != null).ToArray()); + } + finally + { + ArrayPool.Shared.Return(tasks); + } } } - internal class GroupState : ConnectionState + internal class GroupState { + public HashSet Connections { get; set; } = new HashSet(); } } \ No newline at end of file diff --git a/src/Orleans.SignalR/Groups/IGroupGrain.cs b/src/Orleans.SignalR/Groups/IGroupGrain.cs index 2906bc4..a0b0d32 100644 --- a/src/Orleans.SignalR/Groups/IGroupGrain.cs +++ b/src/Orleans.SignalR/Groups/IGroupGrain.cs @@ -1,8 +1,37 @@ using Orleans.SignalR.Core; +using System.Collections.Generic; +using System.Threading.Tasks; namespace Orleans.SignalR.Groups { - public interface IGroupGrain : IConnectionGrain + /// + /// Grain interface Grouped of connections, such as user or custom group. + /// + public interface IGroupGrain : IHubMessageInvoker, IGrainWithStringKey { + /// + /// Add connection id to the group. + /// + /// Connection id to add. + Task Add(string connectionId); + + /// + /// Remove the connection id to the group. + /// + /// Connection id to remove. + Task Remove(string connectionId); + + /// + /// Gets the connection count in the group. + /// + Task Count(); + + /// + /// Invokes a method on the hub except the specified connection ids. + /// + /// Target method name to invoke. + /// Arguments to pass to the target method. + /// Connection ids to exclude. + Task SendExcept(string methodName, object[] args, IReadOnlyList excludedConnectionIds); } } \ No newline at end of file diff --git a/src/Orleans.SignalR/Orleans.SignalR.csproj b/src/Orleans.SignalR/Orleans.SignalR.csproj index 7477588..e3f924c 100644 --- a/src/Orleans.SignalR/Orleans.SignalR.csproj +++ b/src/Orleans.SignalR/Orleans.SignalR.csproj @@ -20,6 +20,11 @@ https://github.com/zeus82/Orleans.SignalR.git + + + + + diff --git a/src/Orleans.SignalR/Users/IUserGrain.cs b/src/Orleans.SignalR/Users/IUserGrain.cs index 3fbfc6c..d38fe96 100644 --- a/src/Orleans.SignalR/Users/IUserGrain.cs +++ b/src/Orleans.SignalR/Users/IUserGrain.cs @@ -1,8 +1,37 @@ using Orleans.SignalR.Core; +using System.Collections.Generic; +using System.Threading.Tasks; namespace Orleans.SignalR.Users { - public interface IUserGrain : IConnectionGrain + /// + /// Grain interface Grouped of connections, such as user or custom group. + /// + public interface IUserGrain : IHubMessageInvoker, IGrainWithStringKey { + /// + /// Add connection id to the group. + /// + /// Connection id to add. + Task Add(string connectionId); + + /// + /// Remove the connection id to the group. + /// + /// Connection id to remove. + Task Remove(string connectionId); + + /// + /// Gets the connection count in the group. + /// + Task Count(); + + /// + /// Invokes a method on the hub except the specified connection ids. + /// + /// Target method name to invoke. + /// Arguments to pass to the target method. + /// Connection ids to exclude. + Task SendExcept(string methodName, object[] args, IReadOnlyList excludedConnectionIds); } } \ No newline at end of file diff --git a/src/Orleans.SignalR/Users/UserGrain.cs b/src/Orleans.SignalR/Users/UserGrain.cs index 3b583de..54ab2ec 100644 --- a/src/Orleans.SignalR/Users/UserGrain.cs +++ b/src/Orleans.SignalR/Users/UserGrain.cs @@ -1,20 +1,125 @@ -using Microsoft.Extensions.Logging; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.Logging; using Orleans.Concurrency; using Orleans.Providers; -using Orleans.SignalR.Core; +using Orleans.SignalR.Connections; +using Orleans.Streams; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; namespace Orleans.SignalR.Users { [StorageProvider(ProviderName = SignalrConstants.STORAGE_PROVIDER)] [Reentrant] - internal class UserGrain : ConnectionGrain, IUserGrain + internal class UserGrain : Grain, IUserGrain { - public UserGrain(ILogger logger) : base(logger) + private readonly ILogger _logger; + private IStreamProvider _streamProvider; + private Dictionary> _connectionStreamHandles; + + protected ConnectionGrainKey KeyData; + + public UserGrain(ILogger logger) + { + _logger = logger; + } + + public override async Task OnActivateAsync() + { + KeyData = new ConnectionGrainKey(this.GetPrimaryKeyString()); + _connectionStreamHandles = new Dictionary>(); + _streamProvider = GetStreamProvider(SignalrConstants.STREAM_PROVIDER); + var subscriptionTasks = new List(); + foreach (var connection in State.Connections) + { + var clientDisconnectStream = _streamProvider.GetStream(SignalrConstants.CLIENT_DISCONNECT_STREAM_ID, connection); + var subscriptions = await clientDisconnectStream.GetAllSubscriptionHandles(); + foreach (var subscription in subscriptions) + { + subscriptionTasks.Add(subscription.ResumeAsync(async (connectionId, _) => await Remove(connectionId))); + } + } + await Task.WhenAll(subscriptionTasks); + } + + public virtual async Task Add(string connectionId) { + var shouldWriteState = State.Connections.Add(connectionId); + if (!_connectionStreamHandles.ContainsKey(connectionId)) + { + var clientDisconnectStream = _streamProvider.GetStream(SignalrConstants.CLIENT_DISCONNECT_STREAM_ID, connectionId); + var subscription = await clientDisconnectStream.SubscribeAsync(async (connId, _) => await Remove(connId)); + _connectionStreamHandles[connectionId] = subscription; + } + + if (shouldWriteState) + await WriteStateAsync(); + } + + public virtual async Task Remove(string connectionId) + { + var shouldWriteState = State.Connections.Remove(connectionId); + if (_connectionStreamHandles.TryGetValue(connectionId, out var stream)) + { + await stream.UnsubscribeAsync(); + _connectionStreamHandles.Remove(connectionId); + } + + if (State.Connections.Count == 0) + { + await ClearStateAsync(); + DeactivateOnIdle(); + } + else if (shouldWriteState) + { + await WriteStateAsync(); + } + } + + public virtual Task Send(Immutable message) + { + return SendAll(message, State.Connections); + } + + public Task SendExcept(string methodName, object[] args, IReadOnlyList excludedConnectionIds) + { + var message = new Immutable(new InvocationMessage(methodName, args)); + return SendAll(message, State.Connections.Where(x => !excludedConnectionIds.Contains(x)).ToList()); + } + + public Task Count() + { + return Task.FromResult(State.Connections.Count); + } + + protected Task SendAll(Immutable message, IReadOnlyCollection connections) + { + _logger.LogDebug("Sending message to {hubName}.{targetMethod} on group {groupId} to {connectionsCount} connection(s)", + KeyData.HubName, message.Value.Target, KeyData.Id, connections.Count); + + var tasks = ArrayPool.Shared.Rent(connections.Count); + try + { + int index = 0; + foreach (var connection in connections) + { + var client = GrainFactory.GetClientGrain(KeyData.HubName, connection); + tasks[index++] = client.Send(message); + } + + return Task.WhenAll(tasks.Where(x => x != null).ToArray()); + } + finally + { + ArrayPool.Shared.Return(tasks); + } } } - internal class UserState : ConnectionState + internal class UserState { + public HashSet Connections { get; set; } = new HashSet(); } } \ No newline at end of file