diff --git a/src/DotNetty.Common/Concurrency/AbstractPromise.cs b/src/DotNetty.Common/Concurrency/AbstractPromise.cs index e1d38745b..1989f1092 100644 --- a/src/DotNetty.Common/Concurrency/AbstractPromise.cs +++ b/src/DotNetty.Common/Concurrency/AbstractPromise.cs @@ -7,12 +7,33 @@ namespace DotNetty.Common.Concurrency using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; + using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Sources; public abstract class AbstractPromise : IPromise, IValueTaskSource { + struct CompletionData + { + public Action Continuation { get; } + public object State { get; } + public ExecutionContext ExecutionContext { get; } + public SynchronizationContext SynchronizationContext { get; } + + public CompletionData(Action continuation, object state, ExecutionContext executionContext, SynchronizationContext synchronizationContext) + { + this.Continuation = continuation; + this.State = state; + this.ExecutionContext = executionContext; + this.SynchronizationContext = synchronizationContext; + } + } + const short SourceToken = 0; + + static readonly ContextCallback ExecutionContextCallback = Execute; + static readonly SendOrPostCallback SyncContextCallbackWithExecutionContext = ExecuteWithExecutionContext; + static readonly SendOrPostCallback SyncContextCallback = Execute; static readonly Exception CanceledException = new OperationCanceledException(); static readonly Exception CompletedNoException = new Exception(); @@ -20,7 +41,7 @@ public abstract class AbstractPromise : IPromise, IValueTaskSource protected Exception exception; int callbackCount; - (Action, object)[] callbacks; + CompletionData[] completions; public bool TryComplete() => this.TryComplete0(CompletedNoException); @@ -34,7 +55,7 @@ protected virtual bool TryComplete0(Exception exception) { // Set the exception object to the exception passed in or a sentinel value this.exception = exception; - this.TryExecuteCallbacks(); + this.TryExecuteCompletions(); return true; } @@ -75,27 +96,31 @@ public virtual void GetResult(short token) public virtual void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) { - //todo: context preservation - if (this.callbacks == null) + if (this.completions == null) { - this.callbacks = new (Action, object)[1]; + this.completions = new CompletionData[1]; } int newIndex = this.callbackCount; this.callbackCount++; - if (newIndex == this.callbacks.Length) + if (newIndex == this.completions.Length) { - var newArray = new (Action, object)[this.callbacks.Length * 2]; - Array.Copy(this.callbacks, newArray, this.callbacks.Length); - this.callbacks = newArray; + var newArray = new CompletionData[this.completions.Length * 2]; + Array.Copy(this.completions, newArray, this.completions.Length); + this.completions = newArray; } - this.callbacks[newIndex] = (continuation, state); + this.completions[newIndex] = new CompletionData( + continuation, + state, + (flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0 ? ExecutionContext.Capture() : null, + (flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0 ? SynchronizationContext.Current : null + ); if (this.exception != null) { - this.TryExecuteCallbacks(); + this.TryExecuteCompletions(); } } @@ -120,9 +145,9 @@ bool IsCompletedOrThrow() [MethodImpl(MethodImplOptions.NoInlining)] void ThrowLatchedException() => ExceptionDispatchInfo.Capture(this.exception).Throw(); - bool TryExecuteCallbacks() + bool TryExecuteCompletions() { - if (this.callbackCount == 0 || this.callbacks == null) + if (this.callbackCount == 0 || this.completions == null) { return false; } @@ -133,8 +158,8 @@ bool TryExecuteCallbacks() { try { - (Action callback, object state) = this.callbacks[i]; - callback(state); + CompletionData completion = this.completions[i]; + ExecuteCompletion(completion); } catch (Exception ex) { @@ -154,15 +179,57 @@ bool TryExecuteCallbacks() throw new AggregateException(exceptions); } - + [MethodImpl(MethodImplOptions.AggressiveInlining)] protected void ClearCallbacks() { if (this.callbackCount > 0) { this.callbackCount = 0; - Array.Clear(this.callbacks, 0, this.callbacks.Length); + Array.Clear(this.completions, 0, this.completions.Length); } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void ExecuteCompletion(CompletionData completion) + { + if (completion.SynchronizationContext == null) + { + if (completion.ExecutionContext == null) + { + completion.Continuation(completion.State); + } + else + { + //boxing + ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, completion); + } + } + else + { + if (completion.ExecutionContext == null) + { + //boxing + completion.SynchronizationContext.Post(SyncContextCallback, completion); + } + else + { + //boxing + completion.SynchronizationContext.Post(SyncContextCallbackWithExecutionContext, completion); + } + } + } + + static void Execute(object state) + { + CompletionData completion = (CompletionData)state; + completion.Continuation(completion.State); + } + + static void ExecuteWithExecutionContext(object state) + { + CompletionData completion = (CompletionData)state; + ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, state); } } } \ No newline at end of file