diff --git a/.nuget/NuGet.Config b/.nuget/NuGet.Config deleted file mode 100644 index 67f8ea046..000000000 --- a/.nuget/NuGet.Config +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/examples/Telnet.Server/TelnetServerHandler.cs b/examples/Telnet.Server/TelnetServerHandler.cs index 563384761..2e85ae8a4 100644 --- a/examples/Telnet.Server/TelnetServerHandler.cs +++ b/examples/Telnet.Server/TelnetServerHandler.cs @@ -6,6 +6,8 @@ namespace Telnet.Server using System; using System.Net; using System.Threading.Tasks; + using DotNetty.Codecs; + using DotNetty.Common.Concurrency; using DotNetty.Transport.Channels; public class TelnetServerHandler : SimpleChannelInboundHandler @@ -16,7 +18,7 @@ public override void ChannelActive(IChannelHandlerContext contex) contex.WriteAndFlushAsync(string.Format("It is {0} now !\r\n", DateTime.Now)); } - protected override void ChannelRead0(IChannelHandlerContext contex, string msg) + protected override void ChannelRead0(IChannelHandlerContext context, string msg) { // Generate and write a response. string response; @@ -35,11 +37,10 @@ protected override void ChannelRead0(IChannelHandlerContext contex, string msg) response = "Did you say '" + msg + "'?\r\n"; } - Task wait_close = contex.WriteAndFlushAsync(response); + Task waitClose = context.WriteAndFlushAsync(response); if (close) { - Task.WaitAll(wait_close); - contex.CloseAsync(); + waitClose.CloseOnComplete(context); } } diff --git a/src/DotNetty.Buffers/DotNetty.Buffers.csproj b/src/DotNetty.Buffers/DotNetty.Buffers.csproj index 1d68f3c2a..5ba59d926 100644 --- a/src/DotNetty.Buffers/DotNetty.Buffers.csproj +++ b/src/DotNetty.Buffers/DotNetty.Buffers.csproj @@ -1,4 +1,5 @@ - + + netstandard1.3;net45 true diff --git a/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs b/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs index c6e38d2e0..feb5453f6 100644 --- a/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs +++ b/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs @@ -167,7 +167,7 @@ void SetExposeHeaders(IHttpResponse response) void SetMaxAge(IHttpResponse response) => response.Headers.Set(HttpHeaderNames.AccessControlMaxAge, this.config.MaxAge); - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { if (this.config.IsCorsSupportEnabled && message is IHttpResponse response) { @@ -177,7 +177,7 @@ public override Task WriteAsync(IChannelHandlerContext context, object message) this.SetExposeHeaders(response); } } - return context.WriteAndFlushAsync(message); + return context.WriteAndFlushAsync(message, true); } static void Forbidden(IChannelHandlerContext ctx, IHttpRequest request) @@ -197,15 +197,8 @@ static void Respond(IChannelHandlerContext ctx, IHttpRequest request, IHttpRespo Task task = ctx.WriteAndFlushAsync(response); if (!keepAlive) { - task.ContinueWith(CloseOnComplete, ctx, - TaskContinuationOptions.ExecuteSynchronously); + task.CloseOnComplete(ctx); } } - - static void CloseOnComplete(Task task, object state) - { - var ctx = (IChannelHandlerContext)state; - ctx.CloseAsync(); - } } } diff --git a/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs b/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs index 9f10eeea3..122fd78e4 100644 --- a/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs +++ b/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs @@ -72,7 +72,7 @@ public HttpClientUpgradeHandler(ISourceCodec sourceCodec, IUpgradeCodec upgradeC this.upgradeCodec = upgradeCodec; } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { if (!(message is IHttpRequest)) { @@ -81,14 +81,14 @@ public override Task WriteAsync(IChannelHandlerContext context, object message) if (this.upgradeRequested) { - return TaskEx.FromException(new InvalidOperationException("Attempting to write HTTP request with upgrade in progress")); + return new ValueTask(TaskEx.FromException(new InvalidOperationException("Attempting to write HTTP request with upgrade in progress"))); } this.upgradeRequested = true; this.SetUpgradeRequestHeaders(context, (IHttpRequest)message); // Continue writing the request. - Task task = context.WriteAsync(message); + ValueTask task = context.WriteAsync(message); // Notify that the upgrade request was issued. context.FireUserEventTriggered(UpgradeEvent.UpgradeIssued); diff --git a/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs b/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs index de4e41746..6ac07f518 100644 --- a/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs +++ b/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs @@ -39,27 +39,15 @@ public override void ChannelRead(IChannelHandlerContext context, object message) // the expectation failed so we refuse the request. IHttpResponse rejection = this.RejectResponse(req); ReferenceCountUtil.Release(message); - context.WriteAndFlushAsync(rejection) - .ContinueWith(CloseOnFailure, context, TaskContinuationOptions.ExecuteSynchronously); + context.WriteAndFlushAsync(rejection).CloseOnFailure(context); return; } - context.WriteAndFlushAsync(accept) - .ContinueWith(CloseOnFailure, context, TaskContinuationOptions.ExecuteSynchronously); + context.WriteAndFlushAsync(accept).CloseOnFailure(context); req.Headers.Remove(HttpHeaderNames.Expect); } base.ChannelRead(context, message); } } - - static Task CloseOnFailure(Task task, object state) - { - if (task.IsFaulted) - { - var context = (IChannelHandlerContext)state; - return context.CloseAsync(); - } - return TaskEx.Completed; - } } } diff --git a/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs b/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs index 981ff209b..aae9166bc 100644 --- a/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs +++ b/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs @@ -31,7 +31,7 @@ public override void ChannelRead(IChannelHandlerContext context, object message) base.ChannelRead(context, message); } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { // modify message on way out to add headers if needed if (message is IHttpResponse response) @@ -52,18 +52,13 @@ public override Task WriteAsync(IChannelHandlerContext context, object message) } if (message is ILastHttpContent && !this.ShouldKeepAlive()) { - return base.WriteAsync(context, message) - .ContinueWith(CloseOnComplete, context, TaskContinuationOptions.ExecuteSynchronously); + Task task = base.WriteAsync(context, message).AsTask(); + task.CloseOnComplete(context.Channel); + return new ValueTask(task); } return base.WriteAsync(context, message); } - static Task CloseOnComplete(Task task, object state) - { - var context = (IChannelHandlerContext)state; - return context.CloseAsync(); - } - void TrackResponse(IHttpResponse response) { if (!IsInformational(response)) diff --git a/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs index d47663731..3eee5b769 100644 --- a/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs +++ b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs @@ -254,34 +254,28 @@ bool Upgrade(IChannelHandlerContext ctx, IFullHttpRequest request) var upgradeEvent = new UpgradeEvent(upgradeProtocol, request); IUpgradeCodec finalUpgradeCodec = upgradeCodec; - ctx.WriteAndFlushAsync(upgradeResponse).ContinueWith(t => - { - try - { - if (t.Status == TaskStatus.RanToCompletion) - { - // Perform the upgrade to the new protocol. - this.sourceCodec.UpgradeFrom(ctx); - finalUpgradeCodec.UpgradeTo(ctx, request); - - // Notify that the upgrade has occurred. Retain the event to offset - // the release() in the finally block. - ctx.FireUserEventTriggered(upgradeEvent.Retain()); - - // Remove this handler from the pipeline. - ctx.Channel.Pipeline.Remove(this); - } - else - { - ctx.Channel.CloseAsync(); - } - } - finally - { - // Release the event if the upgrade event wasn't fired. - upgradeEvent.Release(); - } - }, TaskContinuationOptions.ExecuteSynchronously); + try + { + Task writeTask = ctx.WriteAndFlushAsync(upgradeResponse); + + // Perform the upgrade to the new protocol. + this.sourceCodec.UpgradeFrom(ctx); + finalUpgradeCodec.UpgradeTo(ctx, request); + + // Remove this handler from the pipeline. + ctx.Channel.Pipeline.Remove(this); + + // Notify that the upgrade has occurred. Retain the event to offset + // the release() in the finally block. + ctx.FireUserEventTriggered(upgradeEvent.Retain()); + + writeTask.CloseOnFailure(ctx.Channel); + } + finally + { + // Release the event if the upgrade event wasn't fired. + upgradeEvent.Release(); + } return true; } diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs index 5a72c6d1c..922ce7756 100644 --- a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs @@ -19,7 +19,7 @@ public WebSocketClientExtensionHandler(params IWebSocketClientExtensionHandshake this.extensionHandshakers = new List(extensionHandshakers); } - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { if (msg is IHttpRequest request && WebSocketExtensionUtil.IsWebsocketUpgrade(request.Headers)) { diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs index e0b07e36b..6d6b0c422 100644 --- a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs @@ -15,12 +15,13 @@ public class WebSocketServerExtensionHandler : ChannelHandlerAdapter readonly List extensionHandshakers; List validExtensions; + Action upgradeCompletedContinuation; public WebSocketServerExtensionHandler(params IWebSocketServerExtensionHandshaker[] extensionHandshakers) { Contract.Requires(extensionHandshakers != null && extensionHandshakers.Length > 0); - this.extensionHandshakers = new List(extensionHandshakers); + this.upgradeCompletedContinuation = this.OnUpgradeCompleted; } public override void ChannelRead(IChannelHandlerContext ctx, object msg) @@ -67,16 +68,16 @@ public override void ChannelRead(IChannelHandlerContext ctx, object msg) base.ChannelRead(ctx, msg); } - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { - Action continuationAction = null; - + HttpHeaders responseHeaders; + string headerValue = null; + if (msg is IHttpResponse response - && WebSocketExtensionUtil.IsWebsocketUpgrade(response.Headers) + && WebSocketExtensionUtil.IsWebsocketUpgrade(responseHeaders = response.Headers) && this.validExtensions != null) { - string headerValue = null; - if (response.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)) + if (responseHeaders.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)) { headerValue = value?.ToString(); } @@ -88,31 +89,33 @@ public override Task WriteAsync(IChannelHandlerContext ctx, object msg) extensionData.Name, extensionData.Parameters); } - continuationAction = promise => - { - if (promise.Status == TaskStatus.RanToCompletion) - { - foreach (IWebSocketServerExtension extension in this.validExtensions) - { - WebSocketExtensionDecoder decoder = extension.NewExtensionDecoder(); - WebSocketExtensionEncoder encoder = extension.NewExtensionEncoder(); - ctx.Channel.Pipeline.AddAfter(ctx.Name, decoder.GetType().Name, decoder); - ctx.Channel.Pipeline.AddAfter(ctx.Name, encoder.GetType().Name, encoder); - } - } - ctx.Channel.Pipeline.Remove(ctx.Name); - }; - if (headerValue != null) { - response.Headers.Set(HttpHeaderNames.SecWebsocketExtensions, headerValue); + responseHeaders.Set(HttpHeaderNames.SecWebsocketExtensions, headerValue); } + + Task task = base.WriteAsync(ctx, msg).AsTask(); + task.ContinueWith(this.upgradeCompletedContinuation, ctx, TaskContinuationOptions.ExecuteSynchronously); + return new ValueTask(task); } - return continuationAction == null - ? base.WriteAsync(ctx, msg) - : base.WriteAsync(ctx, msg) - .ContinueWith(continuationAction, TaskContinuationOptions.ExecuteSynchronously); + return base.WriteAsync(ctx, msg); + } + + void OnUpgradeCompleted(Task task, object state) + { + var ctx = (IChannelHandlerContext)state; + if (task.Status == TaskStatus.RanToCompletion) + { + foreach (IWebSocketServerExtension extension in this.validExtensions) + { + WebSocketExtensionDecoder decoder = extension.NewExtensionDecoder(); + WebSocketExtensionEncoder encoder = extension.NewExtensionEncoder(); + ctx.Channel.Pipeline.AddAfter(ctx.Name, decoder.GetType().Name, decoder); + ctx.Channel.Pipeline.AddAfter(ctx.Name, encoder.GetType().Name, encoder); + } + } + ctx.Channel.Pipeline.Remove(ctx.Name); } } } diff --git a/src/DotNetty.Codecs.Mqtt/MqttDecoder.cs b/src/DotNetty.Codecs.Mqtt/MqttDecoder.cs index b2418db83..34894e704 100644 --- a/src/DotNetty.Codecs.Mqtt/MqttDecoder.cs +++ b/src/DotNetty.Codecs.Mqtt/MqttDecoder.cs @@ -241,7 +241,7 @@ static void DecodeConnectPacket(IByteBuffer buffer, ConnectPacket packet, ref in { var connAckPacket = new ConnAckPacket(); connAckPacket.ReturnCode = ConnectReturnCode.RefusedUnacceptableProtocolVersion; - context.WriteAndFlushAsync(connAckPacket); + context.WriteAndFlushAsync(connAckPacket, false); throw new DecoderException($"Unexpected protocol level. Expected: {Util.ProtocolLevel}. Actual: {packet.ProtocolLevel}"); } diff --git a/src/DotNetty.Codecs/Compression/JZlibEncoder.cs b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs index 901bfcf80..19bdc4df0 100644 --- a/src/DotNetty.Codecs/Compression/JZlibEncoder.cs +++ b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs @@ -241,8 +241,7 @@ Task FinishEncode(IChannelHandlerContext context) this.z.next_out = null; } - return context.WriteAndFlushAsync(footer) - .ContinueWith(_ => context.CloseAsync()); + return context.WriteAndFlushAsync(footer).CloseOnComplete(context); } public override void HandlerAdded(IChannelHandlerContext context) => this.ctx = context; diff --git a/src/DotNetty.Codecs/DotNetty.Codecs.csproj b/src/DotNetty.Codecs/DotNetty.Codecs.csproj index 12a113965..ef53cc113 100644 --- a/src/DotNetty.Codecs/DotNetty.Codecs.csproj +++ b/src/DotNetty.Codecs/DotNetty.Codecs.csproj @@ -1,4 +1,5 @@ - + + netstandard1.3;net45 true @@ -44,4 +45,7 @@ + + + \ No newline at end of file diff --git a/src/DotNetty.Codecs/MessageAggregator.cs b/src/DotNetty.Codecs/MessageAggregator.cs index 79f5e9c62..76106dd03 100644 --- a/src/DotNetty.Codecs/MessageAggregator.cs +++ b/src/DotNetty.Codecs/MessageAggregator.cs @@ -130,13 +130,10 @@ protected internal override void Decode(IChannelHandlerContext context, TMessage bool closeAfterWrite = this.CloseAfterContinueResponse(continueResponse); this.handlingOversizedMessage = this.IgnoreContentAfterContinueResponse(continueResponse); - Task task = context - .WriteAndFlushAsync(continueResponse) - .ContinueWith(ContinueResponseWriteAction, context, TaskContinuationOptions.ExecuteSynchronously); + WriteContinueResponse(context, continueResponse, closeAfterWrite); if (closeAfterWrite) { - task.ContinueWith(CloseAfterWriteAction, context, TaskContinuationOptions.ExecuteSynchronously); return; } @@ -245,19 +242,21 @@ protected internal override void Decode(IChannelHandlerContext context, TMessage throw new MessageAggregationException("Unknown aggregation state."); } } - - static void CloseAfterWriteAction(Task task, object state) + + static async void WriteContinueResponse(IChannelHandlerContext ctx, object message, bool closeAfterWrite) { - var ctx = (IChannelHandlerContext)state; - ctx.Channel.CloseAsync(); - } - - static void ContinueResponseWriteAction(Task task, object state) - { - if (task.IsFaulted) + try + { + await ctx.WriteAndFlushAsync(message); + } + catch (Exception ex) + { + ctx.FireExceptionCaught(ex); + } + + if (closeAfterWrite) { - var ctx = (IChannelHandlerContext)state; - ctx.FireExceptionCaught(task.Exception); + ctx.Channel.CloseAsync(); } } diff --git a/src/DotNetty.Codecs/MessageToByteEncoder.cs b/src/DotNetty.Codecs/MessageToByteEncoder.cs index cff1dee2e..9b4283b75 100644 --- a/src/DotNetty.Codecs/MessageToByteEncoder.cs +++ b/src/DotNetty.Codecs/MessageToByteEncoder.cs @@ -7,6 +7,7 @@ namespace DotNetty.Codecs using System.Diagnostics.Contracts; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; using DotNetty.Transport.Channels; @@ -14,12 +15,12 @@ public abstract class MessageToByteEncoder : ChannelHandlerAdapter { public virtual bool AcceptOutboundMessage(object message) => message is T; - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { Contract.Requires(context != null); IByteBuffer buffer = null; - Task result; + ValueTask result; try { if (this.AcceptOutboundMessage(message)) @@ -52,13 +53,13 @@ public override Task WriteAsync(IChannelHandlerContext context, object message) return context.WriteAsync(message); } } - catch (EncoderException e) + catch (EncoderException) { - return TaskEx.FromException(e); + throw; } catch (Exception ex) { - return TaskEx.FromException(new EncoderException(ex)); + throw new EncoderException(ex); } finally { diff --git a/src/DotNetty.Codecs/MessageToMessageCodec.cs b/src/DotNetty.Codecs/MessageToMessageCodec.cs index a990193a2..5b4368323 100644 --- a/src/DotNetty.Codecs/MessageToMessageCodec.cs +++ b/src/DotNetty.Codecs/MessageToMessageCodec.cs @@ -50,7 +50,7 @@ protected MessageToMessageCodec() public sealed override void ChannelRead(IChannelHandlerContext context, object message) => this.decoder.ChannelRead(context, message); - public sealed override Task WriteAsync(IChannelHandlerContext context, object message) => + public sealed override ValueTask WriteAsync(IChannelHandlerContext context, object message) => this.encoder.WriteAsync(context, message); public virtual bool AcceptInboundMessage(object msg) => msg is TInbound; diff --git a/src/DotNetty.Codecs/MessageToMessageEncoder.cs b/src/DotNetty.Codecs/MessageToMessageEncoder.cs index cecc8ebde..812d632a5 100644 --- a/src/DotNetty.Codecs/MessageToMessageEncoder.cs +++ b/src/DotNetty.Codecs/MessageToMessageEncoder.cs @@ -7,6 +7,7 @@ namespace DotNetty.Codecs using System.Collections.Generic; using System.Threading.Tasks; using DotNetty.Common; + using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; using DotNetty.Transport.Channels; @@ -18,9 +19,9 @@ public abstract class MessageToMessageEncoder : ChannelHandlerAdapter /// public virtual bool AcceptOutboundMessage(object msg) => msg is T; - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { - Task result; + ValueTask result; ThreadLocalObjectList output = null; try { @@ -50,13 +51,13 @@ public override Task WriteAsync(IChannelHandlerContext ctx, object msg) return ctx.WriteAsync(msg); } } - catch (EncoderException e) + catch (EncoderException) { - return TaskEx.FromException(e); + throw; } catch (Exception ex) { - return TaskEx.FromException(new EncoderException(ex)); // todo: we don't have a stack on EncoderException but it's present on inner exception. + throw new EncoderException(ex);// todo: we don't have a stack on EncoderException but it's present on inner exception. } finally { @@ -72,21 +73,21 @@ public override Task WriteAsync(IChannelHandlerContext ctx, object msg) for (int i = 0; i < lastItemIndex; i++) { // we don't care about output from these messages as failure while sending one of these messages will fail all messages up to the last message - which will be observed by the caller in Task result. - ctx.WriteAsync(output[i]); // todo: optimize: once IChannelHandlerContext allows, pass "not interested in task" flag + ctx.WriteAsync(output[i]); } result = ctx.WriteAsync(output[lastItemIndex]); } else { // 0 items in output - must never get here - result = null; + result = default(ValueTask); } output.Return(); } else { // output was reset during exception handling - must never get here - result = null; + result = default(ValueTask); } } return result; diff --git a/src/DotNetty.Codecs/TaskExtensions.cs b/src/DotNetty.Codecs/TaskExtensions.cs new file mode 100644 index 000000000..d8356c148 --- /dev/null +++ b/src/DotNetty.Codecs/TaskExtensions.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public static class TaskExtensions + { + public static async Task CloseOnComplete(this ValueTask task, IChannelHandlerContext ctx) + { + try + { + await task; + } + finally + { + await ctx.CloseAsync(); + } + } + + static readonly Func CloseOnCompleteContinuation = Close; + static readonly Func CloseOnFailureContinuation = CloseOnFailure; + + public static Task CloseOnComplete(this Task task, IChannelHandlerContext ctx) + => task.ContinueWith(CloseOnCompleteContinuation, ctx, TaskContinuationOptions.ExecuteSynchronously); + + public static Task CloseOnComplete(this Task task, IChannel channel) + => task.ContinueWith(CloseOnCompleteContinuation, channel, TaskContinuationOptions.ExecuteSynchronously); + + public static Task CloseOnFailure(this Task task, IChannelHandlerContext ctx) + => task.ContinueWith(CloseOnFailureContinuation, ctx, TaskContinuationOptions.ExecuteSynchronously); + + public static Task CloseOnFailure(this Task task, IChannel channel) + => task.ContinueWith(CloseOnFailureContinuation, channel, TaskContinuationOptions.ExecuteSynchronously); + + static Task Close(Task task, object state) + { + switch (state) + { + case IChannelHandlerContext ctx: + return ctx.CloseAsync(); + case IChannel ch: + return ch.CloseAsync(); + default: + throw new InvalidOperationException("must never get here"); + } + } + + static Task CloseOnFailure(Task task, object state) + { + if (task.IsFaulted) + { + return Close(task, state); + } + return TaskEx.Completed; + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/Concurrency/AbstractPromise.cs b/src/DotNetty.Common/Concurrency/AbstractPromise.cs new file mode 100644 index 000000000..e02c9bff7 --- /dev/null +++ b/src/DotNetty.Common/Concurrency/AbstractPromise.cs @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Concurrency +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Reflection; + using System.Runtime.CompilerServices; + using System.Runtime.ExceptionServices; + using System.Runtime.InteropServices.ComTypes; + using System.Threading; + using System.Threading.Tasks; + using System.Threading.Tasks.Sources; + + public abstract class AbstractPromise : IPromise, IValueTaskSource + { + static readonly ContextCallback ExecutionContextCallback = Execute; + static readonly SendOrPostCallback SyncContextCallback = Execute; + static readonly SendOrPostCallback SyncContextCallbackWithExecutionContext = ExecuteWithExecutionContext; + static readonly Action TaskSchedulerCallback = Execute; + static readonly Action TaskScheduleCallbackWithExecutionContext = ExecuteWithExecutionContext; + + protected static readonly Exception CompletedSentinel = new Exception(); + + short currentId; + protected Exception exception; + + Action continuation; + object state; + ExecutionContext executionContext; + object schedulingContext; + + public ValueTask ValueTask => new ValueTask(this, this.currentId); + + public bool TryComplete() => this.TryComplete0(CompletedSentinel, out _); + + public bool TrySetException(Exception exception) => this.TryComplete0(exception, out _); + + public bool TrySetCanceled(CancellationToken cancellationToken = default(CancellationToken)) => this.TryComplete0(new OperationCanceledException(cancellationToken), out _); + + protected virtual bool TryComplete0(Exception exception, out bool continuationInvoked) + { + continuationInvoked = false; + + if (this.exception == null) + { + // Set the exception object to the exception passed in or a sentinel value + this.exception = exception; + + if (this.continuation != null) + { + this.ExecuteContinuation(); + continuationInvoked = true; + } + return true; + } + + return false; + } + + public bool SetUncancellable() => true; + + public virtual ValueTaskSourceStatus GetStatus(short token) + { + this.EnsureValidToken(token); + + if (this.exception == null) + { + return ValueTaskSourceStatus.Pending; + } + else if (this.exception == CompletedSentinel) + { + return ValueTaskSourceStatus.Succeeded; + } + else if (this.exception is OperationCanceledException) + { + return ValueTaskSourceStatus.Canceled; + } + else + { + return ValueTaskSourceStatus.Faulted; + } + } + + public virtual void GetResult(short token) + { + this.EnsureValidToken(token); + + if (this.exception == null) + { + throw new InvalidOperationException("Attempt to get result on not yet completed promise"); + } + + this.currentId++; + + if (this.exception != CompletedSentinel) + { + this.ThrowLatchedException(); + } + } + + public virtual void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + this.EnsureValidToken(token); + + if (this.continuation != null) + { + throw new InvalidOperationException("Attempt to subscribe same promise twice"); + } + + this.continuation = continuation; + this.state = state; + this.executionContext = (flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0 ? ExecutionContext.Capture() : null; + + if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) + { + SynchronizationContext sc = SynchronizationContext.Current; + if (sc != null && sc.GetType() != typeof(SynchronizationContext)) + { + this.schedulingContext = sc; + } + else + { + TaskScheduler ts = TaskScheduler.Current; + if (ts != TaskScheduler.Default) + { + this.schedulingContext = ts; + } + } + } + + if (this.exception != null) + { + this.ExecuteContinuation(); + } + } + + public static implicit operator ValueTask(AbstractPromise promise) => promise.ValueTask; + + [MethodImpl(MethodImplOptions.NoInlining)] + void ThrowLatchedException() => ExceptionDispatchInfo.Capture(this.exception).Throw(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected void ClearCallback() + { + this.continuation = null; + this.state = null; + this.executionContext = null; + this.schedulingContext = null; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void EnsureValidToken(short token) + { + if (this.currentId != token) + { + throw new InvalidOperationException("Incorrect ValueTask token"); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void ExecuteContinuation() + { + ExecutionContext executionContext = this.executionContext; + object schedulingContext = this.schedulingContext; + + if (schedulingContext == null) + { + if (executionContext == null) + { + this.ExecuteContinuation0(); + } + else + { + ExecutionContext.Run(executionContext, ExecutionContextCallback, this); + } + } + else if (schedulingContext is SynchronizationContext sc) + { + sc.Post(executionContext == null ? SyncContextCallback : SyncContextCallbackWithExecutionContext, this); + } + else + { + TaskScheduler ts = (TaskScheduler)schedulingContext; + Contract.Assert(ts != null, "Expected a TaskScheduler"); + Task.Factory.StartNew(executionContext == null ? TaskSchedulerCallback : TaskScheduleCallbackWithExecutionContext, this, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); + } + } + + static void Execute(object state) => ((AbstractPromise)state).ExecuteContinuation0(); + + static void ExecuteWithExecutionContext(object state) => ExecutionContext.Run(((AbstractPromise)state).executionContext, ExecutionContextCallback, state); + + protected virtual void ExecuteContinuation0() + { + Contract.Assert(this.continuation != null); + this.continuation(this.state); + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/Concurrency/AbstractRecyclablePromise.cs b/src/DotNetty.Common/Concurrency/AbstractRecyclablePromise.cs new file mode 100644 index 000000000..cfee262b3 --- /dev/null +++ b/src/DotNetty.Common/Concurrency/AbstractRecyclablePromise.cs @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Concurrency +{ + using System; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Threading.Tasks.Sources; + + public abstract class AbstractRecyclablePromise : AbstractPromise + { + static readonly Action RecycleAction = Recycle; + + protected readonly ThreadLocalPool.Handle handle; + + protected bool recycled; + protected IEventExecutor executor; + + protected AbstractRecyclablePromise(ThreadLocalPool.Handle handle) + { + this.handle = handle; + } + + public override ValueTaskSourceStatus GetStatus(short token) + { + this.ThrowIfRecycled(); + return base.GetStatus(token); + } + + public override void GetResult(short token) + { + this.ThrowIfRecycled(); + base.GetResult(token); + } + + protected override bool TryComplete0(Exception exception, out bool continuationInvoked) + { + Contract.Assert(this.executor.InEventLoop, "must be invoked from an event loop"); + this.ThrowIfRecycled(); + + try + { + bool completed = base.TryComplete0(exception, out continuationInvoked); + if (!continuationInvoked) + { + this.Recycle(); + } + return completed; + } + catch + { + this.Recycle(); + throw; + } + } + + protected void Init(IEventExecutor executor) + { + this.executor = executor; + this.recycled = false; + } + + protected virtual void Recycle() + { + Contract.Assert(this.executor.InEventLoop, "must be invoked from an event loop"); + this.exception = null; + this.ClearCallback(); + this.executor = null; + this.recycled = true; + this.handle.Release(this); + } + + public override void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + this.ThrowIfRecycled(); + base.OnCompleted(continuation,state, token, flags); + } + + protected override void ExecuteContinuation0() + { + try + { + base.ExecuteContinuation0(); + } + finally + { + if (this.executor.InEventLoop) + { + this.Recycle(); + } + else + { + this.executor.Execute(RecycleAction, this); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void ThrowIfRecycled() + { + if (this.recycled) + { + throw new InvalidOperationException("Attempt to use recycled channel promise"); + } + } + + static void Recycle(object state) + { + AbstractRecyclablePromise promise = (AbstractRecyclablePromise)state; + Contract.Assert(promise.executor.InEventLoop, "must be invoked from an event loop"); + promise.Recycle(); + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/Concurrency/AggregatingPromise.cs b/src/DotNetty.Common/Concurrency/AggregatingPromise.cs new file mode 100644 index 000000000..5c140c435 --- /dev/null +++ b/src/DotNetty.Common/Concurrency/AggregatingPromise.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Concurrency +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks.Sources; + + public sealed class AggregatingPromise : AbstractPromise + { + readonly IList futures; + int successCount; + int failureCount; + + IList failures; + + public AggregatingPromise(IList futures) + { + Contract.Requires(futures != null); + this.futures = futures; + + foreach (IValueTaskSource future in futures) + { + future.OnCompleted(this.OnFutureCompleted, future, 0, ValueTaskSourceOnCompletedFlags.None); + } + + // Done on arrival? + if (futures.Count == 0) + { + this.TryComplete(); + } + } + + void OnFutureCompleted(object obj) + { + var future = obj as IValueTaskSource; + Contract.Assert(future != null); + + try + { + future.GetResult(0); + this.successCount++; + } + catch(Exception ex) + { + this.failureCount++; + + if (this.failures == null) + { + this.failures = new List(); + } + this.failures.Add(ex); + } + + bool callSetDone = this.successCount + this.failureCount == this.futures.Count; + Contract.Assert(this.successCount + this.failureCount <= this.futures.Count); + + if (callSetDone) + { + if (this.failureCount > 0) + { + this.TrySetException(new AggregateException(this.failures)); + } + else + { + this.TryComplete(); + } + } + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/Concurrency/IPromise.cs b/src/DotNetty.Common/Concurrency/IPromise.cs new file mode 100644 index 000000000..f51e1c28d --- /dev/null +++ b/src/DotNetty.Common/Concurrency/IPromise.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Concurrency +{ + using System; + using System.Threading; + using System.Threading.Tasks; + + public interface IPromise + { + ValueTask ValueTask { get; } + + bool TryComplete(); + + bool TrySetException(Exception exception); + + bool TrySetCanceled(CancellationToken cancellationToken = default(CancellationToken)); + + bool SetUncancellable(); + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/Concurrency/IRecyclable.cs b/src/DotNetty.Common/Concurrency/IRecyclable.cs new file mode 100644 index 000000000..3777dbca1 --- /dev/null +++ b/src/DotNetty.Common/Concurrency/IRecyclable.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Concurrency +{ + public interface IRecyclable + { + void Init(IEventExecutor executor); + + void Recycle(); + } +} \ No newline at end of file diff --git a/src/DotNetty.Common/DotNetty.Common.csproj b/src/DotNetty.Common/DotNetty.Common.csproj index c187ae363..d3539fb00 100644 --- a/src/DotNetty.Common/DotNetty.Common.csproj +++ b/src/DotNetty.Common/DotNetty.Common.csproj @@ -32,6 +32,7 @@ + diff --git a/src/DotNetty.Common/ThreadLocalObjectList.cs b/src/DotNetty.Common/ThreadLocalObjectList.cs index f2e2b6102..71414ffa5 100644 --- a/src/DotNetty.Common/ThreadLocalObjectList.cs +++ b/src/DotNetty.Common/ThreadLocalObjectList.cs @@ -5,7 +5,7 @@ namespace DotNetty.Common { using System.Collections.Generic; - public class ThreadLocalObjectList : List + public sealed class ThreadLocalObjectList : List { const int DefaultInitialCapacity = 8; @@ -13,7 +13,7 @@ public class ThreadLocalObjectList : List readonly ThreadLocalPool.Handle returnHandle; - ThreadLocalObjectList(ThreadLocalPool.Handle returnHandle) + protected ThreadLocalObjectList(ThreadLocalPool.Handle returnHandle) { this.returnHandle = returnHandle; } @@ -28,7 +28,6 @@ public static ThreadLocalObjectList NewInstance(int minCapacity) ret.Capacity = minCapacity; } return ret; - } public void Return() diff --git a/src/DotNetty.Common/Utilities/TaskEx.cs b/src/DotNetty.Common/Utilities/TaskEx.cs index d0093a97d..96b097125 100644 --- a/src/DotNetty.Common/Utilities/TaskEx.cs +++ b/src/DotNetty.Common/Utilities/TaskEx.cs @@ -4,6 +4,7 @@ namespace DotNetty.Common.Utilities { using System; + using System.Runtime.CompilerServices; using System.Threading.Tasks; using DotNetty.Common.Concurrency; @@ -19,6 +20,10 @@ public static class TaskEx public static readonly Task False = Task.FromResult(false); + public static ValueTask ToValueTask(this Exception ex) => new ValueTask(FromException(ex)); + + public static ValueTask ToValueTask(this Exception ex) => new ValueTask(FromException(ex)); + static Task CreateCancelledTask() { var tcs = new TaskCompletionSource(); @@ -40,7 +45,7 @@ public static Task FromException(Exception exception) return tcs.Task; } - static readonly Action LinkOutcomeContinuationAction = (t, tcs) => + static readonly Action LinkOutcomeTcs = (t, tcs) => { switch (t.Status) { @@ -57,7 +62,7 @@ public static Task FromException(Exception exception) throw new ArgumentOutOfRangeException(); } }; - + public static void LinkOutcome(this Task task, TaskCompletionSource taskCompletionSource) { switch (task.Status) @@ -72,14 +77,63 @@ public static void LinkOutcome(this Task task, TaskCompletionSource taskCompleti taskCompletionSource.TryUnwrap(task.Exception); break; default: - task.ContinueWith( - LinkOutcomeContinuationAction, - taskCompletionSource, - TaskContinuationOptions.ExecuteSynchronously); + task.ContinueWith(LinkOutcomeTcs, taskCompletionSource, TaskContinuationOptions.ExecuteSynchronously); + break; + } + } + + static readonly Action LinkOutcomePromise = (t, promise) => + { + switch (t.Status) + { + case TaskStatus.RanToCompletion: + ((IPromise)promise).TryComplete(); + break; + case TaskStatus.Canceled: + ((IPromise)promise).TrySetCanceled(); + break; + case TaskStatus.Faulted: + ((IPromise)promise).TryUnwrap(t.Exception); + break; + default: + throw new ArgumentOutOfRangeException(); + } + }; + + public static void LinkOutcome(this Task task, IPromise promise) + { + switch (task.Status) + { + case TaskStatus.RanToCompletion: + promise.TryComplete(); + break; + case TaskStatus.Canceled: + promise.TrySetCanceled(); + break; + case TaskStatus.Faulted: + promise.TryUnwrap(task.Exception); + break; + default: + task.ContinueWith(LinkOutcomePromise, promise, TaskContinuationOptions.ExecuteSynchronously); break; } } + public static async void LinkOutcome(this ValueTask future, IPromise promise) + { + try + { + //context capturing not required since callback executed synchronously on completion in eventloop + await future; + promise.TryComplete(); + } + catch (Exception ex) + { + promise.TrySetException(ex); + } + } + + static class LinkOutcomeActionHost { public static readonly Action, object> Action = @@ -132,6 +186,18 @@ public static void TryUnwrap(this TaskCompletionSource completionSource, E completionSource.TrySetException(exception); } } + + public static void TryUnwrap(this IPromise promise, Exception exception) + { + if (exception is AggregateException aggregateException) + { + promise.TrySetException(aggregateException.InnerException); + } + else + { + promise.TrySetException(exception); + } + } public static Exception Unwrap(this Exception exception) { diff --git a/src/DotNetty.Handlers/DotNetty.Handlers.csproj b/src/DotNetty.Handlers/DotNetty.Handlers.csproj index d9afdf486..37ad674b2 100644 --- a/src/DotNetty.Handlers/DotNetty.Handlers.csproj +++ b/src/DotNetty.Handlers/DotNetty.Handlers.csproj @@ -43,4 +43,7 @@ + + + \ No newline at end of file diff --git a/src/DotNetty.Handlers/Logging/LoggingHandler.cs b/src/DotNetty.Handlers/Logging/LoggingHandler.cs index 1115d1f41..7363bafcd 100644 --- a/src/DotNetty.Handlers/Logging/LoggingHandler.cs +++ b/src/DotNetty.Handlers/Logging/LoggingHandler.cs @@ -8,6 +8,7 @@ namespace DotNetty.Handlers.Logging using System.Text; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Common.Concurrency; using DotNetty.Common.Internal.Logging; using DotNetty.Transport.Channels; @@ -208,7 +209,7 @@ public override void ChannelRead(IChannelHandlerContext ctx, object message) } ctx.FireChannelRead(message); } - + public override void ChannelReadComplete(IChannelHandlerContext ctx) { if (this.Logger.IsEnabled(this.InternalLevel)) @@ -251,7 +252,7 @@ public override void Read(IChannelHandlerContext ctx) ctx.Read(); } - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { if (this.Logger.IsEnabled(this.InternalLevel)) { diff --git a/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs b/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs index 438856ae5..bf7120af8 100644 --- a/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs +++ b/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs @@ -7,6 +7,7 @@ namespace DotNetty.Handlers.Streams using System.Collections.Generic; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Common; using DotNetty.Common.Concurrency; using DotNetty.Common.Internal.Logging; using DotNetty.Common.Utilities; @@ -39,11 +40,11 @@ public void ResumeTransfer() } } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { - var pendingWrite = new PendingWrite(message); + var pendingWrite = PendingWrite.NewInstance(context.Executor, message); this.queue.Enqueue(pendingWrite); - return pendingWrite.PendingTask; + return pendingWrite; } public override void Flush(IChannelHandlerContext context) => this.DoFlush(context); @@ -97,16 +98,16 @@ void Discard(Exception cause = null) cause = new ClosedChannelException(); } - current.Fail(cause); + current.TrySetException(cause); } else { - current.Success(); + current.TryComplete(); } } catch (Exception exception) { - current.Fail(exception); + current.TrySetException(exception); Logger.Warn($"{StringUtil.SimpleClassName(typeof(ChunkedWriteHandler))}.IsEndOfInput failed", exception); } finally @@ -121,7 +122,7 @@ void Discard(Exception cause = null) cause = new ClosedChannelException(); } - current.Fail(cause); + current.TrySetException(cause); } } } @@ -197,7 +198,7 @@ void DoFlush(IChannelHandlerContext context) ReferenceCountUtil.Release(message); } - current.Fail(exception); + current.TrySetException(exception); CloseInput(chunks); break; @@ -218,7 +219,7 @@ void DoFlush(IChannelHandlerContext context) message = Unpooled.Empty; } - Task future = context.WriteAsync(message); + ValueTask writeFuture = context.WriteAsync(message); if (endOfInput) { this.currentWrite = null; @@ -228,54 +229,62 @@ void DoFlush(IChannelHandlerContext context) // be closed before its not written. // // See https://github.com/netty/netty/issues/303 - future.ContinueWith((_, state) => + CloseOnComplete(writeFuture, current, chunks); + + async void CloseOnComplete(ValueTask future, PendingWrite promise, IChunkedInput input) + { + try { - var pendingTask = (PendingWrite)state; - CloseInput((IChunkedInput)pendingTask.Message); - pendingTask.Success(); - }, - current, - TaskContinuationOptions.ExecuteSynchronously); + await future; + } + finally + { + promise.Progress(input.Progress, input.Length); + promise.TryComplete(); + CloseInput(input); + } + } } else if (channel.IsWritable) { - future.ContinueWith((task, state) => + ProgressOnComplete(writeFuture, current, chunks); + + async void ProgressOnComplete(ValueTask future, PendingWrite promise, IChunkedInput input) + { + try { - var pendingTask = (PendingWrite)state; - if (task.IsFaulted) - { - CloseInput((IChunkedInput)pendingTask.Message); - pendingTask.Fail(task.Exception); - } - else - { - pendingTask.Progress(chunks.Progress, chunks.Length); - } - }, - current, - TaskContinuationOptions.ExecuteSynchronously); + await future; + promise.Progress(input.Progress, input.Length); + } + catch(Exception ex) + { + CloseInput((IChunkedInput)promise.Message); + promise.TrySetException(ex); + } + } } else { - future.ContinueWith((task, state) => + ProgressAndResumeOnComplete(writeFuture, this, channel, chunks); + + async void ProgressAndResumeOnComplete(ValueTask future, ChunkedWriteHandler handler, IChannel ch, IChunkedInput input) { - var handler = (ChunkedWriteHandler) state; - if (task.IsFaulted) - { - CloseInput((IChunkedInput)handler.currentWrite.Message); - handler.currentWrite.Fail(task.Exception); - } - else + PendingWrite promise = handler.currentWrite; + try { - handler.currentWrite.Progress(chunks.Progress, chunks.Length); - if (channel.IsWritable) + await future; + promise.Progress(input.Progress, input.Length); + if (ch.IsWritable) { handler.ResumeTransfer(); } } - }, - this, - TaskContinuationOptions.ExecuteSynchronously); + catch(Exception ex) + { + CloseInput((IChunkedInput)promise.Message); + promise.TrySetException(ex); + } + } } // Flush each chunk to conserve memory @@ -284,22 +293,7 @@ void DoFlush(IChannelHandlerContext context) } else { - context.WriteAsync(pendingMessage) - .ContinueWith((task, state) => - { - var pendingTask = (PendingWrite)state; - if (task.IsFaulted) - { - pendingTask.Fail(task.Exception); - } - else - { - pendingTask.Success(); - } - }, - current, - TaskContinuationOptions.ExecuteSynchronously); - + context.WriteAsync(pendingMessage).LinkOutcome(current); this.currentWrite = null; requiresFlush = true; } @@ -332,37 +326,48 @@ static void CloseInput(IChunkedInput chunks) } } - sealed class PendingWrite + sealed class PendingWrite : AbstractRecyclablePromise { - readonly TaskCompletionSource promise; - - public PendingWrite(object msg) + static readonly ThreadLocalPool Pool = new ThreadLocalPool(h => new PendingWrite(h)); + + PendingWrite(ThreadLocalPool.Handle handle) + : base(handle) { - this.Message = msg; - this.promise = new TaskCompletionSource(); + } + + public static PendingWrite NewInstance(IEventExecutor executor, object msg) + { + PendingWrite entry = Pool.Take(); + entry.Init(executor); + entry.Message = msg; + return entry; } - public object Message { get; } - - public void Success() => this.promise.TryComplete(); + public object Message { get; private set; } - public void Fail(Exception error) + protected override bool TryComplete0(Exception exception, out bool continuationInvoked) { - ReferenceCountUtil.Release(this.Message); - this.promise.TrySetException(error); + if (exception != CompletedSentinel) + { + ReferenceCountUtil.Release(this.Message); + } + + return base.TryComplete0(exception, out continuationInvoked); } public void Progress(long progress, long total) { - if (progress < total) + /*if (progress < total) { return; - } - - this.Success(); + }*/ } - public Task PendingTask => this.promise.Task; + protected override void Recycle() + { + this.Message = null; + base.Recycle(); + } } } } diff --git a/src/DotNetty.Handlers/Timeout/IdleStateHandler.cs b/src/DotNetty.Handlers/Timeout/IdleStateHandler.cs index f20f52fab..0e509f967 100644 --- a/src/DotNetty.Handlers/Timeout/IdleStateHandler.cs +++ b/src/DotNetty.Handlers/Timeout/IdleStateHandler.cs @@ -201,7 +201,7 @@ public IdleStateHandler(bool observeOutput, ? TimeUtil.Max(allIdleTime, IdleStateHandler.MinTimeout) : TimeSpan.Zero; - this.writeListener = new Action(antecedent => + this.writeListener = new Action(t => { this.lastWriteTime = this.Ticks(); this.firstWriterIdleEvent = this.firstAllIdleEvent = true; @@ -303,17 +303,17 @@ public override void ChannelReadComplete(IChannelHandlerContext context) context.FireChannelReadComplete(); } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { + ValueTask future = context.WriteAsync(message); if (this.writerIdleTime.Ticks > 0 || this.allIdleTime.Ticks > 0) { - Task task = context.WriteAsync(message); + //task allocation since we attach continuation and returning task to a caller + Task task = future.AsTask(); task.ContinueWith(this.writeListener, TaskContinuationOptions.ExecuteSynchronously); - - return task; + return new ValueTask(task); } - - return context.WriteAsync(message); + return future; } void Initialize(IChannelHandlerContext context) diff --git a/src/DotNetty.Handlers/Timeout/WriteTimeoutHandler.cs b/src/DotNetty.Handlers/Timeout/WriteTimeoutHandler.cs index 25fa214fb..afc0e612e 100644 --- a/src/DotNetty.Handlers/Timeout/WriteTimeoutHandler.cs +++ b/src/DotNetty.Handlers/Timeout/WriteTimeoutHandler.cs @@ -83,16 +83,19 @@ public WriteTimeoutHandler(TimeSpan timeout) : TimeSpan.Zero; } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { - Task task = context.WriteAsync(message); + ValueTask future = context.WriteAsync(message); if (this.timeout.Ticks > 0) { + //allocating task cause we need to attach continuation + Task task = future.AsTask(); this.ScheduleTimeout(context, task); + return new ValueTask(task); } - return task; + return future; } public override void HandlerRemoved(IChannelHandlerContext context) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 063aa2db9..d220bd1b6 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -508,13 +508,13 @@ bool EnsureAuthenticated() return oldState.Has(TlsHandlerState.Authenticated); } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { if (!(message is IByteBuffer)) { - return TaskEx.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer))); + return new UnsupportedMessageTypeException(message, typeof(IByteBuffer)).ToValueTask(); } - return this.pendingUnencryptedWrites.Add(message); + return new ValueTask(this.pendingUnencryptedWrites.Add(message)); } public override void Flush(IChannelHandlerContext context) @@ -572,7 +572,7 @@ void Wrap(IChannelHandlerContext context) buf.ReadBytes(this.sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times buf.Release(); - TaskCompletionSource promise = this.pendingUnencryptedWrites.Remove(); + IPromise promise = this.pendingUnencryptedWrites.Remove(); Task task = this.lastContextWriteTask; if (task != null) { @@ -606,14 +606,14 @@ void FinishWrap(byte[] buffer, int offset, int count) output.WriteBytes(buffer, offset, count); } - this.lastContextWriteTask = this.capturedContext.WriteAsync(output); + this.lastContextWriteTask = this.capturedContext.WriteAsync(output).AsTask(); } Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) { - var future = this.capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count)); + Task task = this.capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count), true).AsTask(); this.ReadIfNeeded(this.capturedContext); - return future; + return task; } public override Task CloseAsync(IChannelHandlerContext context) @@ -816,6 +816,7 @@ IAsyncResult PrepareSyncReadResult(int readBytes, object state) public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => this.owner.FinishWrapNonAppDataAsync(buffer, offset, count); + #if !NETSTANDARD1_3 static readonly Action WriteCompleteCallback = HandleChannelWriteComplete; diff --git a/src/DotNetty.Transport/Channels/AbstractChannel.cs b/src/DotNetty.Transport/Channels/AbstractChannel.cs index 10f06f7a5..92496b2ec 100644 --- a/src/DotNetty.Transport/Channels/AbstractChannel.cs +++ b/src/DotNetty.Transport/Channels/AbstractChannel.cs @@ -189,9 +189,11 @@ public IChannel Read() return this; } - public Task WriteAsync(object msg) => this.pipeline.WriteAsync(msg); + public ValueTask WriteAsync(object msg) => this.pipeline.WriteAsync(msg); public Task WriteAndFlushAsync(object message) => this.pipeline.WriteAndFlushAsync(message); + + public ValueTask WriteAndFlushAsync(object message, bool notifyComplete) => this.pipeline.WriteAndFlushAsync(message, notifyComplete); public Task CloseCompletion => this.closeFuture.Task; @@ -670,7 +672,7 @@ public void BeginRead() } } - public Task WriteAsync(object msg) + public ValueTask WriteAsync(object msg) { this.AssertEventLoop(); @@ -684,7 +686,7 @@ public Task WriteAsync(object msg) // release message now to prevent resource-leak ReferenceCountUtil.Release(msg); - return TaskEx.FromException(new ClosedChannelException()); + return new ClosedChannelException().ToValueTask(); } int size; @@ -700,13 +702,10 @@ public Task WriteAsync(object msg) catch (Exception t) { ReferenceCountUtil.Release(msg); - - return TaskEx.FromException(t); + return t.ToValueTask(); } - var promise = new TaskCompletionSource(); - outboundBuffer.AddMessage(msg, size, promise); - return promise.Task; + return outboundBuffer.AddMessage(msg, size); } public void Flush() diff --git a/src/DotNetty.Transport/Channels/AbstractChannelHandlerContext.cs b/src/DotNetty.Transport/Channels/AbstractChannelHandlerContext.cs index 6401d5c1f..10036c5fa 100644 --- a/src/DotNetty.Transport/Channels/AbstractChannelHandlerContext.cs +++ b/src/DotNetty.Transport/Channels/AbstractChannelHandlerContext.cs @@ -4,11 +4,14 @@ namespace DotNetty.Transport.Channels { using System; + using System.Collections; + using System.Diagnostics; using System.Diagnostics.Contracts; using System.Net; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading.Tasks; + using System.Threading.Tasks.Sources; using DotNetty.Buffers; using DotNetty.Common; using DotNetty.Common.Concurrency; @@ -793,16 +796,16 @@ void InvokeRead() } } - public Task WriteAsync(object msg) + public ValueTask WriteAsync(object msg) { Contract.Requires(msg != null); // todo: check for cancellation - return this.WriteAsync(msg, false); + return this.WriteAsync(msg, FlushMode.NoFlush); } - Task InvokeWriteAsync(object msg) => this.Added ? this.InvokeWriteAsync0(msg) : this.WriteAsync(msg); + ValueTask InvokeWriteAsync(object msg) => this.Added ? this.InvokeWriteAsync0(msg) : this.WriteAsync(msg); - Task InvokeWriteAsync0(object msg) + ValueTask InvokeWriteAsync0(object msg) { try { @@ -810,7 +813,7 @@ Task InvokeWriteAsync0(object msg) } catch (Exception ex) { - return ComposeExceptionTask(ex); + return ex.ToValueTask(); } } @@ -853,44 +856,64 @@ void InvokeFlush0() } } - public Task WriteAndFlushAsync(object message) + public Task WriteAndFlushAsync(object message) => this.WriteAndFlushAsync(message, true).AsTask(); + + public ValueTask WriteAndFlushAsync(object message, bool notifyComplete) { Contract.Requires(message != null); // todo: check for cancellation - return this.WriteAsync(message, true); + return this.WriteAsync(message, notifyComplete ? FlushMode.Flush : FlushMode.VoidFlush); } - Task InvokeWriteAndFlushAsync(object msg) + ValueTask InvokeWriteAndFlushAsync(object msg, bool notifyComplete) { if (this.Added) { - Task task = this.InvokeWriteAsync0(msg); + ValueTask task = this.InvokeWriteAsync0(msg); + //flush can synchronously complete write, hence Task allocation required to capture result + task = notifyComplete ? task.Preserve() : default(ValueTask); this.InvokeFlush0(); return task; } - return this.WriteAndFlushAsync(msg); + return this.WriteAndFlushAsync(msg, notifyComplete); } - Task WriteAsync(object msg, bool flush) + ValueTask WriteAsync(object msg, FlushMode mode) { AbstractChannelHandlerContext next = this.FindContextOutbound(); object m = this.pipeline.Touch(msg, next); IEventExecutor nextExecutor = next.Executor; if (nextExecutor.InEventLoop) { - return flush - ? next.InvokeWriteAndFlushAsync(m) - : next.InvokeWriteAsync(m); + return mode == FlushMode.NoFlush + ? next.InvokeWriteAsync(m) + : next.InvokeWriteAndFlushAsync(m, mode == FlushMode.Flush); } else { - var promise = new TaskCompletionSource(); - AbstractWriteTask task = flush - ? WriteAndFlushTask.NewInstance(next, m, promise) - : (AbstractWriteTask)WriteTask.NewInstance(next, m, promise); - SafeExecuteOutbound(nextExecutor, task, promise, msg); - return promise.Task; + AbstractWriteTask task = mode == FlushMode.NoFlush + ? WriteTask.NewInstance(next, m) + : (AbstractWriteTask)WriteAndFlushTask.NewInstance(next, m, mode == FlushMode.Flush); + + ValueTask result; + switch (mode) + { + case FlushMode.NoFlush: + result = task; + break; + case FlushMode.Flush: + //flush can synchronously complete write, hence Task allocation required to capture result + result = ((ValueTask)task).Preserve(); + break; + case FlushMode.VoidFlush: + default: + result = default(ValueTask); + break; + } + + SafeExecuteOutbound(nextExecutor, task, msg); + return result; } } @@ -952,7 +975,7 @@ static Task SafeExecuteOutboundAsync(IEventExecutor executor, Func functio return promise.Task; } - static void SafeExecuteOutbound(IEventExecutor executor, IRunnable task, TaskCompletionSource promise, object msg) + static void SafeExecuteOutbound(IEventExecutor executor, AbstractWriteTask task, object msg) { try { @@ -962,7 +985,7 @@ static void SafeExecuteOutbound(IEventExecutor executor, IRunnable task, TaskCom { try { - promise.TrySetException(cause); + task.TrySetException(cause); } finally { @@ -975,8 +998,14 @@ static void SafeExecuteOutbound(IEventExecutor executor, IRunnable task, TaskCom public override string ToString() => $"{typeof(IChannelHandlerContext).Name} ({this.Name}, {this.Channel})"; + enum FlushMode : byte + { + NoFlush = 0, + VoidFlush = 1, + Flush = 2 + } - abstract class AbstractWriteTask : IRunnable + abstract class AbstractWriteTask : AbstractRecyclablePromise, IRunnable { static readonly bool EstimateTaskSizeOnSubmit = SystemPropertyUtil.GetBoolean("io.netty.transport.estimateSizeOnSubmit", true); @@ -985,17 +1014,15 @@ abstract class AbstractWriteTask : IRunnable static readonly int WriteTaskOverhead = SystemPropertyUtil.GetInt("io.netty.transport.writeTaskSizeOverhead", 56); - ThreadLocalPool.Handle handle; AbstractChannelHandlerContext ctx; object msg; - TaskCompletionSource promise; int size; - protected static void Init(AbstractWriteTask task, AbstractChannelHandlerContext ctx, object msg, TaskCompletionSource promise) + protected static void Init(AbstractWriteTask task, AbstractChannelHandlerContext ctx, object msg) { + task.Init(ctx.Executor); task.ctx = ctx; task.msg = msg; - task.promise = promise; if (EstimateTaskSizeOnSubmit) { @@ -1018,9 +1045,9 @@ protected static void Init(AbstractWriteTask task, AbstractChannelHandlerContext } } - protected AbstractWriteTask(ThreadLocalPool.Handle handle) + protected AbstractWriteTask(ThreadLocalPool.Handle handle) : base(handle) { - this.handle = handle; + } public void Run() @@ -1033,28 +1060,42 @@ public void Run() { buffer?.DecrementPendingOutboundBytes(this.size); } - this.WriteAsync(this.ctx, this.msg).LinkOutcome(this.promise); + + this.WriteAsync(this.ctx, this.msg).LinkOutcome(this); + } + catch (Exception ex) + { + this.TrySetException(ex); } finally { // Set to null so the GC can collect them directly this.ctx = null; this.msg = null; - this.promise = null; - this.handle.Release(this); + + //this.Recycle(); + //this.handle.Release(this); } } - protected virtual Task WriteAsync(AbstractChannelHandlerContext ctx, object msg) => ctx.InvokeWriteAsync(msg); - } - sealed class WriteTask : AbstractWriteTask { + protected abstract ValueTask WriteAsync(AbstractChannelHandlerContext ctx, object msg); + /*public override void Recycle() + { + base.Recycle(); + this.handle.Release(this); + }*/ + } + + + sealed class WriteTask : AbstractWriteTask + { static readonly ThreadLocalPool Recycler = new ThreadLocalPool(handle => new WriteTask(handle)); - public static WriteTask NewInstance(AbstractChannelHandlerContext ctx, object msg, TaskCompletionSource promise) + public static WriteTask NewInstance(AbstractChannelHandlerContext ctx, object msg) { WriteTask task = Recycler.Take(); - Init(task, ctx, msg, promise); + Init(task, ctx, msg); return task; } @@ -1062,17 +1103,21 @@ public static WriteTask NewInstance(AbstractChannelHandlerContext ctx, object ms : base(handle) { } + + protected override ValueTask WriteAsync(AbstractChannelHandlerContext ctx, object msg) => ctx.InvokeWriteAsync(msg); } sealed class WriteAndFlushTask : AbstractWriteTask - { - + { + bool notifyComplete; + static readonly ThreadLocalPool Recycler = new ThreadLocalPool(handle => new WriteAndFlushTask(handle)); - public static WriteAndFlushTask NewInstance( - AbstractChannelHandlerContext ctx, object msg, TaskCompletionSource promise) { + public static WriteAndFlushTask NewInstance(AbstractChannelHandlerContext ctx, object msg, bool notifyComplete) + { WriteAndFlushTask task = Recycler.Take(); - Init(task, ctx, msg, promise); + Init(task, ctx, msg); + task.notifyComplete = notifyComplete; return task; } @@ -1080,13 +1125,10 @@ public static WriteAndFlushTask NewInstance( : base(handle) { } + + //notifyComplete is always true since continuation triggers WriteAndFlushTask completion + protected override ValueTask WriteAsync(AbstractChannelHandlerContext ctx, object msg) => ctx.InvokeWriteAndFlushAsync(msg, this.notifyComplete); - protected override Task WriteAsync(AbstractChannelHandlerContext ctx, object msg) - { - Task result = base.WriteAsync(ctx, msg); - ctx.InvokeFlush(); - return result; - } } } } \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/BatchingPendingWriteQueue.cs b/src/DotNetty.Transport/Channels/BatchingPendingWriteQueue.cs index 8d2fc2278..eb5ec0b99 100644 --- a/src/DotNetty.Transport/Channels/BatchingPendingWriteQueue.cs +++ b/src/DotNetty.Transport/Channels/BatchingPendingWriteQueue.cs @@ -6,7 +6,10 @@ namespace DotNetty.Transport.Channels using System; using System.Collections.Generic; using System.Diagnostics.Contracts; + using System.Globalization; + using System.Runtime.CompilerServices; using System.Threading.Tasks; + using System.Threading.Tasks.Sources; using DotNetty.Common; using DotNetty.Common.Concurrency; using DotNetty.Common.Internal.Logging; @@ -82,12 +85,11 @@ public Task Add(object msg) if (canBundle) { currentTail.Add(msg, messageSize); - return currentTail.Promise.Task; + return currentTail.Task; } } - var promise = new TaskCompletionSource(); - PendingWrite write = PendingWrite.NewInstance(msg, messageSize, promise); + PendingWrite write = PendingWrite.NewInstance(this.ctx.Executor, msg, messageSize); if (currentTail == null) { this.tail = this.head = write; @@ -102,7 +104,7 @@ public Task Add(object msg) // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 this.buffer?.IncrementPendingOutboundBytes(messageSize); - return promise.Task; + return write.Task; } /// @@ -124,9 +126,8 @@ public void RemoveAndFailAll(Exception cause) { PendingWrite next = write.Next; ReleaseMessages(write.Messages); - TaskCompletionSource promise = write.Promise; this.Recycle(write, false); - Util.SafeSetFailure(promise, cause, Logger); + Util.SafeSetFailure(write, cause, Logger); write = next; } this.AssertEmpty(); @@ -149,8 +150,7 @@ public void RemoveAndFail(Exception cause) return; } ReleaseMessages(write.Messages); - TaskCompletionSource promise = write.Promise; - Util.SafeSetFailure(promise, cause, Logger); + Util.SafeSetFailure(write, cause, Logger); this.Recycle(write, true); } @@ -188,10 +188,10 @@ public Task RemoveAndWriteAllAsync() { PendingWrite next = write.Next; object msg = write.Messages; - TaskCompletionSource promise = write.Promise; this.Recycle(write, false); - this.ctx.WriteAsync(msg).LinkOutcome(promise); - tasks.Add(promise.Task); + this.ctx.WriteAsync(msg).LinkOutcome(write); + + tasks.Add(write.Task); write = next; } this.AssertEmpty(); @@ -218,17 +218,16 @@ public Task RemoveAndWriteAsync() return null; } object msg = write.Messages; - TaskCompletionSource promise = write.Promise; this.Recycle(write, true); - this.ctx.WriteAsync(msg).LinkOutcome(promise); - return promise.Task; + this.ctx.WriteAsync(msg).LinkOutcome(write); + return write.Task; } /// /// Removes a pending write operation and release it's message via . /// /// of the pending write or null if the queue is empty. - public TaskCompletionSource Remove() + public IPromise Remove() { Contract.Assert(this.ctx.Executor.InEventLoop); @@ -237,10 +236,9 @@ public TaskCompletionSource Remove() { return null; } - TaskCompletionSource promise = write.Promise; ReferenceCountUtil.SafeRelease(write.Messages); this.Recycle(write, true); - return promise; + return write; } /// @@ -303,7 +301,6 @@ void Recycle(PendingWrite write, bool update) } } - write.Recycle(); // We need to guard against null as channel.unsafe().outboundBuffer() may returned null // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 @@ -319,27 +316,26 @@ static void ReleaseMessages(List messages) } /// Holds all meta-data and construct the linked-list structure. - sealed class PendingWrite + sealed class PendingWrite : AbstractRecyclablePromise { static readonly ThreadLocalPool Pool = new ThreadLocalPool(handle => new PendingWrite(handle)); - readonly ThreadLocalPool.Handle handle; + Task task; + public PendingWrite Next; public long Size; - public TaskCompletionSource Promise; public readonly List Messages; - PendingWrite(ThreadLocalPool.Handle handle) + PendingWrite(ThreadLocalPool.Handle handle) : base(handle) { this.Messages = new List(); - this.handle = handle; } - public static PendingWrite NewInstance(object msg, int size, TaskCompletionSource promise) + public static PendingWrite NewInstance(IEventExecutor executor, object msg, int size) { PendingWrite write = Pool.Take(); + write.Init(executor); write.Add(msg, size); - write.Promise = promise; return write; } @@ -349,13 +345,16 @@ public void Add(object msg, int size) this.Size += size; } - public void Recycle() + //pendingwrite instances can be returned to a caller times but AbstractPromise supports single continuation only + public Task Task => this.task ?? (this.task = this.ValueTask.AsTask()); + + protected override void Recycle() { this.Size = 0; this.Next = null; this.Messages.Clear(); - this.Promise = null; - this.handle.Release(this); + this.task = null; + base.Recycle(); } } } diff --git a/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs b/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs index 243dfb657..c6a53ea54 100644 --- a/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs +++ b/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs @@ -7,6 +7,7 @@ namespace DotNetty.Transport.Channels using System.Net; using System.Threading.Tasks; using DotNetty.Common.Utilities; + using DotNetty.Common.Concurrency; public class ChannelHandlerAdapter : IChannelHandler { @@ -47,7 +48,7 @@ public virtual void HandlerRemoved(IChannelHandlerContext context) public virtual void UserEventTriggered(IChannelHandlerContext context, object evt) => context.FireUserEventTriggered(evt); [Skip] - public virtual Task WriteAsync(IChannelHandlerContext context, object message) => context.WriteAsync(message); + public virtual ValueTask WriteAsync(IChannelHandlerContext context, object message) => context.WriteAsync(message); [Skip] public virtual void Flush(IChannelHandlerContext context) => context.Flush(); diff --git a/src/DotNetty.Transport/Channels/ChannelOutboundBuffer.cs b/src/DotNetty.Transport/Channels/ChannelOutboundBuffer.cs index 960091015..ebdd7842f 100644 --- a/src/DotNetty.Transport/Channels/ChannelOutboundBuffer.cs +++ b/src/DotNetty.Transport/Channels/ChannelOutboundBuffer.cs @@ -11,8 +11,14 @@ namespace DotNetty.Transport.Channels using System; using System.Collections.Generic; using System.Diagnostics; + using System.ComponentModel; using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Runtime.ExceptionServices; + using System.Security.Cryptography; using System.Threading; + using System.Threading.Tasks; + using System.Threading.Tasks.Sources; using DotNetty.Buffers; using DotNetty.Common; using DotNetty.Common.Concurrency; @@ -53,15 +59,14 @@ internal ChannelOutboundBuffer(IChannel channel) } /// - /// Adds the given message to this . The given - /// will be notified once the message was written. + /// Adds the given message to this . Returned + /// will be notified once the message was written. /// /// The message to add to the buffer. /// The size of the message. - /// The to notify once the message is written. - public void AddMessage(object msg, int size, TaskCompletionSource promise) + public ValueTask AddMessage(object msg, int size) { - Entry entry = Entry.NewInstance(msg, size, promise); + Entry entry = Entry.NewInstance(this.channel.EventLoop, msg, size); if (this.tailEntry == null) { this.flushedEntry = null; @@ -81,6 +86,8 @@ public void AddMessage(object msg, int size, TaskCompletionSource promise) // increment pending bytes after adding message to the unflushed arrays. // See https://github.com/netty/netty/issues/1619 this.IncrementPendingOutboundBytes(size, false); + + return entry; } /// @@ -104,7 +111,7 @@ public void AddFlush() do { this.flushed++; - if (!entry.Promise.SetUncancellable()) + if (!entry.SetUncancellable()) { // Was cancelled so make sure we free up memory and notify about the freed bytes int pending = entry.Cancel(); @@ -201,7 +208,6 @@ public bool Remove() } object msg = e.Message; - TaskCompletionSource promise = e.Promise; int size = e.PendingSize; this.RemoveEntry(e); @@ -210,13 +216,10 @@ public bool Remove() { // only release message, notify and decrement if it was not canceled before. ReferenceCountUtil.SafeRelease(msg); - SafeSuccess(promise); + Util.SafeSetSuccess(e, Logger); this.DecrementPendingOutboundBytes(size, false, true); } - // recycle the entry - e.Recycle(); - return true; } @@ -239,7 +242,7 @@ bool Remove0(Exception cause, bool notifyWritability) } object msg = e.Message; - TaskCompletionSource promise = e.Promise; + //TaskCompletionSource promise = e.Promise; int size = e.PendingSize; this.RemoveEntry(e); @@ -248,12 +251,14 @@ bool Remove0(Exception cause, bool notifyWritability) { // only release message, fail and decrement if it was not canceled before. ReferenceCountUtil.SafeRelease(msg); - SafeFail(promise, cause); + + Util.SafeSetFailure(e, cause, Logger); + this.DecrementPendingOutboundBytes(size, false, notifyWritability); } // recycle the entry - e.Recycle(); + //e.Recycle(); return true; } @@ -665,7 +670,7 @@ internal void Close(Exception cause, bool allowChannelOpen) if (!e.Cancelled) { ReferenceCountUtil.SafeRelease(e.Message); - SafeFail(e.Promise, cause); + Util.SafeSetFailure(e, cause, Logger); } e = e.RecycleAndGetNext(); } @@ -783,31 +788,29 @@ public interface IMessageProcessor bool ProcessMessage(object msg); } - sealed class Entry + sealed class Entry : AbstractRecyclablePromise { static readonly ThreadLocalPool Pool = new ThreadLocalPool(h => new Entry(h)); - readonly ThreadLocalPool.Handle handle; public Entry Next; public object Message; public ArraySegment[] Buffers; public ArraySegment Buffer; - public TaskCompletionSource Promise; public int PendingSize; public int Count = -1; public bool Cancelled; - Entry(ThreadLocalPool.Handle handle) + Entry(ThreadLocalPool.Handle handle) + : base(handle) { - this.handle = handle; } - public static Entry NewInstance(object msg, int size, TaskCompletionSource promise) + public static Entry NewInstance(IEventExecutor executor, object msg, int size) { Entry entry = Pool.Take(); + entry.Init(executor); entry.Message = msg; entry.PendingSize = size; - entry.Promise = promise; return entry; } @@ -830,23 +833,23 @@ public int Cancel() return 0; } - public void Recycle() + protected override void Recycle() { this.Next = null; this.Buffers = null; this.Buffer = new ArraySegment(); this.Message = null; - this.Promise = null; this.PendingSize = 0; this.Count = -1; this.Cancelled = false; - this.handle.Release(this); + + base.Recycle(); } public Entry RecycleAndGetNext() { Entry next = this.Next; - this.Recycle(); + //this.Recycle(); return next; } } diff --git a/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs index 0f5dce1a9..bc6ab0c25 100644 --- a/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs +++ b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs @@ -354,7 +354,7 @@ public override void Read(IChannelHandlerContext context) } } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { Contract.Assert(context == this.outboundCtx.InnerContext); @@ -491,7 +491,7 @@ public IChannelHandlerContext Read() return this; } - public Task WriteAsync(object message) => this.ctx.WriteAsync(message); + public ValueTask WriteAsync(object message) => this.ctx.WriteAsync(message); public IChannelHandlerContext Flush() { @@ -500,6 +500,8 @@ public IChannelHandlerContext Flush() } public Task WriteAndFlushAsync(object message) => this.ctx.WriteAndFlushAsync(message); + + public ValueTask WriteAndFlushAsync(object message, bool notifyComplete) => this.ctx.WriteAndFlushAsync(message, notifyComplete); public IAttribute GetAttribute(AttributeKey key) where T : class => this.ctx.GetAttribute(key); diff --git a/src/DotNetty.Transport/Channels/DefaultChannelPipeline.cs b/src/DotNetty.Transport/Channels/DefaultChannelPipeline.cs index 1531f5676..3186fc3e6 100644 --- a/src/DotNetty.Transport/Channels/DefaultChannelPipeline.cs +++ b/src/DotNetty.Transport/Channels/DefaultChannelPipeline.cs @@ -822,7 +822,7 @@ public IChannelPipeline Read() return this; } - public Task WriteAsync(object msg) => this.tail.WriteAsync(msg); + public ValueTask WriteAsync(object msg) => this.tail.WriteAsync(msg); public IChannelPipeline Flush() { @@ -831,6 +831,8 @@ public IChannelPipeline Flush() } public Task WriteAndFlushAsync(object msg) => this.tail.WriteAndFlushAsync(msg); + + public ValueTask WriteAndFlushAsync(object msg, bool notifyComplete) => this.tail.WriteAndFlushAsync(msg, notifyComplete); string FilterName(string name, IChannelHandler handler) { @@ -1049,7 +1051,7 @@ public void HandlerRemoved(IChannelHandlerContext context) public void UserEventTriggered(IChannelHandlerContext context, object evt) => ReferenceCountUtil.Release(evt); [Skip] - public Task WriteAsync(IChannelHandlerContext ctx, object message) => ctx.WriteAsync(message); + public ValueTask WriteAsync(IChannelHandlerContext ctx, object message) => ctx.WriteAsync(message); [Skip] public void Flush(IChannelHandlerContext context) => context.Flush(); @@ -1092,7 +1094,7 @@ public HeadContext(DefaultChannelPipeline pipeline) public void Read(IChannelHandlerContext context) => this.channelUnsafe.BeginRead(); - public Task WriteAsync(IChannelHandlerContext context, object message) => this.channelUnsafe.WriteAsync(message); + public ValueTask WriteAsync(IChannelHandlerContext context, object message) => this.channelUnsafe.WriteAsync(message); [Skip] public void HandlerAdded(IChannelHandlerContext context) diff --git a/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs b/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs index d3bd4ca07..47155cb2d 100644 --- a/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs +++ b/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs @@ -11,6 +11,7 @@ namespace DotNetty.Transport.Channels.Embedded using System.Runtime.ExceptionServices; using System.Threading.Tasks; using DotNetty.Common; + using DotNetty.Common.Concurrency; using DotNetty.Common.Internal.Logging; using DotNetty.Common.Utilities; @@ -294,53 +295,34 @@ public bool WriteOutbound(params object[] msgs) return IsNotEmpty(this.outboundMessages); } - ThreadLocalObjectList futures = ThreadLocalObjectList.NewInstance(msgs.Length); - foreach (object m in msgs) { if (m == null) { break; } - futures.Add(this.WriteAsync(m)); + WriteAsync(m); } // We need to call RunPendingTasks first as a IChannelHandler may have used IEventLoop.Execute(...) to // delay the write on the next event loop run. this.RunPendingTasks(); this.Flush(); + this.RunPendingTasks(); + this.CheckException(); + return IsNotEmpty(this.outboundMessages); - int size = futures.Count; - for (int i = 0; i < size; i++) + async void WriteAsync(object message) { - var future = (Task)futures[i]; - if (future.IsCompleted) + try { - this.RecordException(future); + //context capturing not required since callback executed synchrounusly on completion in eventloop + await this.WriteAsync(message).ConfigureAwait(false); } - else + catch (Exception e) { - // The write may be delayed to run later by runPendingTasks() - future.ContinueWith(t => this.RecordException(t)); + this.RecordException(e); } } - futures.Return(); - - this.RunPendingTasks(); - this.CheckException(); - return IsNotEmpty(this.outboundMessages); - } - - void RecordException(Task future) - { - switch (future.Status) - { - case TaskStatus.Canceled: - case TaskStatus.Faulted: - this.RecordException(future.Exception); - break; - default: - break; - } } void RecordException(Exception cause) diff --git a/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroup.cs b/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroup.cs index fc2b2565f..bd5bd49a3 100644 --- a/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroup.cs +++ b/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroup.cs @@ -54,13 +54,13 @@ public IChannel Find(IChannelId id) } } - public Task WriteAsync(object message) => this.WriteAsync(message, ChannelMatchers.All()); + public ValueTask WriteAsync(object message) => this.WriteAsync(message, ChannelMatchers.All()); - public Task WriteAsync(object message, IChannelMatcher matcher) + public ValueTask WriteAsync(object message, IChannelMatcher matcher) { Contract.Requires(message != null); Contract.Requires(matcher != null); - var futures = new Dictionary(); + var futures = new Dictionary(); foreach (IChannel c in this.nonServerChannels.Values) { if (matcher.Matches(c)) @@ -70,7 +70,7 @@ public Task WriteAsync(object message, IChannelMatcher matcher) } ReferenceCountUtil.Release(message); - return new DefaultChannelGroupCompletionSource(this, futures /*, this.executor*/).Task; + return new DefaultChannelGroupPromise(futures); } public IChannelGroup Flush(IChannelMatcher matcher) @@ -144,23 +144,23 @@ public bool Remove(IChannel channel) IEnumerator IEnumerable.GetEnumerator() => new CombinedEnumerator(this.serverChannels.Values.GetEnumerator(), this.nonServerChannels.Values.GetEnumerator()); - public Task WriteAndFlushAsync(object message) => this.WriteAndFlushAsync(message, ChannelMatchers.All()); + public ValueTask WriteAndFlushAsync(object message) => this.WriteAndFlushAsync(message, ChannelMatchers.All()); - public Task WriteAndFlushAsync(object message, IChannelMatcher matcher) + public ValueTask WriteAndFlushAsync(object message, IChannelMatcher matcher) { Contract.Requires(message != null); Contract.Requires(matcher != null); - var futures = new Dictionary(); + var futures = new Dictionary(); foreach (IChannel c in this.nonServerChannels.Values) { if (matcher.Matches(c)) { - futures.Add(c, c.WriteAndFlushAsync(SafeDuplicate(message))); + futures.Add(c, c.WriteAndFlushAsync(SafeDuplicate(message), true)); } } ReferenceCountUtil.Release(message); - return new DefaultChannelGroupCompletionSource(this, futures /*, this.executor*/).Task; + return new DefaultChannelGroupPromise(futures); } public Task DisconnectAsync() => this.DisconnectAsync(ChannelMatchers.All()); diff --git a/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroupPromise.cs b/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroupPromise.cs new file mode 100644 index 000000000..5e5ab81c0 --- /dev/null +++ b/src/DotNetty.Transport/Channels/Groups/DefaultChannelGroupPromise.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Transport.Channels.Groups +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Threading.Tasks; + using DotNetty.Common.Concurrency; + + public sealed class DefaultChannelGroupPromise : AbstractPromise + { + readonly int count; + int failureCount; + int successCount; + IList> failures; + + public DefaultChannelGroupPromise(Dictionary futures) + { + Contract.Requires(futures != null); + + if (futures.Count == 0) + { + this.TryComplete(); + } + else + { + this.count = futures.Count; + foreach (KeyValuePair pair in futures) + { + this.Await(pair); + } + } + } + + async void Await(KeyValuePair pair) + { + try + { + await pair.Value; + this.successCount++; + } + catch(Exception ex) + { + this.failureCount++; + if (this.failures == null) + { + this.failures = new List>(); + } + this.failures.Add(new KeyValuePair(pair.Key, ex)); + } + + bool callSetDone = this.successCount + this.failureCount == this.count; + Contract.Assert(this.successCount + this.failureCount <= this.count); + + if (callSetDone) + { + if (this.failureCount > 0) + { + this.TrySetException(new ChannelGroupException(this.failures)); + } + else + { + this.TryComplete(); + } + } + } + + } +} \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/Groups/IChannelGroup.cs b/src/DotNetty.Transport/Channels/Groups/IChannelGroup.cs index 2256dbb12..b998b6be0 100644 --- a/src/DotNetty.Transport/Channels/Groups/IChannelGroup.cs +++ b/src/DotNetty.Transport/Channels/Groups/IChannelGroup.cs @@ -6,6 +6,7 @@ namespace DotNetty.Transport.Channels.Groups using System; using System.Collections.Generic; using System.Threading.Tasks; + using DotNetty.Common.Concurrency; public interface IChannelGroup : ICollection, IComparable { @@ -17,17 +18,17 @@ public interface IChannelGroup : ICollection, IComparable @@ -61,10 +62,12 @@ public interface IChannel : IAttributeMap, IComparable IChannel Read(); - Task WriteAsync(object message); + ValueTask WriteAsync(object message); IChannel Flush(); Task WriteAndFlushAsync(object message); + + ValueTask WriteAndFlushAsync(object message, bool notifyComplete); } } \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/IChannelHandler.cs b/src/DotNetty.Transport/Channels/IChannelHandler.cs index c80c34640..2155be5a2 100644 --- a/src/DotNetty.Transport/Channels/IChannelHandler.cs +++ b/src/DotNetty.Transport/Channels/IChannelHandler.cs @@ -6,6 +6,7 @@ namespace DotNetty.Transport.Channels using System; using System.Net; using System.Threading.Tasks; + using DotNetty.Common.Concurrency; public interface IChannelHandler { @@ -39,7 +40,7 @@ public interface IChannelHandler void HandlerRemoved(IChannelHandlerContext context); - Task WriteAsync(IChannelHandlerContext context, object message); + ValueTask WriteAsync(IChannelHandlerContext context, object message); void Flush(IChannelHandlerContext context); diff --git a/src/DotNetty.Transport/Channels/IChannelHandlerContext.cs b/src/DotNetty.Transport/Channels/IChannelHandlerContext.cs index 1a052cbf0..b3e17749d 100644 --- a/src/DotNetty.Transport/Channels/IChannelHandlerContext.cs +++ b/src/DotNetty.Transport/Channels/IChannelHandlerContext.cs @@ -67,11 +67,13 @@ public interface IChannelHandlerContext : IAttributeMap IChannelHandlerContext Read(); - Task WriteAsync(object message); // todo: optimize: add flag saying if handler is interested in task, do not produce task if it isn't needed + ValueTask WriteAsync(object message); // todo: optimize: add flag saying if handler is interested in task, do not produce task if it isn't needed IChannelHandlerContext Flush(); Task WriteAndFlushAsync(object message); + + ValueTask WriteAndFlushAsync(object message, bool notifyComplete); /// /// Request to bind to the given . diff --git a/src/DotNetty.Transport/Channels/IChannelPipeline.cs b/src/DotNetty.Transport/Channels/IChannelPipeline.cs index 074412650..270ad3eea 100644 --- a/src/DotNetty.Transport/Channels/IChannelPipeline.cs +++ b/src/DotNetty.Transport/Channels/IChannelPipeline.cs @@ -683,7 +683,7 @@ public interface IChannelPipeline : IEnumerable /// once you want to request to flush all pending data to the actual transport. /// /// An await-able task. - Task WriteAsync(object msg); + ValueTask WriteAsync(object msg); /// /// Request to flush all pending messages. @@ -695,5 +695,7 @@ public interface IChannelPipeline : IEnumerable /// Shortcut for calling both and . /// Task WriteAndFlushAsync(object msg); + + ValueTask WriteAndFlushAsync(object msg, bool notifyComplete); } } \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/IChannelUnsafe.cs b/src/DotNetty.Transport/Channels/IChannelUnsafe.cs index a28e833fd..f0eade341 100644 --- a/src/DotNetty.Transport/Channels/IChannelUnsafe.cs +++ b/src/DotNetty.Transport/Channels/IChannelUnsafe.cs @@ -5,6 +5,7 @@ namespace DotNetty.Transport.Channels { using System.Net; using System.Threading.Tasks; + using DotNetty.Common.Concurrency; public interface IChannelUnsafe { @@ -26,7 +27,7 @@ public interface IChannelUnsafe void BeginRead(); - Task WriteAsync(object message); + ValueTask WriteAsync(object message); void Flush(); diff --git a/src/DotNetty.Transport/Channels/PendingWriteQueue.cs b/src/DotNetty.Transport/Channels/PendingWriteQueue.cs index 80441bb22..f09a9dd3a 100644 --- a/src/DotNetty.Transport/Channels/PendingWriteQueue.cs +++ b/src/DotNetty.Transport/Channels/PendingWriteQueue.cs @@ -7,6 +7,7 @@ namespace DotNetty.Transport.Channels using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Threading.Tasks; + using System.Threading.Tasks.Sources; using DotNetty.Common; using DotNetty.Common.Concurrency; using DotNetty.Common.Internal.Logging; @@ -70,7 +71,7 @@ public int Size /// /// The message to add to the . /// An await-able task. - public Task Add(object msg) + public ValueTask Add(object msg) { Contract.Assert(this.ctx.Executor.InEventLoop); Contract.Requires(msg != null); @@ -81,8 +82,8 @@ public Task Add(object msg) // Size may be unknow so just use 0 messageSize = 0; } - var promise = new TaskCompletionSource(); - PendingWrite write = PendingWrite.NewInstance(msg, messageSize, promise); + + PendingWrite write = PendingWrite.NewInstance(this.ctx.Executor, msg, messageSize); PendingWrite currentTail = this.tail; if (currentTail == null) { @@ -98,7 +99,8 @@ public Task Add(object msg) // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 this.buffer?.IncrementPendingOutboundBytes(write.Size); - return promise.Task; + + return write; } /// @@ -120,9 +122,8 @@ public void RemoveAndFailAll(Exception cause) { PendingWrite next = write.Next; ReferenceCountUtil.SafeRelease(write.Msg); - TaskCompletionSource promise = write.Promise; this.Recycle(write, false); - Util.SafeSetFailure(promise, cause, Logger); + Util.SafeSetFailure(write, cause, Logger); write = next; } this.AssertEmpty(); @@ -145,8 +146,7 @@ public void RemoveAndFail(Exception cause) return; } ReferenceCountUtil.SafeRelease(write.Msg); - TaskCompletionSource promise = write.Promise; - Util.SafeSetFailure(promise, cause, Logger); + Util.SafeSetFailure(write, cause, Logger); this.Recycle(write, true); } @@ -154,7 +154,7 @@ public void RemoveAndFail(Exception cause) /// Removes all pending write operation and performs them via /// /// An await-able task. - public Task RemoveAndWriteAllAsync() + public ValueTask RemoveAndWriteAllAsync() { Contract.Assert(this.ctx.Executor.InEventLoop); @@ -167,7 +167,7 @@ public Task RemoveAndWriteAllAsync() if (write == null) { // empty so just return null - return null; + return default(ValueTask); } // Guard against re-entrance by directly reset @@ -175,19 +175,19 @@ public Task RemoveAndWriteAllAsync() int currentSize = this.size; this.size = 0; - var tasks = new List(currentSize); + var tasks = new List(currentSize); + while (write != null) { PendingWrite next = write.Next; object msg = write.Msg; - TaskCompletionSource promise = write.Promise; this.Recycle(write, false); - this.ctx.WriteAsync(msg).LinkOutcome(promise); - tasks.Add(promise.Task); + this.ctx.WriteAsync(msg).LinkOutcome(write); + tasks.Add(write); write = next; } this.AssertEmpty(); - return Task.WhenAll(tasks); + return new AggregatingPromise(tasks); } void AssertEmpty() => Contract.Assert(this.tail == null && this.head == null && this.size == 0); @@ -196,20 +196,19 @@ public Task RemoveAndWriteAllAsync() /// Removes a pending write operation and performs it via . /// /// An await-able task. - public Task RemoveAndWriteAsync() + public ValueTask RemoveAndWriteAsync() { Contract.Assert(this.ctx.Executor.InEventLoop); PendingWrite write = this.head; if (write == null) { - return null; + return default(ValueTask); } object msg = write.Msg; - TaskCompletionSource promise = write.Promise; this.Recycle(write, true); - this.ctx.WriteAsync(msg).LinkOutcome(promise); - return promise.Task; + this.ctx.WriteAsync(msg).LinkOutcome(write); + return write; } /// @@ -219,19 +218,18 @@ public Task RemoveAndWriteAsync() /// /// The of the pending write, or null if the queue is empty. /// - public TaskCompletionSource Remove() + public ValueTask Remove() { Contract.Assert(this.ctx.Executor.InEventLoop); PendingWrite write = this.head; if (write == null) { - return null; + return default(ValueTask); } - TaskCompletionSource promise = write.Promise; ReferenceCountUtil.SafeRelease(write.Msg); this.Recycle(write, true); - return promise; + return write; } /// @@ -268,8 +266,7 @@ void Recycle(PendingWrite write, bool update) Contract.Assert(this.size > 0); } } - - write.Recycle(); + // We need to guard against null as channel.unsafe().outboundBuffer() may returned null // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 @@ -279,37 +276,34 @@ void Recycle(PendingWrite write, bool update) /// /// Holds all meta-data and constructs the linked-list structure. /// - sealed class PendingWrite + sealed class PendingWrite : AbstractRecyclablePromise { static readonly ThreadLocalPool Pool = new ThreadLocalPool(handle => new PendingWrite(handle)); - readonly ThreadLocalPool.Handle handle; public PendingWrite Next; public long Size; - public TaskCompletionSource Promise; public object Msg; PendingWrite(ThreadLocalPool.Handle handle) + : base(handle) { - this.handle = handle; } - public static PendingWrite NewInstance(object msg, int size, TaskCompletionSource promise) + public static PendingWrite NewInstance(IEventExecutor executor, object msg, int size) { PendingWrite write = Pool.Take(); + write.Init(executor); write.Size = size; write.Msg = msg; - write.Promise = promise; return write; } - public void Recycle() + protected override void Recycle() { this.Size = 0; this.Next = null; this.Msg = null; - this.Promise = null; - this.handle.Release(this); + base.Recycle(); } } } diff --git a/src/DotNetty.Transport/Channels/Util.cs b/src/DotNetty.Transport/Channels/Util.cs index 105860e87..51f439f3d 100644 --- a/src/DotNetty.Transport/Channels/Util.cs +++ b/src/DotNetty.Transport/Channels/Util.cs @@ -25,6 +25,17 @@ public static void SafeSetSuccess(TaskCompletionSource promise, IInternalLogger logger.Warn($"Failed to mark a promise as success because it is done already: {promise}"); } } + + /// + /// Marks the specified {@code promise} as success. If the {@code promise} is done already, log a message. + /// + public static void SafeSetSuccess(IPromise promise, IInternalLogger logger) + { + if (!promise.TryComplete()) + { + logger.Warn($"Failed to mark a promise as success because it is done already: {promise}"); + } + } /// /// Marks the specified as failure. If the @@ -40,6 +51,17 @@ public static void SafeSetFailure(TaskCompletionSource promise, Exception cause, logger.Warn($"Failed to mark a promise as failure because it's done already: {promise}", cause); } } + + /// + /// Marks the specified {@code promise} as failure. If the {@code promise} is done already, log a message. + /// + public static void SafeSetFailure(IPromise promise, Exception cause, IInternalLogger logger) + { + if (!promise.TrySetException(cause)) + { + logger.Warn($"Failed to mark a promise as failure because it's done already: {promise}", cause); + } + } public static void CloseSafe(this IChannel channel) { diff --git a/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs index ed0bcd240..e5695cc10 100644 --- a/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs +++ b/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs @@ -366,12 +366,9 @@ public void TooManyResponses() ch.WriteOutbound(new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty)); Assert.True(false, "Should not get here, expecting exception thrown"); } - catch (AggregateException e) + catch (EncoderException e) { - Assert.Single(e.InnerExceptions); - Assert.IsType(e.InnerExceptions[0]); - Exception exception = e.InnerExceptions[0]; - Assert.IsType(exception.InnerException); + Assert.IsType(e.InnerException); } Assert.True(ch.Finish()); diff --git a/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs index cb8f3a6a2..dda9c5186 100644 --- a/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs +++ b/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs @@ -59,8 +59,8 @@ public override void ChannelRead(IChannelHandlerContext ctx, object msg) // written the upgrade response, and upgraded the pipeline. Assert.True(this.writeUpgradeMessage); Assert.False(this.writeFlushed); - //Assert.Null(ctx.Channel.Pipeline.Get()); - //Assert.NotNull(ctx.Channel.Pipeline.Get("marker")); + Assert.Null(ctx.Channel.Pipeline.Get()); + Assert.NotNull(ctx.Channel.Pipeline.Get("marker")); } finally { @@ -68,7 +68,7 @@ public override void ChannelRead(IChannelHandlerContext ctx, object msg) } } - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { // We ensure that we're in the read call and defer the write so we can // make sure the pipeline was reformed irrespective of the flush completing. @@ -76,22 +76,20 @@ public override Task WriteAsync(IChannelHandlerContext ctx, object msg) this.writeUpgradeMessage = true; var completion = new TaskCompletionSource(); - ctx.Channel.EventLoop.Execute(() => + ctx.Channel.EventLoop.Execute(async () => { - ctx.WriteAsync(msg) - .ContinueWith(t => - { - if (t.Status == TaskStatus.RanToCompletion) - { - this.writeFlushed = true; - completion.TryComplete(); - return; - } - completion.TrySetException(new InvalidOperationException($"Invalid WriteAsync task status {t.Status}")); - }, - TaskContinuationOptions.ExecuteSynchronously); + try + { + await ctx.WriteAsync(msg); + this.writeFlushed = true; + completion.TryComplete(); + } + catch(Exception ex) + { + completion.TrySetException(ex); + } }); - return completion.Task; + return new ValueTask(completion.Task); } } @@ -113,13 +111,11 @@ public void UpgradesPipelineInSameMethodInvocation() IByteBuffer upgrade = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(UpgradeString)); Assert.False(channel.WriteInbound(upgrade)); - //Assert.Null(channel.Pipeline.Get()); - //Assert.NotNull(channel.Pipeline.Get("marker")); - - channel.Flush(); Assert.Null(channel.Pipeline.Get()); Assert.NotNull(channel.Pipeline.Get("marker")); + channel.Flush(); + var upgradeMessage = channel.ReadOutbound(); const string ExpectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + "connection: upgrade\r\n" + diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs index 9fe1705a3..341d1dd27 100644 --- a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs @@ -140,10 +140,10 @@ public MockOutboundHandler(WebSocketServerProtocolHandlerTest owner) this.owner = owner; } - public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + public override ValueTask WriteAsync(IChannelHandlerContext ctx, object msg) { this.owner.responses.Enqueue((IFullHttpResponse)msg); - return TaskEx.Completed; + return new ValueTask(); } public override void Flush(IChannelHandlerContext ctx) diff --git a/test/DotNetty.Codecs.Tests/Frame/LengthFieldPrependerTest.cs b/test/DotNetty.Codecs.Tests/Frame/LengthFieldPrependerTest.cs index 6ece234cc..cfa0456f0 100644 --- a/test/DotNetty.Codecs.Tests/Frame/LengthFieldPrependerTest.cs +++ b/test/DotNetty.Codecs.Tests/Frame/LengthFieldPrependerTest.cs @@ -70,13 +70,11 @@ public void TestPrependAdjustedLength() public void TestPrependAdjustedLengthLessThanZero() { var ch = new EmbeddedChannel(new LengthFieldPrepender(4, -2)); - var ex = Assert.Throws(() => + var ex = Assert.Throws(() => { ch.WriteOutbound(this.msg); Assert.True(false, typeof(EncoderException).Name + " must be raised."); }); - - Assert.IsType(ex.InnerExceptions.Single()); } [Fact] diff --git a/test/DotNetty.Tests.Common/ChannelExtensions.cs b/test/DotNetty.Tests.Common/ChannelExtensions.cs index 921613d0f..fd19d3f27 100644 --- a/test/DotNetty.Tests.Common/ChannelExtensions.cs +++ b/test/DotNetty.Tests.Common/ChannelExtensions.cs @@ -14,7 +14,7 @@ public static Task WriteAndFlushManyAsync(this IChannel channel, params object[] var list = new List(); foreach (object m in messages) { - list.Add(channel.WriteAsync(m)); + list.Add(channel.WriteAsync(m).AsTask()); } IEnumerable tasks = list.ToArray(); channel.Flush(); diff --git a/test/DotNetty.Transport.Libuv.Tests/AutoReadTests.cs b/test/DotNetty.Transport.Libuv.Tests/AutoReadTests.cs index ef09ff54c..c8f3bc6b2 100644 --- a/test/DotNetty.Transport.Libuv.Tests/AutoReadTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/AutoReadTests.cs @@ -8,7 +8,9 @@ namespace DotNetty.Transport.Libuv.Tests using System.Threading; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; + using DotNetty.Tests.Common; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; using Xunit; diff --git a/test/DotNetty.Transport.Libuv.Tests/BufReleaseTests.cs b/test/DotNetty.Transport.Libuv.Tests/BufReleaseTests.cs index e104f9f27..91220118c 100644 --- a/test/DotNetty.Transport.Libuv.Tests/BufReleaseTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/BufReleaseTests.cs @@ -4,10 +4,14 @@ namespace DotNetty.Transport.Libuv.Tests { using System; + using System.Diagnostics; + using System.Globalization; using System.Net; + using System.Runtime.CompilerServices; using System.Threading.Tasks; using DotNetty.Buffers; using DotNetty.Common.Concurrency; + using DotNetty.Tests.Common; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; using Xunit; diff --git a/test/DotNetty.Transport.Libuv.Tests/CompositeBufferGatheringWriteTests.cs b/test/DotNetty.Transport.Libuv.Tests/CompositeBufferGatheringWriteTests.cs index ab276ce4a..aa762a1bc 100644 --- a/test/DotNetty.Transport.Libuv.Tests/CompositeBufferGatheringWriteTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/CompositeBufferGatheringWriteTests.cs @@ -131,10 +131,17 @@ public void AssertReceived(IByteBuffer expected) sealed class ServerHandler : ChannelHandlerAdapter { - public override void ChannelActive(IChannelHandlerContext ctx) => - ctx.WriteAndFlushAsync(NewCompositeBuffer(ctx.Allocator)) - .ContinueWith((t, s) => ((IChannelHandlerContext)s).CloseAsync(), - ctx, TaskContinuationOptions.ExecuteSynchronously); + public override async void ChannelActive(IChannelHandlerContext ctx) + { + try + { + await ctx.WriteAndFlushAsync(NewCompositeBuffer(ctx.Allocator)); + } + finally + { + ctx.CloseAsync(); + } + } } static IByteBuffer NewCompositeBuffer(IByteBufferAllocator alloc) diff --git a/test/DotNetty.Transport.Libuv.Tests/DetectPeerCloseWithoutReadTests.cs b/test/DotNetty.Transport.Libuv.Tests/DetectPeerCloseWithoutReadTests.cs index 3355c12b1..e89d8140e 100644 --- a/test/DotNetty.Transport.Libuv.Tests/DetectPeerCloseWithoutReadTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/DetectPeerCloseWithoutReadTests.cs @@ -8,7 +8,9 @@ namespace DotNetty.Transport.Libuv.Tests using System.Threading; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Codecs; using DotNetty.Common.Concurrency; + using DotNetty.Tests.Common; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; using Xunit; @@ -62,7 +64,7 @@ public void ClientCloseWithoutServerReadIsDetected() IByteBuffer buf = this.clientChannel.Allocator.Buffer(ExpectedBytes); buf.SetWriterIndex(buf.WriterIndex + ExpectedBytes); - this.clientChannel.WriteAndFlushAsync(buf).ContinueWith(_ => this.clientChannel.CloseAsync()); + this.clientChannel.WriteAndFlushAsync(buf).CloseOnComplete(this.clientChannel); Task completion = serverHandler.Completion; Assert.True(completion.Wait(DefaultTimeout)); @@ -172,7 +174,7 @@ public override void ChannelActive(IChannelHandlerContext ctx) { IByteBuffer buf = ctx.Allocator.Buffer(this.expectedBytesRead); buf.SetWriterIndex(buf.WriterIndex + this.expectedBytesRead); - ctx.WriteAndFlushAsync(buf).ContinueWith(_ => ctx.CloseAsync()); + ctx.WriteAndFlushAsync(buf).CloseOnComplete(ctx); ctx.FireChannelActive(); } diff --git a/test/DotNetty.Transport.Libuv.Tests/ExceptionHandlingTests.cs b/test/DotNetty.Transport.Libuv.Tests/ExceptionHandlingTests.cs index 2ac71cf76..2d5e865aa 100644 --- a/test/DotNetty.Transport.Libuv.Tests/ExceptionHandlingTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/ExceptionHandlingTests.cs @@ -10,6 +10,7 @@ namespace DotNetty.Transport.Libuv.Tests using DotNetty.Buffers; using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; + using DotNetty.Tests.Common; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; using Xunit; diff --git a/test/DotNetty.Transport.Libuv.Tests/ReadPendingTests.cs b/test/DotNetty.Transport.Libuv.Tests/ReadPendingTests.cs index 0696eda56..a3644787c 100644 --- a/test/DotNetty.Transport.Libuv.Tests/ReadPendingTests.cs +++ b/test/DotNetty.Transport.Libuv.Tests/ReadPendingTests.cs @@ -10,6 +10,7 @@ namespace DotNetty.Transport.Libuv.Tests using DotNetty.Buffers; using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; + using DotNetty.Tests.Common; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; using Xunit; diff --git a/test/DotNetty.Transport.Tests.Performance/Sockets/SocketDatagramChannelPerfSpecs.cs b/test/DotNetty.Transport.Tests.Performance/Sockets/SocketDatagramChannelPerfSpecs.cs index 71695f7df..2e154e211 100644 --- a/test/DotNetty.Transport.Tests.Performance/Sockets/SocketDatagramChannelPerfSpecs.cs +++ b/test/DotNetty.Transport.Tests.Performance/Sockets/SocketDatagramChannelPerfSpecs.cs @@ -8,6 +8,7 @@ namespace DotNetty.Transport.Tests.Performance.Sockets using System.Threading; using System.Threading.Tasks; using DotNetty.Buffers; + using DotNetty.Common.Concurrency; using DotNetty.Common.Utilities; using DotNetty.Transport.Bootstrapping; using DotNetty.Transport.Channels; @@ -85,7 +86,7 @@ public OutboundCounter(Counter writes) this.writes = writes; } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { this.writes.Increment(); return context.WriteAsync(message); diff --git a/test/DotNetty.Transport.Tests.Performance/Transport/AbstractPingPongPerfSpecs.cs b/test/DotNetty.Transport.Tests.Performance/Transport/AbstractPingPongPerfSpecs.cs index 92f050f48..a8e9b336c 100644 --- a/test/DotNetty.Transport.Tests.Performance/Transport/AbstractPingPongPerfSpecs.cs +++ b/test/DotNetty.Transport.Tests.Performance/Transport/AbstractPingPongPerfSpecs.cs @@ -75,7 +75,7 @@ public void SetUp(BenchmarkContext context) public void RoundTrip(BenchmarkContext context) { this.clientHandler.Start(); - this.client.WriteAndFlushAsync(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes("PING"))); + this.client.WriteAndFlushAsync(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes("PING")), false); this.clientHandler.Completion.Wait(TimeSpan.FromSeconds(10)); } @@ -105,7 +105,7 @@ public override void ChannelRead(IChannelHandlerContext context, object message) this.counter.Increment(); if (this.stopwatch.Elapsed < this.duration) { - context.WriteAndFlushAsync(buffer); + context.WriteAndFlushAsync(buffer, false); } else { @@ -129,7 +129,7 @@ public override void ChannelRead(IChannelHandlerContext context, object message) { if (message is IByteBuffer buffer) { - context.WriteAndFlushAsync(buffer); + context.WriteAndFlushAsync(buffer, false); } else { diff --git a/test/DotNetty.Transport.Tests.Performance/Utilities/CounterHandlerOutbound.cs b/test/DotNetty.Transport.Tests.Performance/Utilities/CounterHandlerOutbound.cs index a5d830d79..063489907 100644 --- a/test/DotNetty.Transport.Tests.Performance/Utilities/CounterHandlerOutbound.cs +++ b/test/DotNetty.Transport.Tests.Performance/Utilities/CounterHandlerOutbound.cs @@ -4,6 +4,7 @@ namespace DotNetty.Transport.Tests.Performance.Utilities { using System.Threading.Tasks; + using DotNetty.Common.Concurrency; using DotNetty.Transport.Channels; using NBench; @@ -16,7 +17,7 @@ public CounterHandlerOutbound(Counter throughput) this.throughput = throughput; } - public override Task WriteAsync(IChannelHandlerContext context, object message) + public override ValueTask WriteAsync(IChannelHandlerContext context, object message) { this.throughput.Increment(); return context.WriteAsync(message);