From 2c0c10939e457a13b2299a5f8d7193d4c8206e95 Mon Sep 17 00:00:00 2001 From: Johnny Z Date: Tue, 14 Aug 2018 03:58:52 +1000 Subject: [PATCH] WebSocket support. (#400) --- DotNetty.sln | 14 + examples/Examples.Common/ClientSettings.cs | 9 + examples/WebSockets.Client/Program.cs | 139 ++++++ .../WebSocketClientHandler.cs | 85 ++++ .../WebSockets.Client.csproj | 31 ++ examples/WebSockets.Client/appsettings.json | 7 + examples/WebSockets.Server/Program.cs | 124 +++++ .../WebSocketServerBenchmarkPage.cs | 172 +++++++ .../WebSocketServerHandler.cs | 160 ++++++ .../WebSockets.Server.csproj | 32 ++ examples/WebSockets.Server/appsettings.json | 4 + src/DotNetty.Buffers/Unpooled.cs | 2 +- src/DotNetty.Codecs.Http/HttpScheme.cs | 6 + src/DotNetty.Codecs.Http/HttpServerCodec.cs | 1 - .../HttpServerUpgradeHandler.cs | 2 - .../WebSockets/BinaryWebSocketFrame.cs | 27 + .../WebSockets/CloseWebSocketFrame.cs | 95 ++++ .../WebSockets/ContinuationWebSocketFrame.cs | 38 ++ .../Extensions/Compression/DeflateDecoder.cs | 134 +++++ .../Extensions/Compression/DeflateEncoder.cs | 136 +++++ .../DeflateFrameClientExtensionHandshaker.cs | 69 +++ .../DeflateFrameServerExtensionHandshaker.cs | 68 +++ .../Compression/PerFrameDeflateDecoder.cs | 21 + .../Compression/PerFrameDeflateEncoder.cs | 22 + ...MessageDeflateClientExtensionHandshaker.cs | 179 +++++++ .../Compression/PerMessageDeflateDecoder.cs | 42 ++ .../Compression/PerMessageDeflateEncoder.cs | 44 ++ ...MessageDeflateServerExtensionHandshaker.cs | 174 +++++++ .../WebSocketClientCompressionHandler.cs | 19 + .../WebSocketServerCompressionHandler.cs | 13 + .../Extensions/IWebSocketClientExtension.cs | 9 + .../IWebSocketClientExtensionHandshaker.cs | 12 + .../Extensions/IWebSocketExtension.cs | 24 + .../Extensions/IWebSocketServerExtension.cs | 10 + .../IWebSocketServerExtensionHandshaker.cs | 10 + .../WebSocketClientExtensionHandler.cs | 101 ++++ .../Extensions/WebSocketExtensionData.cs | 28 ++ .../Extensions/WebSocketExtensionDecoder.cs | 9 + .../Extensions/WebSocketExtensionEncoder.cs | 9 + .../Extensions/WebSocketExtensionUtil.cs | 84 ++++ .../WebSocketServerExtensionHandler.cs | 118 +++++ .../WebSockets/IWebSocketFrameDecoder.cs | 15 + .../WebSockets/IWebSocketFrameEncoder.cs | 15 + .../WebSockets/PingWebSocketFrame.cs | 27 + .../WebSockets/PongWebSocketFrame.cs | 27 + .../WebSockets/TextWebSocketFrame.cs | 43 ++ .../WebSockets/Utf8FrameValidator.cs | 90 ++++ .../WebSockets/Utf8Validator.cs | 76 +++ .../WebSockets/WebSocket00FrameDecoder.cs | 133 +++++ .../WebSockets/WebSocket00FrameEncoder.cs | 104 ++++ .../WebSockets/WebSocket07FrameDecoder.cs | 18 + .../WebSockets/WebSocket07FrameEncoder.cs | 13 + .../WebSockets/WebSocket08FrameDecoder.cs | 467 ++++++++++++++++++ .../WebSockets/WebSocket08FrameEncoder.cs | 219 ++++++++ .../WebSockets/WebSocket13FrameDecoder.cs | 19 + .../WebSockets/WebSocket13FrameEncoder.cs | 13 + .../WebSockets/WebSocketChunkedInput.cs | 42 ++ .../WebSockets/WebSocketClientHandshaker.cs | 425 ++++++++++++++++ .../WebSockets/WebSocketClientHandshaker00.cs | 163 ++++++ .../WebSockets/WebSocketClientHandshaker07.cs | 117 +++++ .../WebSockets/WebSocketClientHandshaker08.cs | 118 +++++ .../WebSockets/WebSocketClientHandshaker13.cs | 119 +++++ .../WebSocketClientHandshakerFactory.cs | 50 ++ .../WebSocketClientProtocolHandler.cs | 100 ++++ ...WebSocketClientProtocolHandshakeHandler.cs | 66 +++ .../WebSockets/WebSocketFrame.cs | 44 ++ .../WebSockets/WebSocketFrameAggregator.cs | 58 +++ .../WebSockets/WebSocketHandshakeException.cs | 25 + .../WebSockets/WebSocketProtocolHandler.cs | 47 ++ .../WebSockets/WebSocketScheme.cs | 39 ++ .../WebSockets/WebSocketServerHandshaker.cs | 257 ++++++++++ .../WebSockets/WebSocketServerHandshaker00.cs | 114 +++++ .../WebSockets/WebSocketServerHandshaker07.cs | 81 +++ .../WebSockets/WebSocketServerHandshaker08.cs | 80 +++ .../WebSockets/WebSocketServerHandshaker13.cs | 79 +++ .../WebSocketServerHandshakerFactory.cs | 100 ++++ .../WebSocketServerProtocolHandler.cs | 185 +++++++ ...WebSocketServerProtocolHandshakeHandler.cs | 126 +++++ .../WebSockets/WebSocketUtil.cs | 67 +++ .../WebSockets/WebSocketVersion.cs | 56 +++ src/DotNetty.Codecs/ByteToMessageDecoder.cs | 10 +- .../Compression/JZlibEncoder.cs | 1 + .../Compression/ZlibCodecFactory.cs | 3 + src/DotNetty.Codecs/Compression/ZlibUtil.cs | 4 + src/DotNetty.Common/Utilities/AsciiString.cs | 31 ++ src/DotNetty.Common/Utilities/NetUtil.cs | 251 ++++++++++ .../Channels/CombinedChannelDuplexHandler.cs | 1 - .../DotNetty.Codecs.Http.Tests.csproj | 1 + .../HttpServerUpgradeHandlerTest.cs | 132 +++++ ...flateFrameClientExtensionHandshakerTest.cs | 66 +++ ...flateFrameServerExtensionHandshakerTest.cs | 55 +++ .../Compression/PerFrameDeflateDecoderTest.cs | 106 ++++ .../Compression/PerFrameDeflateEncoderTest.cs | 144 ++++++ ...ageDeflateClientExtensionHandshakerTest.cs | 92 ++++ .../PerMessageDeflateDecoderTest.cs | 171 +++++++ .../PerMessageDeflateEncoderTest.cs | 142 ++++++ ...ageDeflateServerExtensionHandshakerTest.cs | 122 +++++ .../WebSocketServerCompressionHandlerTest.cs | 195 ++++++++ .../WebSocketClientExtensionHandlerTest.cs | 239 +++++++++ .../Extensions/WebSocketExtensionTestUtil.cs | 82 +++ .../Extensions/WebSocketExtensionUtilTest.cs | 24 + .../WebSocketServerExtensionHandlerTest.cs | 192 +++++++ .../WebSockets/WebSocket00FrameEncoderTest.cs | 35 ++ .../WebSocket08EncoderDecoderTest.cs | 158 ++++++ .../WebSockets/WebSocket08FrameDecoderTest.cs | 24 + .../WebSocketClientHandshaker00Test.cs | 17 + .../WebSocketClientHandshaker07Test.cs | 16 + .../WebSocketClientHandshaker08Test.cs | 13 + .../WebSocketClientHandshaker13Test.cs | 13 + .../WebSocketClientHandshakerTest.cs | 315 ++++++++++++ .../WebSocketFrameAggregatorTest.cs | 140 ++++++ .../WebSocketHandshakeHandOverTest.cs | 166 +++++++ .../WebSocketProtocolHandlerTest.cs | 82 +++ .../WebSockets/WebSocketRequestBuilder.cs | 143 ++++++ .../WebSocketServerHandshaker00Test.cs | 74 +++ .../WebSocketServerHandshaker08Test.cs | 72 +++ .../WebSocketServerHandshaker13Test.cs | 72 +++ .../WebSocketServerHandshakerFactoryTest.cs | 32 ++ .../WebSocketServerProtocolHandlerTest.cs | 168 +++++++ 119 files changed, 9718 insertions(+), 6 deletions(-) create mode 100644 examples/WebSockets.Client/Program.cs create mode 100644 examples/WebSockets.Client/WebSocketClientHandler.cs create mode 100644 examples/WebSockets.Client/WebSockets.Client.csproj create mode 100644 examples/WebSockets.Client/appsettings.json create mode 100644 examples/WebSockets.Server/Program.cs create mode 100644 examples/WebSockets.Server/WebSocketServerBenchmarkPage.cs create mode 100644 examples/WebSockets.Server/WebSocketServerHandler.cs create mode 100644 examples/WebSockets.Server/WebSockets.Server.csproj create mode 100644 examples/WebSockets.Server/appsettings.json create mode 100644 src/DotNetty.Codecs.Http/WebSockets/BinaryWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/CloseWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/ContinuationWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketClientCompressionHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketServerCompressionHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtension.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketExtension.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtension.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtensionHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionData.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionUtil.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/PingWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/PongWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/TextWebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Utf8FrameValidator.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/Utf8Validator.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketChunkedInput.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker00.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker07.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker08.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker13.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshakerFactory.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandshakeHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketFrame.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketFrameAggregator.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketHandshakeException.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketProtocolHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketScheme.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker00.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker07.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker08.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker13.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshakerFactory.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandshakeHandler.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketUtil.cs create mode 100644 src/DotNetty.Codecs.Http/WebSockets/WebSocketVersion.cs create mode 100644 src/DotNetty.Common/Utilities/NetUtil.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshakerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshakerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshakerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshakerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/WebSocketServerCompressionHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketClientExtensionHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionTestUtil.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionUtilTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketServerExtensionHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket00FrameEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08EncoderDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08FrameDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker00Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker07Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker08Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker13Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshakerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketFrameAggregatorTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketHandshakeHandOverTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketProtocolHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketRequestBuilder.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker00Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker08Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker13Test.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshakerFactoryTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs diff --git a/DotNetty.sln b/DotNetty.sln index d9f9ccb52..f2809a025 100644 --- a/DotNetty.sln +++ b/DotNetty.sln @@ -103,6 +103,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Codecs.Http", "src EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Codecs.Http.Tests", "test\DotNetty.Codecs.Http.Tests\DotNetty.Codecs.Http.Tests.csproj", "{16C89E7C-1575-4685-8DFA-8E7E2C6101BF}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "WebSockets.Server", "examples\WebSockets.Server\WebSockets.Server.csproj", "{EA387B4B-DAD0-4E34-B8A3-79EA4616726A}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WebSockets.Client", "examples\WebSockets.Client\WebSockets.Client.csproj", "{3326DB6E-023E-483F-9A1C-5905D3091B57}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -273,6 +277,14 @@ Global {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Debug|Any CPU.Build.0 = Debug|Any CPU {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Release|Any CPU.ActiveCfg = Release|Any CPU {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Release|Any CPU.Build.0 = Release|Any CPU + {EA387B4B-DAD0-4E34-B8A3-79EA4616726A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EA387B4B-DAD0-4E34-B8A3-79EA4616726A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EA387B4B-DAD0-4E34-B8A3-79EA4616726A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EA387B4B-DAD0-4E34-B8A3-79EA4616726A}.Release|Any CPU.Build.0 = Release|Any CPU + {3326DB6E-023E-483F-9A1C-5905D3091B57}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3326DB6E-023E-483F-9A1C-5905D3091B57}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3326DB6E-023E-483F-9A1C-5905D3091B57}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3326DB6E-023E-483F-9A1C-5905D3091B57}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -319,6 +331,8 @@ Global {A7CACAE7-66E7-43DA-948B-28EB0DDDB582} = {F716F1EF-81EF-4020-914A-5422A13A9E13} {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B} = {126EA539-4B28-4B07-8B5D-D1D7F794D189} {16C89E7C-1575-4685-8DFA-8E7E2C6101BF} = {541093F6-616E-43D9-B671-FCD1F9C0A181} + {EA387B4B-DAD0-4E34-B8A3-79EA4616726A} = {F716F1EF-81EF-4020-914A-5422A13A9E13} + {3326DB6E-023E-483F-9A1C-5905D3091B57} = {F716F1EF-81EF-4020-914A-5422A13A9E13} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution {9FE6A783-C20D-4097-9988-4178E2C4CE75} = {126EA539-4B28-4B07-8B5D-D1D7F794D189} diff --git a/examples/Examples.Common/ClientSettings.cs b/examples/Examples.Common/ClientSettings.cs index 1bc1c1c78..ac689fbf1 100644 --- a/examples/Examples.Common/ClientSettings.cs +++ b/examples/Examples.Common/ClientSettings.cs @@ -21,5 +21,14 @@ public static bool IsSsl public static int Port => int.Parse(ExampleHelper.Configuration["port"]); public static int Size => int.Parse(ExampleHelper.Configuration["size"]); + + public static bool UseLibuv + { + get + { + string libuv = ExampleHelper.Configuration["libuv"]; + return !string.IsNullOrEmpty(libuv) && bool.Parse(libuv); + } + } } } \ No newline at end of file diff --git a/examples/WebSockets.Client/Program.cs b/examples/WebSockets.Client/Program.cs new file mode 100644 index 000000000..002f2f616 --- /dev/null +++ b/examples/WebSockets.Client/Program.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace WebSockets.Client +{ + using System; + using System.IO; + using System.Net; + using System.Net.Security; + using System.Security.Cryptography.X509Certificates; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Handlers.Tls; + using DotNetty.Transport.Bootstrapping; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Sockets; + using DotNetty.Transport.Libuv; + using Examples.Common; + + class Program + { + static async Task RunClientAsync() + { + var builder = new UriBuilder + { + Scheme = ClientSettings.IsSsl ? "wss" : "ws", + Host = ClientSettings.Host.ToString(), + Port = ClientSettings.Port + }; + + string path = ExampleHelper.Configuration["path"]; + if (!string.IsNullOrEmpty(path)) + { + builder.Path = path; + } + + Uri uri = builder.Uri; + ExampleHelper.SetConsoleLogger(); + + bool useLibuv = ClientSettings.UseLibuv; + Console.WriteLine("Transport type : " + (useLibuv ? "Libuv" : "Socket")); + + IEventLoopGroup group; + if (useLibuv) + { + group = new EventLoopGroup(); + } + else + { + group = new MultithreadEventLoopGroup(); + } + + X509Certificate2 cert = null; + string targetHost = null; + if (ClientSettings.IsSsl) + { + cert = new X509Certificate2(Path.Combine(ExampleHelper.ProcessDirectory, "dotnetty.com.pfx"), "password"); + targetHost = cert.GetNameInfo(X509NameType.DnsName, false); + } + try + { + var bootstrap = new Bootstrap(); + bootstrap + .Group(group) + .Option(ChannelOption.TcpNodelay, true); + if (useLibuv) + { + bootstrap.Channel(); + } + else + { + bootstrap.Channel(); + } + + // Connect with V13 (RFC 6455 aka HyBi-17). You can change it to V08 or V00. + // If you change it to V00, ping is not supported and remember to change + // HttpResponseDecoder to WebSocketHttpResponseDecoder in the pipeline. + var handler =new WebSocketClientHandler( + WebSocketClientHandshakerFactory.NewHandshaker( + uri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders())); + + bootstrap.Handler(new ActionChannelInitializer(channel => + { + IChannelPipeline pipeline = channel.Pipeline; + if (cert != null) + { + pipeline.AddLast("tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost))); + } + + pipeline.AddLast( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + WebSocketClientCompressionHandler.Instance, + handler); + })); + + IChannel ch = await bootstrap.ConnectAsync(new IPEndPoint(ClientSettings.Host, ClientSettings.Port)); + await handler.HandshakeCompletion; + + Console.WriteLine("WebSocket handshake completed.\n"); + Console.WriteLine("\t[bye]:Quit \n\t [ping]:Send ping frame\n\t Enter any text and Enter: Send text frame"); + while (true) + { + string msg = Console.ReadLine(); + if (msg == null) + { + break; + } + else if ("bye".Equals(msg.ToLower())) + { + await ch.WriteAndFlushAsync(new CloseWebSocketFrame()); + break; + } + else if ("ping".Equals(msg.ToLower())) + { + var frame = new PingWebSocketFrame(Unpooled.WrappedBuffer(new byte[] { 8, 1, 8, 1 })); + await ch.WriteAndFlushAsync(frame); + } + else + { + WebSocketFrame frame = new TextWebSocketFrame(msg); + await ch.WriteAndFlushAsync(frame); + } + } + + await ch.CloseAsync(); + } + finally + { + await group.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromSeconds(1)); + } + } + + static void Main() => RunClientAsync().Wait(); + } +} diff --git a/examples/WebSockets.Client/WebSocketClientHandler.cs b/examples/WebSockets.Client/WebSocketClientHandler.cs new file mode 100644 index 000000000..6219b1b3a --- /dev/null +++ b/examples/WebSockets.Client/WebSocketClientHandler.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace WebSockets.Client +{ + using System; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Codecs.Http; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocketClientHandler : SimpleChannelInboundHandler + { + readonly WebSocketClientHandshaker handshaker; + readonly TaskCompletionSource completionSource; + + public WebSocketClientHandler(WebSocketClientHandshaker handshaker) + { + this.handshaker = handshaker; + this.completionSource = new TaskCompletionSource(); + } + + public Task HandshakeCompletion => this.completionSource.Task; + + public override void ChannelActive(IChannelHandlerContext ctx) => + this.handshaker.HandshakeAsync(ctx.Channel).LinkOutcome(this.completionSource); + + public override void ChannelInactive(IChannelHandlerContext context) + { + Console.WriteLine("WebSocket Client disconnected!"); + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + IChannel ch = ctx.Channel; + if (!this.handshaker.IsHandshakeComplete) + { + try + { + this.handshaker.FinishHandshake(ch, (IFullHttpResponse)msg); + Console.WriteLine("WebSocket Client connected!"); + this.completionSource.TryComplete(); + } + catch (WebSocketHandshakeException e) + { + Console.WriteLine("WebSocket Client failed to connect"); + this.completionSource.TrySetException(e); + } + + return; + } + + + if (msg is IFullHttpResponse response) + { + throw new InvalidOperationException( + $"Unexpected FullHttpResponse (getStatus={response.Status}, content={response.Content.ToString(Encoding.UTF8)})"); + } + + if (msg is TextWebSocketFrame textFrame) + { + Console.WriteLine($"WebSocket Client received message: {textFrame.Text()}"); + } + else if (msg is PongWebSocketFrame) + { + Console.WriteLine("WebSocket Client received pong"); + } + else if (msg is CloseWebSocketFrame) + { + Console.WriteLine("WebSocket Client received closing"); + ch.CloseAsync(); + } + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception exception) + { + Console.WriteLine("Exception: " + exception); + this.completionSource.TrySetException(exception); + ctx.CloseAsync(); + } + } +} diff --git a/examples/WebSockets.Client/WebSockets.Client.csproj b/examples/WebSockets.Client/WebSockets.Client.csproj new file mode 100644 index 000000000..1f27494e3 --- /dev/null +++ b/examples/WebSockets.Client/WebSockets.Client.csproj @@ -0,0 +1,31 @@ + + + + Exe + netcoreapp1.1;net451 + 1.6.1 + false + + + win-x64 + + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + + + + + diff --git a/examples/WebSockets.Client/appsettings.json b/examples/WebSockets.Client/appsettings.json new file mode 100644 index 000000000..a0f36590b --- /dev/null +++ b/examples/WebSockets.Client/appsettings.json @@ -0,0 +1,7 @@ +{ + "ssl": "false", + "host": "127.0.0.1", + "port": "8080", + "path": "/websocket", + "libuv": "true" +} \ No newline at end of file diff --git a/examples/WebSockets.Server/Program.cs b/examples/WebSockets.Server/Program.cs new file mode 100644 index 000000000..f41d9aa2f --- /dev/null +++ b/examples/WebSockets.Server/Program.cs @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace WebSockets.Server +{ + using System; + using System.IO; + using System.Net; + using System.Runtime; + using System.Runtime.InteropServices; + using System.Security.Cryptography.X509Certificates; + using System.Threading.Tasks; + using DotNetty.Codecs.Http; + using DotNetty.Common; + using DotNetty.Handlers.Tls; + using DotNetty.Transport.Bootstrapping; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Sockets; + using DotNetty.Transport.Libuv; + using Examples.Common; + + class Program + { + static Program() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + } + + static async Task RunServerAsync() + { + Console.WriteLine( + $"\n{RuntimeInformation.OSArchitecture} {RuntimeInformation.OSDescription}" + + $"\n{RuntimeInformation.ProcessArchitecture} {RuntimeInformation.FrameworkDescription}" + + $"\nProcessor Count : {Environment.ProcessorCount}\n"); + + bool useLibuv = ServerSettings.UseLibuv; + Console.WriteLine("Transport type : " + (useLibuv ? "Libuv" : "Socket")); + + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + GCSettings.LatencyMode = GCLatencyMode.SustainedLowLatency; + } + + Console.WriteLine($"Server garbage collection : {(GCSettings.IsServerGC ? "Enabled" : "Disabled")}"); + Console.WriteLine($"Current latency mode for garbage collection: {GCSettings.LatencyMode}"); + Console.WriteLine("\n"); + + IEventLoopGroup bossGroup; + IEventLoopGroup workGroup; + if (useLibuv) + { + var dispatcher = new DispatcherEventLoopGroup(); + bossGroup = dispatcher; + workGroup = new WorkerEventLoopGroup(dispatcher); + } + else + { + bossGroup = new MultithreadEventLoopGroup(1); + workGroup = new MultithreadEventLoopGroup(); + } + + X509Certificate2 tlsCertificate = null; + if (ServerSettings.IsSsl) + { + tlsCertificate = new X509Certificate2(Path.Combine(ExampleHelper.ProcessDirectory, "dotnetty.com.pfx"), "password"); + } + try + { + var bootstrap = new ServerBootstrap(); + bootstrap.Group(bossGroup, workGroup); + + if (useLibuv) + { + bootstrap.Channel(); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) + || RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + bootstrap + .Option(ChannelOption.SoReuseport, true) + .ChildOption(ChannelOption.SoReuseaddr, true); + } + } + else + { + bootstrap.Channel(); + } + + bootstrap + .Option(ChannelOption.SoBacklog, 8192) + .ChildHandler(new ActionChannelInitializer(channel => + { + IChannelPipeline pipeline = channel.Pipeline; + if (tlsCertificate != null) + { + pipeline.AddLast(TlsHandler.Server(tlsCertificate)); + } + pipeline.AddLast(new HttpServerCodec()); + pipeline.AddLast(new HttpObjectAggregator(65536)); + pipeline.AddLast(new WebSocketServerHandler()); + })); + + int port = ServerSettings.Port; + IChannel bootstrapChannel = await bootstrap.BindAsync(IPAddress.Loopback, port); + + Console.WriteLine("Open your web browser and navigate to " + + $"{(ServerSettings.IsSsl ? "https" : "http")}" + + $"://127.0.0.1:{port}/"); + Console.WriteLine("Listening on " + + $"{(ServerSettings.IsSsl ? "wss" : "ws")}" + + $"://127.0.0.1:{port}/websocket"); + Console.ReadLine(); + + await bootstrapChannel.CloseAsync(); + } + finally + { + workGroup.ShutdownGracefullyAsync().Wait(); + bossGroup.ShutdownGracefullyAsync().Wait(); + } + } + + static void Main() => RunServerAsync().Wait(); + } +} diff --git a/examples/WebSockets.Server/WebSocketServerBenchmarkPage.cs b/examples/WebSockets.Server/WebSocketServerBenchmarkPage.cs new file mode 100644 index 000000000..f142fb83f --- /dev/null +++ b/examples/WebSockets.Server/WebSocketServerBenchmarkPage.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace WebSockets.Server +{ + using System.Text; + using DotNetty.Buffers; + + static class WebSocketServerBenchmarkPage + { + const string Newline = "\r\n"; + + public static IByteBuffer GetContent(string webSocketLocation) => + Unpooled.WrappedBuffer( + Encoding.ASCII.GetBytes( + "Web Socket Performance Test" + Newline + + "" + Newline + + "

WebSocket Performance Test

" + Newline + + "" + Newline + + "
" + Newline + + + "
" + Newline + + "Message size:" + + "
" + Newline + + "Number of messages:" + + "
" + Newline + + "Data Type:" + + "text" + + "binary
" + Newline + + "Mode:
" + Newline + + "" + + "Wait for response after each messages
" + Newline + + "" + + "Send all messages and then wait for all responses
" + Newline + + "Verify responded messages
" + Newline + + "" + Newline + + "

Output

" + Newline + + "" + Newline + + "
" + Newline + + "" + Newline + + "
" + Newline + + + "" + Newline + + "" + Newline + + "" + Newline)); + } +} diff --git a/examples/WebSockets.Server/WebSocketServerHandler.cs b/examples/WebSockets.Server/WebSocketServerHandler.cs new file mode 100644 index 000000000..3d9d8faf0 --- /dev/null +++ b/examples/WebSockets.Server/WebSocketServerHandler.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace WebSockets.Server +{ + using System; + using System.Diagnostics; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using Examples.Common; + + using static DotNetty.Codecs.Http.HttpVersion; + using static DotNetty.Codecs.Http.HttpResponseStatus; + + public sealed class WebSocketServerHandler : SimpleChannelInboundHandler + { + const string WebsocketPath = "/websocket"; + + WebSocketServerHandshaker handshaker; + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + if (msg is IFullHttpRequest request) + { + this.HandleHttpRequest(ctx, request); + } + else if (msg is WebSocketFrame frame) + { + this.HandleWebSocketFrame(ctx, frame); + } + } + + public override void ChannelReadComplete(IChannelHandlerContext context) => context.Flush(); + + void HandleHttpRequest(IChannelHandlerContext ctx, IFullHttpRequest req) + { + // Handle a bad request. + if (!req.Result.IsSuccess) + { + SendHttpResponse(ctx, req, new DefaultFullHttpResponse(Http11, BadRequest)); + return; + } + + // Allow only GET methods. + if (!Equals(req.Method, HttpMethod.Get)) + { + SendHttpResponse(ctx, req, new DefaultFullHttpResponse(Http11, Forbidden)); + return; + } + + // Send the demo page and favicon.ico + if ("/".Equals(req.Uri)) + { + IByteBuffer content = WebSocketServerBenchmarkPage.GetContent(GetWebSocketLocation(req)); + var res = new DefaultFullHttpResponse(Http11, OK, content); + + res.Headers.Set(HttpHeaderNames.ContentType, "text/html; charset=UTF-8"); + HttpUtil.SetContentLength(res, content.ReadableBytes); + + SendHttpResponse(ctx, req, res); + return; + } + if ("/favicon.ico".Equals(req.Uri)) + { + var res = new DefaultFullHttpResponse(Http11, NotFound); + SendHttpResponse(ctx, req, res); + return; + } + + // Handshake + var wsFactory = new WebSocketServerHandshakerFactory( + GetWebSocketLocation(req), null, true, 5 * 1024 * 1024); + this.handshaker = wsFactory.NewHandshaker(req); + if (this.handshaker == null) + { + WebSocketServerHandshakerFactory.SendUnsupportedVersionResponse(ctx.Channel); + } + else + { + this.handshaker.HandshakeAsync(ctx.Channel, req); + } + } + + void HandleWebSocketFrame(IChannelHandlerContext ctx, WebSocketFrame frame) + { + // Check for closing frame + if (frame is CloseWebSocketFrame) + { + this.handshaker.CloseAsync(ctx.Channel, (CloseWebSocketFrame)frame.Retain()); + return; + } + + if (frame is PingWebSocketFrame) + { + ctx.WriteAsync(new PongWebSocketFrame((IByteBuffer)frame.Content.Retain())); + return; + } + + if (frame is TextWebSocketFrame) + { + // Echo the frame + ctx.WriteAsync(frame.Retain()); + return; + } + + if (frame is BinaryWebSocketFrame) + { + // Echo the frame + ctx.WriteAsync(frame.Retain()); + } + } + + static void SendHttpResponse(IChannelHandlerContext ctx, IFullHttpRequest req, IFullHttpResponse res) + { + // Generate an error page if response getStatus code is not OK (200). + if (res.Status.Code != 200) + { + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(res.Status.ToString())); + res.Content.WriteBytes(buf); + buf.Release(); + HttpUtil.SetContentLength(res, res.Content.ReadableBytes); + } + + // Send the response and close the connection if necessary. + Task task = ctx.Channel.WriteAndFlushAsync(res); + if (!HttpUtil.IsKeepAlive(req) || res.Status.Code != 200) + { + task.ContinueWith((t, c) => ((IChannelHandlerContext)c).CloseAsync(), + ctx, TaskContinuationOptions.ExecuteSynchronously); + } + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception e) + { + Console.WriteLine($"{nameof(WebSocketServerHandler)} {0}", e); + ctx.CloseAsync(); + } + + static string GetWebSocketLocation(IFullHttpRequest req) + { + bool result = req.Headers.TryGet(HttpHeaderNames.Host, out ICharSequence value); + Debug.Assert(result, "Host header does not exist."); + string location= value.ToString() + WebsocketPath; + + if (ServerSettings.IsSsl) + { + return "wss://" + location; + } + else + { + return "ws://" + location; + } + } + } +} diff --git a/examples/WebSockets.Server/WebSockets.Server.csproj b/examples/WebSockets.Server/WebSockets.Server.csproj new file mode 100644 index 000000000..09b16b3ea --- /dev/null +++ b/examples/WebSockets.Server/WebSockets.Server.csproj @@ -0,0 +1,32 @@ + + + + Exe + netcoreapp1.1;net451 + 1.6.1 + false + true + + + win-x64 + + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + + + + + diff --git a/examples/WebSockets.Server/appsettings.json b/examples/WebSockets.Server/appsettings.json new file mode 100644 index 000000000..9f992a912 --- /dev/null +++ b/examples/WebSockets.Server/appsettings.json @@ -0,0 +1,4 @@ +{ + "port": "8080", + "libuv": "true" +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/Unpooled.cs b/src/DotNetty.Buffers/Unpooled.cs index d832dc764..7e3206763 100644 --- a/src/DotNetty.Buffers/Unpooled.cs +++ b/src/DotNetty.Buffers/Unpooled.cs @@ -338,7 +338,7 @@ public static IByteBuffer CopiedBuffer(char[] array, int offset, int length, Enc return length == 0 ? Empty : CopiedBuffer(new string(array, offset, length), encoding); } - static IByteBuffer CopiedBuffer(string value, Encoding encoding) => ByteBufferUtil.EncodeString0(Allocator, true, value, encoding, 0); + public static IByteBuffer CopiedBuffer(string value, Encoding encoding) => ByteBufferUtil.EncodeString0(Allocator, true, value, encoding, 0); /// /// Creates a new 4-byte big-endian buffer that holds the specified 32-bit integer. diff --git a/src/DotNetty.Codecs.Http/HttpScheme.cs b/src/DotNetty.Codecs.Http/HttpScheme.cs index 8d8548aed..9bba09bf1 100644 --- a/src/DotNetty.Codecs.Http/HttpScheme.cs +++ b/src/DotNetty.Codecs.Http/HttpScheme.cs @@ -9,6 +9,12 @@ namespace DotNetty.Codecs.Http public sealed class HttpScheme { + // Scheme for non-secure HTTP connection. + public static readonly HttpScheme Http = new HttpScheme(80, "http"); + + // Scheme for secure HTTP connection. + public static readonly HttpScheme Https = new HttpScheme(443, "https"); + readonly int port; readonly AsciiString name; diff --git a/src/DotNetty.Codecs.Http/HttpServerCodec.cs b/src/DotNetty.Codecs.Http/HttpServerCodec.cs index 7cae4e15e..6c2a75943 100644 --- a/src/DotNetty.Codecs.Http/HttpServerCodec.cs +++ b/src/DotNetty.Codecs.Http/HttpServerCodec.cs @@ -98,7 +98,6 @@ protected override void SanitizeHeadersBeforeEncode(IHttpResponse msg, bool isAl base.SanitizeHeadersBeforeEncode(msg, isAlwaysEmpty); } - protected override bool IsContentAlwaysEmpty(IHttpResponse msg) { this.method = this.serverCodec.queue.Count > 0 ? this.serverCodec.queue.Dequeue() : null; diff --git a/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs index e7e4d5572..d47663731 100644 --- a/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs +++ b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs @@ -219,7 +219,6 @@ bool Upgrade(IChannelHandlerContext ctx, IFullHttpRequest request) } // Make sure the CONNECTION header is present. - ; if (!request.Headers.TryGet(HttpHeaderNames.Connection, out ICharSequence connectionHeader)) { return false; @@ -310,7 +309,6 @@ static IList SplitHeader(ICharSequence header) } if (c == ',') { - // Add the string and reset the builder for the next protocol. // Add the string and reset the builder for the next protocol. protocols.Add(new AsciiString(builder.ToString())); builder.Length = 0; diff --git a/src/DotNetty.Codecs.Http/WebSockets/BinaryWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/BinaryWebSocketFrame.cs new file mode 100644 index 000000000..e9a553a38 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/BinaryWebSocketFrame.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Buffers; + + public class BinaryWebSocketFrame : WebSocketFrame + { + public BinaryWebSocketFrame() + : base(Unpooled.Buffer(0)) + { + } + + public BinaryWebSocketFrame(IByteBuffer binaryData) + : base(binaryData) + { + } + + public BinaryWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + public override IByteBufferHolder Replace(IByteBuffer content) => new BinaryWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/CloseWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/CloseWebSocketFrame.cs new file mode 100644 index 000000000..b290dcc47 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/CloseWebSocketFrame.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + public class CloseWebSocketFrame : WebSocketFrame + { + public CloseWebSocketFrame() + : base(Unpooled.Buffer(0)) + { + } + + public CloseWebSocketFrame(int statusCode, ICharSequence reasonText) + : this(true, 0, statusCode, reasonText) + { + } + + public CloseWebSocketFrame(bool finalFragment, int rsv) + : this(finalFragment, rsv, Unpooled.Buffer(0)) + { + } + + public CloseWebSocketFrame(bool finalFragment, int rsv, int statusCode, ICharSequence reasonText) + : base(finalFragment, rsv, NewBinaryData(statusCode, reasonText)) + { + } + + static IByteBuffer NewBinaryData(int statusCode, ICharSequence reasonText) + { + if (reasonText == null) + { + reasonText = StringCharSequence.Empty; + } + + IByteBuffer binaryData = Unpooled.Buffer(2 + reasonText.Count); + binaryData.WriteShort(statusCode); + if (reasonText.Count > 0) + { + binaryData.WriteCharSequence(reasonText, Encoding.UTF8); + } + + binaryData.SetReaderIndex(0); + return binaryData; + } + + public CloseWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + /// + /// Returns the closing status code as per http://tools.ietf.org/html/rfc6455#section-7.4 RFC 6455. + /// If a getStatus code is set, -1 is returned. + /// + public int StatusCode() + { + IByteBuffer binaryData = this.Content; + if (binaryData == null || binaryData.Capacity == 0) + { + return -1; + } + + binaryData.SetReaderIndex(0); + int statusCode = binaryData.ReadShort(); + binaryData.SetReaderIndex(0); + + return statusCode; + } + + /// + /// Returns the reason text as per http://tools.ietf.org/html/rfc6455#section-7.4 RFC 6455 + /// If a reason text is not supplied, an empty string is returned. + /// + public ICharSequence ReasonText() + { + IByteBuffer binaryData = this.Content; + if (binaryData == null || binaryData.Capacity <= 2) + { + return StringCharSequence.Empty; + } + + binaryData.SetReaderIndex(2); + string reasonText = binaryData.ToString(Encoding.UTF8); + binaryData.SetReaderIndex(0); + + return new StringCharSequence(reasonText); + } + + public override IByteBufferHolder Replace(IByteBuffer content) => new CloseWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/ContinuationWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/ContinuationWebSocketFrame.cs new file mode 100644 index 000000000..6854fe2b2 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/ContinuationWebSocketFrame.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Buffers; + + public class ContinuationWebSocketFrame : WebSocketFrame + { + public ContinuationWebSocketFrame() + : this(Unpooled.Buffer(0)) + { + } + + public ContinuationWebSocketFrame(IByteBuffer binaryData) + : base(binaryData) + { + } + + public ContinuationWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + public ContinuationWebSocketFrame(bool finalFragment, int rsv, string text) + : this(finalFragment, rsv, FromText(text)) + { + } + + public string Text() => this.Content.ToString(Encoding.UTF8); + + static IByteBuffer FromText(string text) => string.IsNullOrEmpty(text) + ? Unpooled.Empty : Unpooled.CopiedBuffer(text, Encoding.UTF8); + + public override IByteBufferHolder Replace(IByteBuffer content) => new ContinuationWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateDecoder.cs new file mode 100644 index 000000000..b119e2f25 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateDecoder.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + + abstract class DeflateDecoder : WebSocketExtensionDecoder + { + internal static readonly byte[] FrameTail = { 0x00, 0x00, 0xff, 0xff }; + + readonly bool noContext; + + EmbeddedChannel decoder; + + protected DeflateDecoder(bool noContext) + { + this.noContext = noContext; + } + + protected abstract bool AppendFrameTail(WebSocketFrame msg); + + protected abstract int NewRsv(WebSocketFrame msg); + + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + if (this.decoder == null) + { + if (!(msg is TextWebSocketFrame) && !(msg is BinaryWebSocketFrame)) + { + throw new CodecException($"unexpected initial frame type: {msg.GetType().Name}"); + } + + this.decoder = new EmbeddedChannel(ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.None)); + } + + bool readable = msg.Content.IsReadable(); + this.decoder.WriteInbound(msg.Content.Retain()); + if (this.AppendFrameTail(msg)) + { + this.decoder.WriteInbound(Unpooled.WrappedBuffer(FrameTail)); + } + + CompositeByteBuffer compositeUncompressedContent = ctx.Allocator.CompositeDirectBuffer(); + for (;;) + { + var partUncompressedContent = this.decoder.ReadInbound(); + if (partUncompressedContent == null) + { + break; + } + + if (!partUncompressedContent.IsReadable()) + { + partUncompressedContent.Release(); + continue; + } + + compositeUncompressedContent.AddComponent(true, partUncompressedContent); + } + + // Correctly handle empty frames + // See https://github.com/netty/netty/issues/4348 + if (readable && compositeUncompressedContent.NumComponents <= 0) + { + compositeUncompressedContent.Release(); + throw new CodecException("cannot read uncompressed buffer"); + } + + if (msg.IsFinalFragment && this.noContext) + { + this.Cleanup(); + } + + WebSocketFrame outMsg; + if (msg is TextWebSocketFrame) + { + outMsg = new TextWebSocketFrame(msg.IsFinalFragment, this.NewRsv(msg), compositeUncompressedContent); + } + else if (msg is BinaryWebSocketFrame) + { + outMsg = new BinaryWebSocketFrame(msg.IsFinalFragment, this.NewRsv(msg), compositeUncompressedContent); + } + else if (msg is ContinuationWebSocketFrame) + { + outMsg = new ContinuationWebSocketFrame(msg.IsFinalFragment, this.NewRsv(msg), compositeUncompressedContent); + } + else + { + throw new CodecException($"unexpected frame type: {msg.GetType().Name}"); + } + + output.Add(outMsg); + } + + public override void HandlerRemoved(IChannelHandlerContext ctx) + { + this.Cleanup(); + base.HandlerRemoved(ctx); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + this.Cleanup(); + base.ChannelInactive(ctx); + } + + void Cleanup() + { + if (this.decoder != null) + { + // Clean-up the previous encoder if not cleaned up correctly. + if (this.decoder.Finish()) + { + for (;;) + { + var buf = this.decoder.ReadOutbound(); + if (buf == null) + { + break; + } + // Release the buffer + buf.Release(); + } + } + this.decoder = null; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateEncoder.cs new file mode 100644 index 000000000..abcf36ac7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateEncoder.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + + using static DeflateDecoder; + + abstract class DeflateEncoder : WebSocketExtensionEncoder + { + readonly int compressionLevel; + readonly int windowSize; + readonly bool noContext; + + EmbeddedChannel encoder; + + protected DeflateEncoder(int compressionLevel, int windowSize, bool noContext) + { + this.compressionLevel = compressionLevel; + this.windowSize = windowSize; + this.noContext = noContext; + } + + protected abstract int Rsv(WebSocketFrame msg); + + protected abstract bool RemoveFrameTail(WebSocketFrame msg); + + protected override void Encode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + if (this.encoder == null) + { + this.encoder = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder( + ZlibWrapper.None, + this.compressionLevel, + this.windowSize, + 8)); + } + + this.encoder.WriteOutbound(msg.Content.Retain()); + + CompositeByteBuffer fullCompressedContent = ctx.Allocator.CompositeBuffer(); + for (;;) + { + var partCompressedContent = this.encoder.ReadOutbound(); + if (partCompressedContent == null) + { + break; + } + + if (!partCompressedContent.IsReadable()) + { + partCompressedContent.Release(); + continue; + } + + fullCompressedContent.AddComponent(true, partCompressedContent); + } + + if (fullCompressedContent.NumComponents <= 0) + { + fullCompressedContent.Release(); + throw new CodecException("cannot read compressed buffer"); + } + + if (msg.IsFinalFragment && this.noContext) + { + this.Cleanup(); + } + + IByteBuffer compressedContent; + if (this.RemoveFrameTail(msg)) + { + int realLength = fullCompressedContent.ReadableBytes - FrameTail.Length; + compressedContent = fullCompressedContent.Slice(0, realLength); + } + else + { + compressedContent = fullCompressedContent; + } + + WebSocketFrame outMsg; + if (msg is TextWebSocketFrame) + { + outMsg = new TextWebSocketFrame(msg.IsFinalFragment, this.Rsv(msg), compressedContent); + } + else if (msg is BinaryWebSocketFrame) + { + outMsg = new BinaryWebSocketFrame(msg.IsFinalFragment, this.Rsv(msg), compressedContent); + } + else if (msg is ContinuationWebSocketFrame) + { + outMsg = new ContinuationWebSocketFrame(msg.IsFinalFragment, this.Rsv(msg), compressedContent); + } + else + { + throw new CodecException($"unexpected frame type: {msg.GetType().Name}"); + } + + output.Add(outMsg); + } + + public override void HandlerRemoved(IChannelHandlerContext ctx) + { + this.Cleanup(); + base.HandlerRemoved(ctx); + } + + void Cleanup() + { + if (this.encoder != null) + { + // Clean-up the previous encoder if not cleaned up correctly. + if (this.encoder.Finish()) + { + for (;;) + { + var buf = this.encoder.ReadOutbound(); + if (buf == null) + { + break; + } + // Release the buffer + buf.Release(); + } + } + this.encoder = null; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshaker.cs new file mode 100644 index 000000000..c5600dae7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshaker.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System; + using System.Collections.Generic; + + using static DeflateFrameServerExtensionHandshaker; + + public sealed class DeflateFrameClientExtensionHandshaker : IWebSocketClientExtensionHandshaker + { + readonly int compressionLevel; + readonly bool useWebkitExtensionName; + + public DeflateFrameClientExtensionHandshaker(bool useWebkitExtensionName) + : this(6, useWebkitExtensionName) + { + } + + public DeflateFrameClientExtensionHandshaker(int compressionLevel, bool useWebkitExtensionName) + { + if (compressionLevel < 0 || compressionLevel > 9) + { + throw new ArgumentException($"compressionLevel: {compressionLevel} (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.useWebkitExtensionName = useWebkitExtensionName; + } + + public WebSocketExtensionData NewRequestData() => new WebSocketExtensionData( + this.useWebkitExtensionName ? XWebkitDeflateFrameExtension : DeflateFrameExtension, + new Dictionary()); + + public IWebSocketClientExtension HandshakeExtension(WebSocketExtensionData extensionData) + { + if (!XWebkitDeflateFrameExtension.Equals(extensionData.Name) && + !DeflateFrameExtension.Equals(extensionData.Name)) + { + return null; + } + + if (extensionData.Parameters.Count == 0) + { + return new DeflateFrameClientExtension(this.compressionLevel); + } + else + { + return null; + } + } + + sealed class DeflateFrameClientExtension : IWebSocketClientExtension + { + readonly int compressionLevel; + + public DeflateFrameClientExtension(int compressionLevel) + { + this.compressionLevel = compressionLevel; + } + + public int Rsv => WebSocketRsv.Rsv1; + + public WebSocketExtensionEncoder NewExtensionEncoder() => new PerFrameDeflateEncoder(this.compressionLevel, 15, false); + + public WebSocketExtensionDecoder NewExtensionDecoder() => new PerFrameDeflateDecoder(false); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshaker.cs new file mode 100644 index 000000000..9d6c24a38 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshaker.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System; + using System.Collections.Generic; + + public sealed class DeflateFrameServerExtensionHandshaker : IWebSocketServerExtensionHandshaker + { + internal static readonly string XWebkitDeflateFrameExtension = "x-webkit-deflate-frame"; + internal static readonly string DeflateFrameExtension = "deflate-frame"; + + readonly int compressionLevel; + + public DeflateFrameServerExtensionHandshaker() + : this(6) + { + } + + public DeflateFrameServerExtensionHandshaker(int compressionLevel) + { + if (compressionLevel < 0 || compressionLevel > 9) + { + throw new ArgumentException($"compressionLevel: {compressionLevel} (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + } + + public IWebSocketServerExtension HandshakeExtension(WebSocketExtensionData extensionData) + { + if (!XWebkitDeflateFrameExtension.Equals(extensionData.Name) + && !DeflateFrameExtension.Equals(extensionData.Name)) + { + return null; + } + + if (extensionData.Parameters.Count == 0) + { + return new DeflateFrameServerExtension(this.compressionLevel, extensionData.Name); + } + else + { + return null; + } + } + + sealed class DeflateFrameServerExtension : IWebSocketServerExtension + { + readonly string extensionName; + readonly int compressionLevel; + + public DeflateFrameServerExtension(int compressionLevel, string extensionName) + { + this.extensionName = extensionName; + this.compressionLevel = compressionLevel; + } + + public int Rsv => WebSocketRsv.Rsv1; + + public WebSocketExtensionEncoder NewExtensionEncoder() => new PerFrameDeflateEncoder(this.compressionLevel, 15, false); + + public WebSocketExtensionDecoder NewExtensionDecoder() => new PerFrameDeflateDecoder(false); + + public WebSocketExtensionData NewReponseData() => new WebSocketExtensionData(this.extensionName, new Dictionary()); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateDecoder.cs new file mode 100644 index 000000000..00e595d71 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateDecoder.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + class PerFrameDeflateDecoder : DeflateDecoder + { + public PerFrameDeflateDecoder(bool noContext) + : base(noContext) + { + } + + public override bool AcceptInboundMessage(object msg) => + (msg is TextWebSocketFrame || msg is BinaryWebSocketFrame || msg is ContinuationWebSocketFrame) + && (((WebSocketFrame) msg).Rsv & WebSocketRsv.Rsv1) > 0; + + protected override int NewRsv(WebSocketFrame msg) => msg.Rsv ^ WebSocketRsv.Rsv1; + + protected override bool AppendFrameTail(WebSocketFrame msg) => true; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateEncoder.cs new file mode 100644 index 000000000..4b5937d80 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerFrameDeflateEncoder.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + class PerFrameDeflateEncoder : DeflateEncoder + { + public PerFrameDeflateEncoder(int compressionLevel, int windowSize, bool noContext) + : base(compressionLevel, windowSize, noContext) + { + } + + public override bool AcceptOutboundMessage(object msg) => + (msg is TextWebSocketFrame || msg is BinaryWebSocketFrame || msg is ContinuationWebSocketFrame) + && ((WebSocketFrame)msg).Content.ReadableBytes > 0 + && (((WebSocketFrame)msg).Rsv & WebSocketRsv.Rsv1) == 0; + + protected override int Rsv(WebSocketFrame msg) => msg.Rsv | WebSocketRsv.Rsv1; + + protected override bool RemoveFrameTail(WebSocketFrame msg) => true; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshaker.cs new file mode 100644 index 000000000..45599f331 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshaker.cs @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System; + using System.Collections.Generic; + using DotNetty.Codecs.Compression; + + using static PerMessageDeflateServerExtensionHandshaker; + + public sealed class PerMessageDeflateClientExtensionHandshaker : IWebSocketClientExtensionHandshaker + { + readonly int compressionLevel; + readonly bool allowClientWindowSize; + readonly int requestedServerWindowSize; + readonly bool allowClientNoContext; + readonly bool requestedServerNoContext; + + public PerMessageDeflateClientExtensionHandshaker() + : this(6, ZlibCodecFactory.IsSupportingWindowSizeAndMemLevel, MaxWindowSize, false, false) + { + } + + public PerMessageDeflateClientExtensionHandshaker(int compressionLevel, + bool allowClientWindowSize, int requestedServerWindowSize, + bool allowClientNoContext, bool requestedServerNoContext) + { + if (requestedServerWindowSize > MaxWindowSize || requestedServerWindowSize < MinWindowSize) + { + throw new ArgumentException($"requestedServerWindowSize: {requestedServerWindowSize} (expected: 8-15)"); + } + if (compressionLevel < 0 || compressionLevel > 9) + { + throw new ArgumentException($"compressionLevel: {compressionLevel} (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.allowClientWindowSize = allowClientWindowSize; + this.requestedServerWindowSize = requestedServerWindowSize; + this.allowClientNoContext = allowClientNoContext; + this.requestedServerNoContext = requestedServerNoContext; + } + + public WebSocketExtensionData NewRequestData() + { + var parameters = new Dictionary(4); + if (this.requestedServerWindowSize != MaxWindowSize) + { + parameters.Add(ServerNoContext, null); + } + if (this.allowClientNoContext) + { + parameters.Add(ClientNoContext, null); + } + if (this.requestedServerWindowSize != MaxWindowSize) + { + parameters.Add(ServerMaxWindow, Convert.ToString(this.requestedServerWindowSize)); + } + if (this.allowClientWindowSize) + { + parameters.Add(ClientMaxWindow, null); + } + return new WebSocketExtensionData(PerMessageDeflateExtension, parameters); + } + + public IWebSocketClientExtension HandshakeExtension(WebSocketExtensionData extensionData) + { + if (!PerMessageDeflateExtension.Equals(extensionData.Name)) + { + return null; + } + + bool succeed = true; + int clientWindowSize = MaxWindowSize; + int serverWindowSize = MaxWindowSize; + bool serverNoContext = false; + bool clientNoContext = false; + + foreach (KeyValuePair parameter in extensionData.Parameters) + { + if (ClientMaxWindow.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // allowed client_window_size_bits + if (this.allowClientWindowSize) + { + clientWindowSize = int.Parse(parameter.Value); + } + else + { + succeed = false; + } + } + else if (ServerMaxWindow.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // acknowledged server_window_size_bits + serverWindowSize = int.Parse(parameter.Value); + if (clientWindowSize > MaxWindowSize || clientWindowSize < MinWindowSize) + { + succeed = false; + } + } + else if (ClientNoContext.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // allowed client_no_context_takeover + if (this.allowClientNoContext) + { + clientNoContext = true; + } + else + { + succeed = false; + } + } + else if (ServerNoContext.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // acknowledged server_no_context_takeover + if (this.requestedServerNoContext) + { + serverNoContext = true; + } + else + { + succeed = false; + } + } + else + { + // unknown parameter + succeed = false; + } + + if (!succeed) + { + break; + } + } + + if ((this.requestedServerNoContext && !serverNoContext) + || this.requestedServerWindowSize != serverWindowSize) + { + succeed = false; + } + + if (succeed) + { + return new WebSocketPermessageDeflateExtension(serverNoContext, serverWindowSize, + clientNoContext, this.compressionLevel); + } + else + { + return null; + } + } + + sealed class WebSocketPermessageDeflateExtension : IWebSocketClientExtension + { + readonly bool serverNoContext; + readonly int serverWindowSize; + readonly bool clientNoContext; + readonly int compressionLevel; + + public int Rsv => WebSocketRsv.Rsv1; + + public WebSocketPermessageDeflateExtension(bool serverNoContext, int serverWindowSize, + bool clientNoContext, int compressionLevel) + { + this.serverNoContext = serverNoContext; + this.serverWindowSize = serverWindowSize; + this.clientNoContext = clientNoContext; + this.compressionLevel = compressionLevel; + } + + public WebSocketExtensionEncoder NewExtensionEncoder() => + new PerMessageDeflateEncoder(this.compressionLevel, this.serverWindowSize, this.serverNoContext); + + public WebSocketExtensionDecoder NewExtensionDecoder() => new PerMessageDeflateDecoder(this.clientNoContext); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateDecoder.cs new file mode 100644 index 000000000..516f10ce4 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateDecoder.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Transport.Channels; + + class PerMessageDeflateDecoder : DeflateDecoder + { + bool compressing; + + public PerMessageDeflateDecoder(bool noContext) + : base(noContext) + { + } + + public override bool AcceptInboundMessage(object msg) => + ((msg is TextWebSocketFrame || msg is BinaryWebSocketFrame) + && (((WebSocketFrame)msg).Rsv & WebSocketRsv.Rsv1) > 0) + || (msg is ContinuationWebSocketFrame && this.compressing); + + protected override int NewRsv(WebSocketFrame msg) => + (msg.Rsv & WebSocketRsv.Rsv1) > 0 ? msg.Rsv ^ WebSocketRsv.Rsv1 : msg.Rsv; + + protected override bool AppendFrameTail(WebSocketFrame msg) => msg.IsFinalFragment; + + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + base.Decode(ctx, msg, output); + + if (msg.IsFinalFragment) + { + this.compressing = false; + } + else if (msg is TextWebSocketFrame || msg is BinaryWebSocketFrame) + { + this.compressing = true; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateEncoder.cs new file mode 100644 index 000000000..9bd2a91fa --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateEncoder.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Transport.Channels; + + class PerMessageDeflateEncoder : DeflateEncoder + { + bool compressing; + + public PerMessageDeflateEncoder(int compressionLevel, int windowSize, bool noContext) + : base(compressionLevel, windowSize, noContext) + { + } + + public override bool AcceptOutboundMessage(object msg) => + ((msg is TextWebSocketFrame || msg is BinaryWebSocketFrame) + && (((WebSocketFrame) msg).Rsv & WebSocketRsv.Rsv1) == 0) + || (msg is ContinuationWebSocketFrame && this.compressing); + + protected override int Rsv(WebSocketFrame msg) => + msg is TextWebSocketFrame || msg is BinaryWebSocketFrame + ? msg.Rsv | WebSocketRsv.Rsv1 + : msg.Rsv; + + protected override bool RemoveFrameTail(WebSocketFrame msg) => msg.IsFinalFragment; + + protected override void Encode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + base.Encode(ctx, msg, output); + + if (msg.IsFinalFragment) + { + this.compressing = false; + } + else if (msg is TextWebSocketFrame || msg is BinaryWebSocketFrame) + { + this.compressing = true; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshaker.cs new file mode 100644 index 000000000..6275ac929 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshaker.cs @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + using System; + using System.Collections.Generic; + using DotNetty.Codecs.Compression; + + public sealed class PerMessageDeflateServerExtensionHandshaker : IWebSocketServerExtensionHandshaker + { + public static readonly int MinWindowSize = 8; + public static readonly int MaxWindowSize = 15; + + internal static readonly string PerMessageDeflateExtension = "permessage-deflate"; + internal static readonly string ClientMaxWindow = "client_max_window_bits"; + internal static readonly string ServerMaxWindow = "server_max_window_bits"; + internal static readonly string ClientNoContext = "client_no_context_takeover"; + internal static readonly string ServerNoContext = "server_no_context_takeover"; + + readonly int compressionLevel; + readonly bool allowServerWindowSize; + readonly int preferredClientWindowSize; + readonly bool allowServerNoContext; + readonly bool preferredClientNoContext; + + public PerMessageDeflateServerExtensionHandshaker() + : this(6, ZlibCodecFactory.IsSupportingWindowSizeAndMemLevel, MaxWindowSize, false, false) + { + } + + public PerMessageDeflateServerExtensionHandshaker(int compressionLevel, + bool allowServerWindowSize, int preferredClientWindowSize, + bool allowServerNoContext, bool preferredClientNoContext) + { + if (preferredClientWindowSize > MaxWindowSize || preferredClientWindowSize < MinWindowSize) + { + throw new ArgumentException($"preferredServerWindowSize: {preferredClientWindowSize} (expected: 8-15)"); + } + if (compressionLevel < 0 || compressionLevel > 9) + { + throw new ArgumentException($"compressionLevel: {compressionLevel} (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.allowServerWindowSize = allowServerWindowSize; + this.preferredClientWindowSize = preferredClientWindowSize; + this.allowServerNoContext = allowServerNoContext; + this.preferredClientNoContext = preferredClientNoContext; + } + + public IWebSocketServerExtension HandshakeExtension(WebSocketExtensionData extensionData) + { + if (!PerMessageDeflateExtension.Equals(extensionData.Name)) + { + return null; + } + + bool deflateEnabled = true; + int clientWindowSize = MaxWindowSize; + int serverWindowSize = MaxWindowSize; + bool serverNoContext = false; + bool clientNoContext = false; + + foreach (KeyValuePair parameter in extensionData.Parameters) + { + if (ClientMaxWindow.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // use preferred clientWindowSize because client is compatible with customization + clientWindowSize = this.preferredClientWindowSize; + } + else if (ServerMaxWindow.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // use provided windowSize if it is allowed + if (this.allowServerWindowSize) + { + serverWindowSize = int.Parse(parameter.Value); + if (serverWindowSize > MaxWindowSize || serverWindowSize < MinWindowSize) + { + deflateEnabled = false; + } + } + else + { + deflateEnabled = false; + } + } + else if (ClientNoContext.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // use preferred clientNoContext because client is compatible with customization + clientNoContext = this.preferredClientNoContext; + } + else if (ServerNoContext.Equals(parameter.Key, StringComparison.OrdinalIgnoreCase)) + { + // use server no context if allowed + if (this.allowServerNoContext) + { + serverNoContext = true; + } + else + { + deflateEnabled = false; + } + } + else + { + // unknown parameter + deflateEnabled = false; + } + if (!deflateEnabled) + { + break; + } + } + + if (deflateEnabled) + { + return new WebSocketPermessageDeflateExtension(this.compressionLevel, serverNoContext, + serverWindowSize, clientNoContext, clientWindowSize); + } + else + { + return null; + } + } + + sealed class WebSocketPermessageDeflateExtension : IWebSocketServerExtension + { + readonly int compressionLevel; + readonly bool serverNoContext; + readonly int serverWindowSize; + readonly bool clientNoContext; + readonly int clientWindowSize; + + public WebSocketPermessageDeflateExtension(int compressionLevel, bool serverNoContext, + int serverWindowSize, bool clientNoContext, int clientWindowSize) + { + this.compressionLevel = compressionLevel; + this.serverNoContext = serverNoContext; + this.serverWindowSize = serverWindowSize; + this.clientNoContext = clientNoContext; + this.clientWindowSize = clientWindowSize; + } + + public int Rsv => WebSocketRsv.Rsv1; + + public WebSocketExtensionEncoder NewExtensionEncoder() => + new PerMessageDeflateEncoder(this.compressionLevel, this.clientWindowSize, this.clientNoContext); + + public WebSocketExtensionDecoder NewExtensionDecoder() => new PerMessageDeflateDecoder(this.serverNoContext); + + public WebSocketExtensionData NewReponseData() + { + var parameters = new Dictionary(4); + if (this.serverNoContext) + { + parameters.Add(ServerNoContext, null); + } + if (this.clientNoContext) + { + parameters.Add(ClientNoContext, null); + } + if (this.serverWindowSize != MaxWindowSize) + { + parameters.Add(ServerMaxWindow, Convert.ToString(this.serverWindowSize)); + } + if (this.clientWindowSize != MaxWindowSize) + { + parameters.Add(ClientMaxWindow, Convert.ToString(this.clientWindowSize)); + } + return new WebSocketExtensionData(PerMessageDeflateExtension, parameters); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketClientCompressionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketClientCompressionHandler.cs new file mode 100644 index 000000000..1ca3c8007 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketClientCompressionHandler.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + public sealed class WebSocketClientCompressionHandler : WebSocketClientExtensionHandler + { + public static readonly WebSocketClientCompressionHandler Instance = new WebSocketClientCompressionHandler(); + + public override bool IsSharable => true; + + WebSocketClientCompressionHandler() + : base(new PerMessageDeflateClientExtensionHandshaker(), + new DeflateFrameClientExtensionHandshaker(false), + new DeflateFrameClientExtensionHandshaker(true)) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketServerCompressionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketServerCompressionHandler.cs new file mode 100644 index 000000000..6d8d9c870 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/Compression/WebSocketServerCompressionHandler.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions.Compression +{ + public class WebSocketServerCompressionHandler : WebSocketServerExtensionHandler + { + public WebSocketServerCompressionHandler() + : base(new PerMessageDeflateServerExtensionHandshaker(), new DeflateFrameServerExtensionHandshaker()) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtension.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtension.cs new file mode 100644 index 000000000..9b31a2e5c --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtension.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public interface IWebSocketClientExtension : IWebSocketExtension + { + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtensionHandshaker.cs new file mode 100644 index 000000000..0a8adb102 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketClientExtensionHandshaker.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public interface IWebSocketClientExtensionHandshaker + { + WebSocketExtensionData NewRequestData(); + + IWebSocketClientExtension HandshakeExtension(WebSocketExtensionData extensionData); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketExtension.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketExtension.cs new file mode 100644 index 000000000..af3c5ad4d --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketExtension.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public interface IWebSocketExtension + { + /// + /// The reserved bit value to ensure that no other extension should interfere. + /// + int Rsv { get; } + + WebSocketExtensionEncoder NewExtensionEncoder(); + + WebSocketExtensionDecoder NewExtensionDecoder(); + } + + public static class WebSocketRsv + { + public static readonly int Rsv1 = 0x04; + public static readonly int Rsv2 = 0x02; + public static readonly int Rsv3 = 0x01; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtension.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtension.cs new file mode 100644 index 000000000..fbdc98102 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtension.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public interface IWebSocketServerExtension : IWebSocketExtension + { + WebSocketExtensionData NewReponseData(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtensionHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtensionHandshaker.cs new file mode 100644 index 000000000..b02c1386c --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/IWebSocketServerExtensionHandshaker.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public interface IWebSocketServerExtensionHandshaker + { + IWebSocketServerExtension HandshakeExtension(WebSocketExtensionData extensionData); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs new file mode 100644 index 000000000..5a72c6d1c --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketClientExtensionHandler.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocketClientExtensionHandler : ChannelHandlerAdapter + { + readonly List extensionHandshakers; + + public WebSocketClientExtensionHandler(params IWebSocketClientExtensionHandshaker[] extensionHandshakers) + { + Contract.Requires(extensionHandshakers != null && extensionHandshakers.Length > 0); + this.extensionHandshakers = new List(extensionHandshakers); + } + + public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + { + if (msg is IHttpRequest request && WebSocketExtensionUtil.IsWebsocketUpgrade(request.Headers)) + { + string headerValue = null; + if (request.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)) + { + headerValue = value.ToString(); + } + + foreach (IWebSocketClientExtensionHandshaker extensionHandshaker in this.extensionHandshakers) + { + WebSocketExtensionData extensionData = extensionHandshaker.NewRequestData(); + headerValue = WebSocketExtensionUtil.AppendExtension(headerValue, + extensionData.Name, extensionData.Parameters); + } + + request.Headers.Set(HttpHeaderNames.SecWebsocketExtensions, headerValue); + } + + return base.WriteAsync(ctx, msg); + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + if (msg is IHttpResponse response + && WebSocketExtensionUtil.IsWebsocketUpgrade(response.Headers)) + { + string extensionsHeader = null; + if (response.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)) + { + extensionsHeader = value.ToString(); + } + + if (extensionsHeader != null) + { + List extensions = + WebSocketExtensionUtil.ExtractExtensions(extensionsHeader); + var validExtensions = new List(extensions.Count); + int rsv = 0; + + foreach (WebSocketExtensionData extensionData in extensions) + { + IWebSocketClientExtension validExtension = null; + foreach (IWebSocketClientExtensionHandshaker extensionHandshaker in this.extensionHandshakers) + { + validExtension = extensionHandshaker.HandshakeExtension(extensionData); + if (validExtension != null) + { + break; + } + } + + if (validExtension != null && (validExtension.Rsv & rsv) == 0) + { + rsv = rsv | validExtension.Rsv; + validExtensions.Add(validExtension); + } + else + { + throw new CodecException($"invalid WebSocket Extension handshake for \"{extensionsHeader}\""); + } + } + + foreach (IWebSocketClientExtension validExtension in validExtensions) + { + WebSocketExtensionDecoder decoder = validExtension.NewExtensionDecoder(); + WebSocketExtensionEncoder encoder = validExtension.NewExtensionEncoder(); + ctx.Channel.Pipeline.AddAfter(ctx.Name, decoder.GetType().Name, decoder); + ctx.Channel.Pipeline.AddAfter(ctx.Name, encoder.GetType().Name, encoder); + } + } + + ctx.Channel.Pipeline.Remove(ctx.Name); + } + + base.ChannelRead(ctx, msg); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionData.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionData.cs new file mode 100644 index 000000000..9196c5963 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionData.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + using System.Collections.Generic; + using System.Diagnostics.Contracts; + + public sealed class WebSocketExtensionData + { + readonly string name; + readonly Dictionary parameters; + + public WebSocketExtensionData(string name, IDictionary parameters) + { + Contract.Requires(name != null); + Contract.Requires(parameters != null); + + this.name = name; + this.parameters = new Dictionary(parameters); + } + + public string Name => this.name; + + public Dictionary Parameters => this.parameters; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionDecoder.cs new file mode 100644 index 000000000..515d0db8d --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionDecoder.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public abstract class WebSocketExtensionDecoder : MessageToMessageDecoder + { + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionEncoder.cs new file mode 100644 index 000000000..1e771e561 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionEncoder.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + public abstract class WebSocketExtensionEncoder : MessageToMessageEncoder + { + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionUtil.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionUtil.cs new file mode 100644 index 000000000..169fb1064 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketExtensionUtil.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + using System.Collections.Generic; + using System.Text; + using System.Text.RegularExpressions; + + public static class WebSocketExtensionUtil + { + const char ExtensionSeparator = ','; + const char ParameterSeparator = ';'; + const char ParameterEqual = '='; + + static readonly Regex Parameter = new Regex("^([^=]+)(=[\\\"]?([^\\\"]+)[\\\"]?)?$", RegexOptions.Compiled); + + internal static bool IsWebsocketUpgrade(HttpHeaders headers) => + headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true) + && headers.Contains(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket, true); + + public static List ExtractExtensions(string extensionHeader) + { + string[] rawExtensions = extensionHeader.Split(ExtensionSeparator); + if (rawExtensions.Length > 0) + { + var extensions = new List(rawExtensions.Length); + foreach (string rawExtension in rawExtensions) + { + string[] extensionParameters = rawExtension.Split(ParameterSeparator); + string name = extensionParameters[0].Trim(); + Dictionary parameters; + if (extensionParameters.Length > 1) + { + parameters = new Dictionary(extensionParameters.Length - 1); + for (int i = 1; i < extensionParameters.Length; i++) + { + string parameter = extensionParameters[i].Trim(); + + Match match = Parameter.Match(parameter); + if (match.Success) + { + parameters.Add(match.Groups[1].Value, match.Groups[3].Value); + } + } + } + else + { + parameters = new Dictionary(); + } + extensions.Add(new WebSocketExtensionData(name, parameters)); + } + return extensions; + } + else + { + return new List(); + } + } + + internal static string AppendExtension(string currentHeaderValue, string extensionName, + Dictionary extensionParameters) + { + var newHeaderValue = new StringBuilder(currentHeaderValue?.Length ?? extensionName.Length + 1); + if (currentHeaderValue != null && currentHeaderValue.Trim() != string.Empty) + { + newHeaderValue.Append(currentHeaderValue); + newHeaderValue.Append(ExtensionSeparator); + } + newHeaderValue.Append(extensionName); + foreach (KeyValuePair extensionParameter in extensionParameters) + { + newHeaderValue.Append(ParameterSeparator); + newHeaderValue.Append(extensionParameter.Key); + if (extensionParameter.Value != null) + { + newHeaderValue.Append(ParameterEqual); + newHeaderValue.Append(extensionParameter.Value); + } + } + return newHeaderValue.ToString(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs new file mode 100644 index 000000000..e0b07e36b --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Extensions/WebSocketServerExtensionHandler.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets.Extensions +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocketServerExtensionHandler : ChannelHandlerAdapter + { + readonly List extensionHandshakers; + + List validExtensions; + + public WebSocketServerExtensionHandler(params IWebSocketServerExtensionHandshaker[] extensionHandshakers) + { + Contract.Requires(extensionHandshakers != null && extensionHandshakers.Length > 0); + + this.extensionHandshakers = new List(extensionHandshakers); + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + if (msg is IHttpRequest request) + { + if (WebSocketExtensionUtil.IsWebsocketUpgrade(request.Headers)) + { + if (request.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value) + && value != null) + { + string extensionsHeader = value.ToString(); + List extensions = + WebSocketExtensionUtil.ExtractExtensions(extensionsHeader); + int rsv = 0; + + foreach (WebSocketExtensionData extensionData in extensions) + { + IWebSocketServerExtension validExtension = null; + foreach (IWebSocketServerExtensionHandshaker extensionHandshaker in this.extensionHandshakers) + { + validExtension = extensionHandshaker.HandshakeExtension(extensionData); + if (validExtension != null) + { + break; + } + } + + if (validExtension != null && (validExtension.Rsv & rsv) == 0) + { + if (this.validExtensions == null) + { + this.validExtensions = new List(1); + } + + rsv = rsv | validExtension.Rsv; + this.validExtensions.Add(validExtension); + } + } + } + } + } + + base.ChannelRead(ctx, msg); + } + + public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + { + Action continuationAction = null; + + if (msg is IHttpResponse response + && WebSocketExtensionUtil.IsWebsocketUpgrade(response.Headers) + && this.validExtensions != null) + { + string headerValue = null; + if (response.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)) + { + headerValue = value?.ToString(); + } + + foreach (IWebSocketServerExtension extension in this.validExtensions) + { + WebSocketExtensionData extensionData = extension.NewReponseData(); + headerValue = WebSocketExtensionUtil.AppendExtension(headerValue, + extensionData.Name, extensionData.Parameters); + } + + continuationAction = promise => + { + if (promise.Status == TaskStatus.RanToCompletion) + { + foreach (IWebSocketServerExtension extension in this.validExtensions) + { + WebSocketExtensionDecoder decoder = extension.NewExtensionDecoder(); + WebSocketExtensionEncoder encoder = extension.NewExtensionEncoder(); + ctx.Channel.Pipeline.AddAfter(ctx.Name, decoder.GetType().Name, decoder); + ctx.Channel.Pipeline.AddAfter(ctx.Name, encoder.GetType().Name, encoder); + } + } + ctx.Channel.Pipeline.Remove(ctx.Name); + }; + + if (headerValue != null) + { + response.Headers.Set(HttpHeaderNames.SecWebsocketExtensions, headerValue); + } + } + + return continuationAction == null + ? base.WriteAsync(ctx, msg) + : base.WriteAsync(ctx, msg) + .ContinueWith(continuationAction, TaskContinuationOptions.ExecuteSynchronously); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameDecoder.cs new file mode 100644 index 000000000..a99957ac5 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameDecoder.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Transport.Channels; + + /// + /// Marker interface which all WebSocketFrame decoders need to implement. This makes it + /// easier to access the added encoder later in the + /// + public interface IWebSocketFrameDecoder : IChannelHandler + { + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameEncoder.cs new file mode 100644 index 000000000..3314c6b20 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/IWebSocketFrameEncoder.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Transport.Channels; + + /// + /// Marker interface which all WebSocketFrame encoders need to implement. This makes it + /// easier to access the added encoder later in the . + /// + public interface IWebSocketFrameEncoder : IChannelHandler + { + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/PingWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/PingWebSocketFrame.cs new file mode 100644 index 000000000..1408f9d2d --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/PingWebSocketFrame.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Buffers; + + public class PingWebSocketFrame : WebSocketFrame + { + public PingWebSocketFrame() + : base(true, 0, Unpooled.Buffer(0)) + { + } + + public PingWebSocketFrame(IByteBuffer binaryData) + : base(binaryData) + { + } + + public PingWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + public override IByteBufferHolder Replace(IByteBuffer content) => new PingWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/PongWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/PongWebSocketFrame.cs new file mode 100644 index 000000000..83012f724 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/PongWebSocketFrame.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Buffers; + + public class PongWebSocketFrame : WebSocketFrame + { + public PongWebSocketFrame() + : base(Unpooled.Buffer(0)) + { + } + + public PongWebSocketFrame(IByteBuffer binaryData) + : base(binaryData) + { + } + + public PongWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + public override IByteBufferHolder Replace(IByteBuffer content) => new PongWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/TextWebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/TextWebSocketFrame.cs new file mode 100644 index 000000000..399267956 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/TextWebSocketFrame.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Buffers; + + public class TextWebSocketFrame : WebSocketFrame + { + public TextWebSocketFrame() + : base(Unpooled.Buffer(0)) + { + } + + public TextWebSocketFrame(string text) + : base(FromText(text)) + { + } + + public TextWebSocketFrame(IByteBuffer binaryData) + : base(binaryData) + { + } + + public TextWebSocketFrame(bool finalFragment, int rsv, string text) + : base(finalFragment, rsv, FromText(text)) + { + } + + static IByteBuffer FromText(string text) => string.IsNullOrEmpty(text) + ? Unpooled.Empty : Unpooled.CopiedBuffer(text, Encoding.UTF8); + + public TextWebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(finalFragment, rsv, binaryData) + { + } + + public string Text() => this.Content.ToString(Encoding.UTF8); + + public override IByteBufferHolder Replace(IByteBuffer content) => new TextWebSocketFrame(this.IsFinalFragment, this.Rsv, content); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Utf8FrameValidator.cs b/src/DotNetty.Codecs.Http/WebSockets/Utf8FrameValidator.cs new file mode 100644 index 000000000..cdb1d40c1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Utf8FrameValidator.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class Utf8FrameValidator : ChannelHandlerAdapter + { + int fragmentedFramesCount; + Utf8Validator utf8Validator; + + public override void ChannelRead(IChannelHandlerContext ctx, object message) + { + if (message is WebSocketFrame frame) + { + // Processing for possible fragmented messages for text and binary + // frames + if (frame.IsFinalFragment) + { + // Final frame of the sequence. Apparently ping frames are + // allowed in the middle of a fragmented message + if (!(frame is PingWebSocketFrame)) + { + this.fragmentedFramesCount = 0; + + // Check text for UTF8 correctness + if (frame is TextWebSocketFrame + || (this.utf8Validator != null && this.utf8Validator.IsChecking)) + { + // Check UTF-8 correctness for this payload + this.CheckUtf8String(ctx, frame.Content); + + // This does a second check to make sure UTF-8 + // correctness for entire text message + this.utf8Validator.Finish(); + } + } + } + else + { + // Not final frame so we can expect more frames in the + // fragmented sequence + if (this.fragmentedFramesCount == 0) + { + // First text or binary frame for a fragmented set + if (frame is TextWebSocketFrame) + { + this.CheckUtf8String(ctx, frame.Content); + } + } + else + { + // Subsequent frames - only check if init frame is text + if (this.utf8Validator != null && this.utf8Validator.IsChecking) + { + this.CheckUtf8String(ctx, frame.Content); + } + } + + // Increment counter + this.fragmentedFramesCount++; + } + } + + base.ChannelRead(ctx, message); + } + + void CheckUtf8String(IChannelHandlerContext ctx, IByteBuffer buffer) + { + try + { + if (this.utf8Validator == null) + { + this.utf8Validator = new Utf8Validator(); + } + this.utf8Validator.Check(buffer); + } + catch (CorruptedFrameException) + { + if (ctx.Channel.Active) + { + ctx.WriteAndFlushAsync(Unpooled.Empty) + .ContinueWith(t => ctx.Channel.CloseAsync()); + } + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/Utf8Validator.cs b/src/DotNetty.Codecs.Http/WebSockets/Utf8Validator.cs new file mode 100644 index 000000000..90f5a229e --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/Utf8Validator.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Runtime.CompilerServices; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + sealed class Utf8Validator : IByteProcessor + { + const int Utf8Accept = 0; + const int Utf8Reject = 12; + + static readonly byte[] Types = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, + 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8 }; + + static readonly byte[] States = { 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, + 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, + 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 12, 12, 36, + 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12 }; + + int state = Utf8Accept; + int codep; + bool checking; + + public void Check(IByteBuffer buffer) + { + this.checking = true; + buffer.ForEachByte(this); + } + + public void Finish() + { + this.checking = false; + this.codep = 0; + if (this.state != Utf8Accept) + { + this.state = Utf8Accept; + ThrowCorruptedFrameException(); + } + } + + public bool Process(byte b) + { + byte type = Types[b & 0xFF]; + + this.codep = this.state != Utf8Accept ? b & 0x3f | this.codep << 6 : 0xff >> type & b; + + this.state = States[this.state + type]; + + if (this.state == Utf8Reject) + { + this.checking = false; + ThrowCorruptedFrameException(); + } + return true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowCorruptedFrameException() => throw new CorruptedFrameException("bytes are not UTF-8"); + + public bool IsChecking => this.checking; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameDecoder.cs new file mode 100644 index 000000000..e349f6602 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameDecoder.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + using static Buffers.ByteBufferUtil; + + public class WebSocket00FrameDecoder : ReplayingDecoder, IWebSocketFrameDecoder + { + public enum Void + { + // Empty state + } + + const int DefaultMaxFrameSize = 16384; + + readonly long maxFrameSize; + bool receivedClosingHandshake; + + public WebSocket00FrameDecoder() : this(DefaultMaxFrameSize) + { + } + + public WebSocket00FrameDecoder(int maxFrameSize) : base(default(Void)) + { + this.maxFrameSize = maxFrameSize; + } + + protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List output) + { + // Discard all data received if closing handshake was received before. + if (this.receivedClosingHandshake) + { + input.SkipBytes(this.ActualReadableBytes); + return; + } + + // Decode a frame otherwise. + byte type = input.ReadByte(); + WebSocketFrame frame; + if ((type & 0x80) == 0x80) + { + // If the MSB on type is set, decode the frame length + frame = this.DecodeBinaryFrame(context, type, input); + } + else + { + // Decode a 0xff terminated UTF-8 string + frame = this.DecodeTextFrame(context, input); + } + + if (frame != null) + { + output.Add(frame); + } + } + + WebSocketFrame DecodeBinaryFrame(IChannelHandlerContext ctx, byte type, IByteBuffer buffer) + { + long frameSize = 0; + int lengthFieldSize = 0; + byte b; + do + { + b = buffer.ReadByte(); + frameSize <<= 7; + frameSize |= (uint)(b & 0x7f); + if (frameSize > this.maxFrameSize) + { + throw new TooLongFrameException(nameof(WebSocket00FrameDecoder)); + } + lengthFieldSize++; + if (lengthFieldSize > 8) + { + // Perhaps a malicious peer? + throw new TooLongFrameException(nameof(WebSocket00FrameDecoder)); + } + } while ((b & 0x80) == 0x80); + + if (type == 0xFF && frameSize == 0) + { + this.receivedClosingHandshake = true; + return new CloseWebSocketFrame(); + } + IByteBuffer payload = ReadBytes(ctx.Allocator, buffer, (int)frameSize); + return new BinaryWebSocketFrame(payload); + } + + WebSocketFrame DecodeTextFrame(IChannelHandlerContext ctx, IByteBuffer buffer) + { + int ridx = buffer.ReaderIndex; + int rbytes = this.ActualReadableBytes; + int delimPos = buffer.IndexOf(ridx, ridx + rbytes, 0xFF); + if (delimPos == -1) + { + // Frame delimiter (0xFF) not found + if (rbytes > this.maxFrameSize) + { + // Frame length exceeded the maximum + throw new TooLongFrameException(nameof(WebSocket00FrameDecoder)); + } + else + { + // Wait until more data is received + return null; + } + } + + int frameSize = delimPos - ridx; + if (frameSize > this.maxFrameSize) + { + throw new TooLongFrameException(nameof(WebSocket00FrameDecoder)); + } + + IByteBuffer binaryData = ReadBytes(ctx.Allocator, buffer, frameSize); + buffer.SkipBytes(1); + + int ffDelimPos = binaryData.IndexOf(binaryData.ReaderIndex, binaryData.WriterIndex, 0xFF); + if (ffDelimPos >= 0) + { + binaryData.Release(); + throw new ArgumentException("a text frame should not contain 0xFF."); + } + + return new TextWebSocketFrame(binaryData); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameEncoder.cs new file mode 100644 index 000000000..c90543f87 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket00FrameEncoder.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocket00FrameEncoder : MessageToMessageEncoder, IWebSocketFrameEncoder + { + static readonly IByteBuffer _0X00 = Unpooled.UnreleasableBuffer( + Unpooled.DirectBuffer(1, 1).WriteByte(0x00)); + + static readonly IByteBuffer _0XFF = Unpooled.UnreleasableBuffer( + Unpooled.DirectBuffer(1, 1).WriteByte(0xFF)); + + static readonly IByteBuffer _0XFF_0X00 = Unpooled.UnreleasableBuffer( + Unpooled.DirectBuffer(2, 2).WriteByte(0xFF).WriteByte(0x00)); + + public override bool IsSharable => true; + + protected override void Encode(IChannelHandlerContext context, WebSocketFrame message, List output) + { + if (message is TextWebSocketFrame) + { + // Text frame + IByteBuffer data = message.Content; + + output.Add(_0X00.Duplicate()); + output.Add(data.Retain()); + output.Add(_0XFF.Duplicate()); + } + else if (message is CloseWebSocketFrame) + { + // Close frame, needs to call duplicate to allow multiple writes. + // See https://github.com/netty/netty/issues/2768 + output.Add(_0XFF_0X00.Duplicate()); + } + else + { + // Binary frame + IByteBuffer data = message.Content; + int dataLen = data.ReadableBytes; + + IByteBuffer buf = context.Allocator.Buffer(5); + bool release = true; + try + { + // Encode type. + buf.WriteByte(0x80); + + // Encode length. + int b1 = dataLen.RightUShift(28) & 0x7F; + int b2 = dataLen.RightUShift(14) & 0x7F; + int b3 = dataLen.RightUShift(7) & 0x7F; + int b4 = dataLen & 0x7F; + if (b1 == 0) + { + if (b2 == 0) + { + if (b3 == 0) + { + buf.WriteByte(b4); + } + else + { + buf.WriteByte(b3 | 0x80); + buf.WriteByte(b4); + } + } + else + { + buf.WriteByte(b2 | 0x80); + buf.WriteByte(b3 | 0x80); + buf.WriteByte(b4); + } + } + else + { + buf.WriteByte(b1 | 0x80); + buf.WriteByte(b2 | 0x80); + buf.WriteByte(b3 | 0x80); + buf.WriteByte(b4); + } + + // Encode binary data. + output.Add(buf); + output.Add(data.Retain()); + release = false; + } + finally + { + if (release) + { + buf.Release(); + } + } + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameDecoder.cs new file mode 100644 index 000000000..a7be397c4 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameDecoder.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + public class WebSocket07FrameDecoder : WebSocket08FrameDecoder + { + public WebSocket07FrameDecoder(bool expectMaskedFrames, bool allowExtensions, int maxFramePayloadLength) + : this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocket07FrameDecoder(bool expectMaskedFrames, bool allowExtensions, int maxFramePayloadLength, bool allowMaskMismatch) + : base(expectMaskedFrames, allowExtensions, maxFramePayloadLength, allowMaskMismatch) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameEncoder.cs new file mode 100644 index 000000000..3a6fc077c --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket07FrameEncoder.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + public class WebSocket07FrameEncoder : WebSocket08FrameEncoder + { + public WebSocket07FrameEncoder(bool maskPayload) + : base(maskPayload) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameDecoder.cs new file mode 100644 index 000000000..418b8d252 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameDecoder.cs @@ -0,0 +1,467 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable UseStringInterpolation +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using System.Runtime.CompilerServices; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Internal.Logging; + using DotNetty.Transport.Channels; + + using static Buffers.ByteBufferUtil; + + public class WebSocket08FrameDecoder : ByteToMessageDecoder, IWebSocketFrameDecoder + { + enum State + { + ReadingFirst, + ReadingSecond, + ReadingSize, + MaskingKey, + Payload, + Corrupt + } + + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + const byte OpcodeCont = 0x0; + const byte OpcodeText = 0x1; + const byte OpcodeBinary = 0x2; + const byte OpcodeClose = 0x8; + const byte OpcodePing = 0x9; + const byte OpcodePong = 0xA; + + readonly long maxFramePayloadLength; + readonly bool allowExtensions; + readonly bool expectMaskedFrames; + readonly bool allowMaskMismatch; + + int fragmentedFramesCount; + bool frameFinalFlag; + bool frameMasked; + int frameRsv; + int frameOpcode; + long framePayloadLength; + byte[] maskingKey; + int framePayloadLen1; + bool receivedClosingHandshake; + State state = State.ReadingFirst; + + public WebSocket08FrameDecoder(bool expectMaskedFrames, bool allowExtensions, int maxFramePayloadLength) + : this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocket08FrameDecoder(bool expectMaskedFrames, bool allowExtensions, int maxFramePayloadLength, bool allowMaskMismatch) + { + this.expectMaskedFrames = expectMaskedFrames; + this.allowMaskMismatch = allowMaskMismatch; + this.allowExtensions = allowExtensions; + this.maxFramePayloadLength = maxFramePayloadLength; + } + + protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List output) + { + // Discard all data received if closing handshake was received before. + if (this.receivedClosingHandshake) + { + input.SkipBytes(this.ActualReadableBytes); + return; + } + + switch (this.state) + { + case State.ReadingFirst: + if (!input.IsReadable()) + { + return; + } + + this.framePayloadLength = 0; + + // FIN, RSV, OPCODE + byte b = input.ReadByte(); + this.frameFinalFlag = (b & 0x80) != 0; + this.frameRsv = (b & 0x70) >> 4; + this.frameOpcode = b & 0x0F; + + if (Logger.DebugEnabled) + { + Logger.Debug("Decoding WebSocket Frame opCode={}", this.frameOpcode); + } + + this.state = State.ReadingSecond; + goto case State.ReadingSecond; + case State.ReadingSecond: + if (!input.IsReadable()) + { + return; + } + + // MASK, PAYLOAD LEN 1 + b = input.ReadByte(); + this.frameMasked = (b & 0x80) != 0; + this.framePayloadLen1 = b & 0x7F; + + if (this.frameRsv != 0 && !this.allowExtensions) + { + this.ProtocolViolation(context, $"RSV != 0 and no extension negotiated, RSV:{this.frameRsv}"); + return; + } + + if (!this.allowMaskMismatch && this.expectMaskedFrames != this.frameMasked) + { + this.ProtocolViolation(context, "received a frame that is not masked as expected"); + return; + } + + // control frame (have MSB in opcode set) + if (this.frameOpcode > 7) + { + // control frames MUST NOT be fragmented + if (!this.frameFinalFlag) + { + this.ProtocolViolation(context, "fragmented control frame"); + return; + } + + // control frames MUST have payload 125 octets or less + if (this.framePayloadLen1 > 125) + { + this.ProtocolViolation(context, "control frame with payload length > 125 octets"); + return; + } + + // check for reserved control frame opcodes + if (!(this.frameOpcode == OpcodeClose + || this.frameOpcode == OpcodePing + || this.frameOpcode == OpcodePong)) + { + this.ProtocolViolation(context, $"control frame using reserved opcode {this.frameOpcode}"); + return; + } + + // close frame : if there is a body, the first two bytes of the + // body MUST be a 2-byte unsigned integer (in network byte + // order) representing a getStatus code + if (this.frameOpcode == 8 && this.framePayloadLen1 == 1) + { + this.ProtocolViolation(context, "received close control frame with payload len 1"); + return; + } + } + else // data frame + { + // check for reserved data frame opcodes + if (!(this.frameOpcode == OpcodeCont || this.frameOpcode == OpcodeText + || this.frameOpcode == OpcodeBinary)) + { + this.ProtocolViolation(context, $"data frame using reserved opcode {this.frameOpcode}"); + return; + } + + // check opcode vs message fragmentation state 1/2 + if (this.fragmentedFramesCount == 0 && this.frameOpcode == OpcodeCont) + { + this.ProtocolViolation(context, "received continuation data frame outside fragmented message"); + return; + } + + // check opcode vs message fragmentation state 2/2 + if (this.fragmentedFramesCount != 0 && this.frameOpcode != OpcodeCont && this.frameOpcode != OpcodePing) + { + this.ProtocolViolation(context, "received non-continuation data frame while inside fragmented message"); + return; + } + } + + this.state = State.ReadingSize; + goto case State.ReadingSize; + case State.ReadingSize: + // Read frame payload length + if (this.framePayloadLen1 == 126) + { + if (input.ReadableBytes < 2) + { + return; + } + this.framePayloadLength = input.ReadUnsignedShort(); + if (this.framePayloadLength < 126) + { + this.ProtocolViolation(context, "invalid data frame length (not using minimal length encoding)"); + return; + } + } + else if (this.framePayloadLen1 == 127) + { + if (input.ReadableBytes < 8) + { + return; + } + this.framePayloadLength = input.ReadLong(); + // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe + // just check if it's negative? + + if (this.framePayloadLength < 65536) + { + this.ProtocolViolation(context, "invalid data frame length (not using minimal length encoding)"); + return; + } + } + else + { + this.framePayloadLength = this.framePayloadLen1; + } + + if (this.framePayloadLength > this.maxFramePayloadLength) + { + this.ProtocolViolation(context, $"Max frame length of {this.maxFramePayloadLength} has been exceeded."); + return; + } + + if (Logger.DebugEnabled) + { + Logger.Debug("Decoding WebSocket Frame length={}", this.framePayloadLength); + } + + this.state = State.MaskingKey; + goto case State.MaskingKey; + case State.MaskingKey: + if (this.frameMasked) + { + if (input.ReadableBytes < 4) + { + return; + } + if (this.maskingKey == null) + { + this.maskingKey = new byte[4]; + } + input.ReadBytes(this.maskingKey); + } + this.state = State.Payload; + goto case State.Payload; + case State.Payload: + if (input.ReadableBytes < this.framePayloadLength) + { + return; + } + + IByteBuffer payloadBuffer = null; + try + { + payloadBuffer = ReadBytes(context.Allocator, input, ToFrameLength(this.framePayloadLength)); + + // Now we have all the data, the next checkpoint must be the next + // frame + this.state = State.ReadingFirst; + + // Unmask data if needed + if (this.frameMasked) + { + this.Unmask(payloadBuffer); + } + + // Processing ping/pong/close frames because they cannot be + // fragmented + if (this.frameOpcode == OpcodePing) + { + output.Add(new PingWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + if (this.frameOpcode == OpcodePong) + { + output.Add(new PongWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + if (this.frameOpcode == OpcodeClose) + { + this.receivedClosingHandshake = true; + this.CheckCloseFrameBody(context, payloadBuffer); + output.Add(new CloseWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + + // Processing for possible fragmented messages for text and binary + // frames + if (this.frameFinalFlag) + { + // Final frame of the sequence. Apparently ping frames are + // allowed in the middle of a fragmented message + if (this.frameOpcode != OpcodePing) + { + this.fragmentedFramesCount = 0; + } + } + else + { + // Increment counter + this.fragmentedFramesCount++; + } + + // Return the frame + if (this.frameOpcode == OpcodeText) + { + output.Add(new TextWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + else if (this.frameOpcode == OpcodeBinary) + { + output.Add(new BinaryWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + else if (this.frameOpcode == OpcodeCont) + { + output.Add(new ContinuationWebSocketFrame(this.frameFinalFlag, this.frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + else + { + ThrowNotSupportedException(this.frameOpcode); + break; + } + } + finally + { + payloadBuffer?.Release(); + } + case State.Corrupt: + if (input.IsReadable()) + { + // If we don't keep reading Netty will throw an exception saying + // we can't return null if no bytes read and state not changed. + input.ReadByte(); + } + return; + default: + throw new Exception("Shouldn't reach here."); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowNotSupportedException(int frameOpcode) + { + throw GetNotSupportedException(); + + NotSupportedException GetNotSupportedException() + { + return new NotSupportedException($"Cannot decode web socket frame with opcode: {frameOpcode}"); + } + } + + void Unmask(IByteBuffer frame) + { + int i = frame.ReaderIndex; + int end = frame.WriterIndex; + + int intMask = (this.maskingKey[0] << 24) + | (this.maskingKey[1] << 16) + | (this.maskingKey[2] << 8) + | this.maskingKey[3]; + + for (; i + 3 < end; i += 4) + { + int unmasked = frame.GetInt(i) ^ intMask; + frame.SetInt(i, unmasked); + } + for (; i < end; i++) + { + frame.SetByte(i, frame.GetByte(i) ^ this.maskingKey[i % 4]); + } + } + + void ProtocolViolation(IChannelHandlerContext ctx, string reason) => this.ProtocolViolation(ctx, new CorruptedFrameException(reason)); + + void ProtocolViolation(IChannelHandlerContext ctx, CorruptedFrameException ex) + { + this.state = State.Corrupt; + if (ctx.Channel.Active) + { + object closeMessage; + if (this.receivedClosingHandshake) + { + closeMessage = Unpooled.Empty; + } + else + { + closeMessage = new CloseWebSocketFrame(1002, null); + } + ctx.WriteAndFlushAsync(closeMessage) + .ContinueWith((t, c) => ((IChannel)c).CloseAsync(), + ctx.Channel, TaskContinuationOptions.ExecuteSynchronously); + } + throw ex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int ToFrameLength(long l) + { + if (l > int.MaxValue) + { + ThrowTooLongFrameException(l); + } + return (int)l; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowTooLongFrameException(long l) + { + throw GetTooLongFrameException(); + + TooLongFrameException GetTooLongFrameException() + { + return new TooLongFrameException(string.Format("Length: {0}", l)); + } + } + + protected void CheckCloseFrameBody(IChannelHandlerContext ctx, IByteBuffer buffer) + { + if (buffer == null || !buffer.IsReadable()) + { + return; + } + if (buffer.ReadableBytes == 1) + { + this.ProtocolViolation(ctx, "Invalid close frame body"); + } + + // Save reader index + int idx = buffer.ReaderIndex; + buffer.SetReaderIndex(0); + + // Must have 2 byte integer within the valid range + int statusCode = buffer.ReadShort(); + if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006 + || statusCode >= 1012 && statusCode <= 2999) + { + this.ProtocolViolation(ctx, $"Invalid close frame getStatus code: {statusCode}"); + } + + // May have UTF-8 message + if (buffer.IsReadable()) + { + try + { + new Utf8Validator().Check(buffer); + } + catch (CorruptedFrameException ex) + { + this.ProtocolViolation(ctx, ex); + } + } + + // Restore reader index + buffer.SetReaderIndex(idx); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameEncoder.cs new file mode 100644 index 000000000..46769f166 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket08FrameEncoder.cs @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using System.Runtime.CompilerServices; + using DotNetty.Buffers; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocket08FrameEncoder : MessageToMessageEncoder, IWebSocketFrameEncoder + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + const byte OpcodeCont = 0x0; + const byte OpcodeText = 0x1; + const byte OpcodeBinary = 0x2; + const byte OpcodeClose = 0x8; + const byte OpcodePing = 0x9; + const byte OpcodePong = 0xA; + + /// + // The size threshold for gathering writes. Non-Masked messages bigger than this size will be be sent fragmented as + // a header and a content ByteBuf whereas messages smaller than the size will be merged into a single buffer and + // sent at once. + // Masked messages will always be sent at once. + // + const int GatheringWriteThreshold = 1024; + + readonly bool maskPayload; + readonly Random random; + + public WebSocket08FrameEncoder(bool maskPayload) + { + this.maskPayload = maskPayload; + this.random = new Random(); + } + + protected override unsafe void Encode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + IByteBuffer data = msg.Content; + var mask = stackalloc byte[4]; + + byte opcode = 0; + if (msg is TextWebSocketFrame) + { + opcode = OpcodeText; + } + else if (msg is PingWebSocketFrame) + { + opcode = OpcodePing; + } + else if (msg is PongWebSocketFrame) + { + opcode = OpcodePong; + } + else if (msg is CloseWebSocketFrame) + { + opcode = OpcodeClose; + } + else if (msg is BinaryWebSocketFrame) + { + opcode = OpcodeBinary; + } + else if (msg is ContinuationWebSocketFrame) + { + opcode = OpcodeCont; + } + else + { + ThrowNotSupportedException(msg); + } + + int length = data.ReadableBytes; + + if (Logger.DebugEnabled) + { + Logger.Debug($"Encoding WebSocket Frame opCode={opcode} length={length}"); + } + + int b0 = 0; + if (msg.IsFinalFragment) + { + b0 |= 1 << 7; + } + b0 |= msg.Rsv % 8 << 4; + b0 |= opcode % 128; + + if (opcode == OpcodePing && length > 125) + { + ThrowTooLongFrameException(length); + } + + bool release = true; + IByteBuffer buf = null; + + try + { + int maskLength = this.maskPayload ? 4 : 0; + if (length <= 125) + { + int size = 2 + maskLength; + if (this.maskPayload || length <= GatheringWriteThreshold) + { + size += length; + } + buf = ctx.Allocator.Buffer(size); + buf.WriteByte(b0); + byte b = (byte)(this.maskPayload ? 0x80 | (byte)length : (byte)length); + buf.WriteByte(b); + } + else if (length <= 0xFFFF) + { + int size = 4 + maskLength; + if (this.maskPayload || length <= GatheringWriteThreshold) + { + size += length; + } + buf = ctx.Allocator.Buffer(size); + buf.WriteByte(b0); + buf.WriteByte(this.maskPayload ? 0xFE : 126); + buf.WriteByte(length.RightUShift(8) & 0xFF); + buf.WriteByte(length & 0xFF); + } + else + { + int size = 10 + maskLength; + if (this.maskPayload || length <= GatheringWriteThreshold) + { + size += length; + } + buf = ctx.Allocator.Buffer(size); + buf.WriteByte(b0); + buf.WriteByte(this.maskPayload ? 0xFF : 127); + buf.WriteLong(length); + } + + // Write payload + if (this.maskPayload) + { + int intMask = (this.random.Next() * int.MaxValue); + + // Mask bytes in BE + uint unsignedValue = (uint)intMask; + *mask = (byte)(unsignedValue >> 24); + *(mask + 1) = (byte)(unsignedValue >> 16); + *(mask + 2) = (byte)(unsignedValue >> 8); + *(mask + 3) = (byte)unsignedValue; + + // Mask in BE + buf.WriteInt(intMask); + + int counter = 0; + int i = data.ReaderIndex; + int end = data.WriterIndex; + + for (; i + 3 < end; i += 4) + { + int intData = data.GetInt(i); + buf.WriteInt(intData ^ intMask); + } + for (; i < end; i++) + { + byte byteData = data.GetByte(i); + buf.WriteByte(byteData ^ mask[counter++ % 4]); + } + output.Add(buf); + } + else + { + if (buf.WritableBytes >= data.ReadableBytes) + { + // merge buffers as this is cheaper then a gathering write if the payload is small enough + buf.WriteBytes(data); + output.Add(buf); + } + else + { + output.Add(buf); + output.Add(data.Retain()); + } + } + release = false; + } + finally + { + if (release) + { + buf?.Release(); + } + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowTooLongFrameException(int length) + { + throw GetTooLongFrameException(); + + TooLongFrameException GetTooLongFrameException() + { + return new TooLongFrameException(string.Format("invalid payload for PING (payload length must be <= 125, was {0}", length)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowNotSupportedException(WebSocketFrame msg) + { + throw GetNotSupportedException(); + + NotSupportedException GetNotSupportedException() + { + return new NotSupportedException(string.Format("Cannot encode frame of type: {0}", msg.GetType().Name)); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameDecoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameDecoder.cs new file mode 100644 index 000000000..cac28092b --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameDecoder.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + public class WebSocket13FrameDecoder : WebSocket08FrameDecoder + { + public WebSocket13FrameDecoder(bool expectMaskedFrames, bool allowExtensions, int maxFramePayloadLength) + : this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocket13FrameDecoder(bool expectMaskedFrames, bool allowExtensions,int maxFramePayloadLength, + bool allowMaskMismatch) + : base(expectMaskedFrames, allowExtensions, maxFramePayloadLength, allowMaskMismatch) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameEncoder.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameEncoder.cs new file mode 100644 index 000000000..d2f79c9b3 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocket13FrameEncoder.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + public class WebSocket13FrameEncoder : WebSocket08FrameEncoder + { + public WebSocket13FrameEncoder(bool maskPayload) + : base(maskPayload) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketChunkedInput.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketChunkedInput.cs new file mode 100644 index 000000000..4826f7fc6 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketChunkedInput.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Diagnostics.Contracts; + using DotNetty.Buffers; + using DotNetty.Handlers.Streams; + + public sealed class WebSocketChunkedInput : IChunkedInput + { + readonly IChunkedInput input; + readonly int rsv; + + public WebSocketChunkedInput(IChunkedInput input) + : this(input, 0) + { + } + + public WebSocketChunkedInput(IChunkedInput input, int rsv) + { + Contract.Requires(input != null); + + this.input = input; + this.rsv = rsv; + } + + public bool IsEndOfInput => this.input.IsEndOfInput; + + public void Close() => this.input.Close(); + + public WebSocketFrame ReadChunk(IByteBufferAllocator allocator) + { + IByteBuffer buf = this.input.ReadChunk(allocator); + return buf != null ? new ContinuationWebSocketFrame(this.input.IsEndOfInput, this.rsv, buf) : null; + } + + public long Length => this.input.Length; + + public long Progress => this.input.Progress; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker.cs new file mode 100644 index 000000000..a2b32ff67 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker.cs @@ -0,0 +1,425 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public abstract class WebSocketClientHandshaker + { + static readonly ClosedChannelException DefaultClosedChannelException = new ClosedChannelException(); + + static readonly string HttpSchemePrefix = HttpScheme.Http + "://"; + static readonly string HttpsSchemePrefix = HttpScheme.Https + "://"; + + readonly Uri uri; + + readonly WebSocketVersion version; + + volatile bool handshakeComplete; + + readonly string expectedSubprotocol; + + volatile string actualSubprotocol; + + protected readonly HttpHeaders CustomHeaders; + + readonly int maxFramePayloadLength; + + protected WebSocketClientHandshaker(Uri uri, WebSocketVersion version, string subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength) + { + this.uri = uri; + this.version = version; + this.expectedSubprotocol = subprotocol; + this.CustomHeaders = customHeaders; + this.maxFramePayloadLength = maxFramePayloadLength; + } + + public Uri Uri => this.uri; + + public WebSocketVersion Version => this.version; + + public int MaxFramePayloadLength => this.maxFramePayloadLength; + + public bool IsHandshakeComplete => this.handshakeComplete; + + void SetHandshakeComplete() => this.handshakeComplete = true; + + public string ExpectedSubprotocol => this.expectedSubprotocol; + + public string ActualSubprotocol + { + get => this.actualSubprotocol; + private set => this.actualSubprotocol = value; + } + + public Task HandshakeAsync(IChannel channel) + { + IFullHttpRequest request = this.NewHandshakeRequest(); + + var decoder = channel.Pipeline.Get(); + if (decoder == null) + { + var codec = channel.Pipeline.Get(); + if (codec == null) + { + return TaskEx.FromException(new InvalidOperationException("ChannelPipeline does not contain a HttpResponseDecoder or HttpClientCodec")); + } + } + + var completion = new TaskCompletionSource(); + channel.WriteAndFlushAsync(request).ContinueWith((t, state) => + { + var tcs = (TaskCompletionSource)state; + switch (t.Status) + { + case TaskStatus.RanToCompletion: + IChannelPipeline p = channel.Pipeline; + IChannelHandlerContext ctx = p.Context() ?? p.Context(); + if (ctx == null) + { + tcs.TrySetException(new InvalidOperationException("ChannelPipeline does not contain a HttpRequestEncoder or HttpClientCodec")); + return; + } + + p.AddAfter(ctx.Name, "ws-encoder", this.NewWebSocketEncoder()); + tcs.TryComplete(); + break; + case TaskStatus.Canceled: + tcs.TrySetCanceled(); + break; + case TaskStatus.Faulted: + tcs.TryUnwrap(t.Exception); + break; + default: + throw new ArgumentOutOfRangeException(); + } + }, + completion, + TaskContinuationOptions.ExecuteSynchronously); + + return completion.Task; + } + + protected internal abstract IFullHttpRequest NewHandshakeRequest(); + + public void FinishHandshake(IChannel channel, IFullHttpResponse response) + { + this.Verify(response); + + // Verify the subprotocol that we received from the server. + // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol + string receivedProtocol = null; + if (response.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out ICharSequence headerValue)) + { + receivedProtocol = headerValue.ToString().Trim(); + } + + string expectedProtocol = this.expectedSubprotocol ?? ""; + bool protocolValid = false; + + if (expectedProtocol.Length == 0 && receivedProtocol == null) + { + // No subprotocol required and none received + protocolValid = true; + this.ActualSubprotocol = this.expectedSubprotocol; // null or "" - we echo what the user requested + } + else if (expectedProtocol.Length > 0 && !string.IsNullOrEmpty(receivedProtocol)) + { + // We require a subprotocol and received one -> verify it + foreach (string protocol in expectedProtocol.Split(',')) + { + if (protocol.Trim().Equals(receivedProtocol)) + { + protocolValid = true; + this.ActualSubprotocol = receivedProtocol; + break; + } + } + } // else mixed cases - which are all errors + + if (!protocolValid) + { + throw new WebSocketHandshakeException($"Invalid subprotocol. Actual: {receivedProtocol}. Expected one of: {this.expectedSubprotocol}"); + } + + this.SetHandshakeComplete(); + + IChannelPipeline p = channel.Pipeline; + // Remove decompressor from pipeline if its in use + var decompressor = p.Get(); + if (decompressor != null) + { + p.Remove(decompressor); + } + + // Remove aggregator if present before + var aggregator = p.Get(); + if (aggregator != null) + { + p.Remove(aggregator); + } + + IChannelHandlerContext ctx = p.Context(); + if (ctx == null) + { + ctx = p.Context(); + if (ctx == null) + { + throw new InvalidOperationException("ChannelPipeline does not contain a HttpRequestEncoder or HttpClientCodec"); + } + + var codec = (HttpClientCodec)ctx.Handler; + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + codec.RemoveOutboundHandler(); + + p.AddAfter(ctx.Name, "ws-decoder", this.NewWebSocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.EventLoop.Execute(() => p.Remove(codec)); + } + else + { + if (p.Get() != null) + { + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + p.Remove(); + } + + IChannelHandlerContext context = ctx; + p.AddAfter(context.Name, "ws-decoder", this.NewWebSocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.EventLoop.Execute(() => p.Remove(context.Handler)); + } + } + + public Task ProcessHandshakeAsync(IChannel channel, IHttpResponse response) + { + var completionSource = new TaskCompletionSource(); + if (response is IFullHttpResponse res) + { + try + { + this.FinishHandshake(channel, res); + completionSource.TryComplete(); + } + catch (Exception cause) + { + completionSource.TrySetException(cause); + } + } + else + { + IChannelPipeline p = channel.Pipeline; + IChannelHandlerContext ctx = p.Context(); + if (ctx == null) + { + ctx = p.Context(); + if (ctx == null) + { + completionSource.TrySetException(new InvalidOperationException("ChannelPipeline does not contain a HttpResponseDecoder or HttpClientCodec")); + } + } + else + { + // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more + // then enough for the websockets handshake payload. + // + // TODO: Make handshake work without HttpObjectAggregator at all. + const string AggregatorName = "httpAggregator"; + p.AddAfter(ctx.Name, AggregatorName, new HttpObjectAggregator(8192)); + p.AddAfter(AggregatorName, "handshaker", new Handshaker(this, channel, completionSource)); + try + { + ctx.FireChannelRead(ReferenceCountUtil.Retain(response)); + } + catch (Exception cause) + { + completionSource.TrySetException(cause); + } + } + } + + return completionSource.Task; + } + + sealed class Handshaker : SimpleChannelInboundHandler + { + readonly WebSocketClientHandshaker clientHandshaker; + readonly IChannel channel; + readonly TaskCompletionSource completion; + + public Handshaker(WebSocketClientHandshaker clientHandshaker, IChannel channel, TaskCompletionSource completion) + { + this.clientHandshaker = clientHandshaker; + this.channel = channel; + this.completion = completion; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, IFullHttpResponse msg) + { + // Remove and do the actual handshake + ctx.Channel.Pipeline.Remove(this); + try + { + this.clientHandshaker.FinishHandshake(this.channel, msg); + this.completion.TryComplete(); + } + catch (Exception cause) + { + this.completion.TrySetException(cause); + } + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + // Remove ourself and fail the handshake promise. + ctx.Channel.Pipeline.Remove(this); + this.completion.TrySetException(cause); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + // Fail promise if Channel was closed + this.completion.TrySetException(DefaultClosedChannelException); + ctx.FireChannelInactive(); + } + } + + protected abstract void Verify(IFullHttpResponse response); + + protected internal abstract IWebSocketFrameDecoder NewWebSocketDecoder(); + + protected internal abstract IWebSocketFrameEncoder NewWebSocketEncoder(); + + public Task CloseAsync(IChannel channel, CloseWebSocketFrame frame) + { + Contract.Requires(channel != null); + return channel.WriteAndFlushAsync(frame); + } + + internal static string RawPath(Uri wsUrl) => wsUrl.IsAbsoluteUri ? wsUrl.PathAndQuery : "/"; + + internal static string WebsocketHostValue(Uri wsUrl) + { + string scheme; + Uri uri; + if (wsUrl.IsAbsoluteUri) + { + scheme = wsUrl.Scheme; + uri = wsUrl; + } + else + { + scheme = null; + uri = AbsoluteUri(wsUrl); + } + + int port = OriginalPort(uri); + if (port == -1) + { + return uri.Host; + } + string host = uri.Host; + if (port == HttpScheme.Http.Port) + { + return HttpScheme.Http.Name.ContentEquals(scheme) + || WebSocketScheme.WS.Name.ContentEquals(scheme) + ? host : NetUtil.ToSocketAddressString(host, port); + } + if (port == HttpScheme.Https.Port) + { + return HttpScheme.Https.Name.ToString().Equals(scheme) + || WebSocketScheme.WSS.Name.ToString().Equals(scheme) + ? host : NetUtil.ToSocketAddressString(host, port); + } + + // if the port is not standard (80/443) its needed to add the port to the header. + // See http://tools.ietf.org/html/rfc6454#section-6.2 + return NetUtil.ToSocketAddressString(host, port); + } + + internal static string WebsocketOriginValue(Uri wsUrl) + { + string scheme; + Uri uri; + if (wsUrl.IsAbsoluteUri) + { + scheme = wsUrl.Scheme; + uri = wsUrl; + } + else + { + scheme = null; + uri = AbsoluteUri(wsUrl); + } + + string schemePrefix; + int port = uri.Port; + int defaultPort; + + if (WebSocketScheme.WSS.Name.ContentEquals(scheme) + || HttpScheme.Https.Name.ContentEquals(scheme) + || (scheme == null && port == WebSocketScheme.WSS.Port)) + { + + schemePrefix = HttpsSchemePrefix; + defaultPort = WebSocketScheme.WSS.Port; + } + else + { + schemePrefix = HttpSchemePrefix; + defaultPort = WebSocketScheme.WS.Port; + } + + // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI") + string host = uri.Host.ToLower(); + + if (port != defaultPort && port != -1) + { + // if the port is not standard (80/443) its needed to add the port to the header. + // See http://tools.ietf.org/html/rfc6454#section-6.2 + return schemePrefix + NetUtil.ToSocketAddressString(host, port); + } + return schemePrefix + host; + } + + static Uri AbsoluteUri(Uri uri) + { + if (uri.IsAbsoluteUri) + { + return uri; + } + + string relativeUri = uri.OriginalString; + return new Uri(relativeUri.StartsWith("//") + ? HttpScheme.Http + ":" + relativeUri + : HttpSchemePrefix + relativeUri); + } + + static int OriginalPort(Uri uri) + { + int index = uri.Scheme.Length + 3 + uri.Host.Length; + + if (index < uri.OriginalString.Length + && uri.OriginalString[index] == ':') + { + return uri.Port; + } + return -1; + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker00.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker00.cs new file mode 100644 index 000000000..85e59f40b --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker00.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Runtime.CompilerServices; + using DotNetty.Buffers; + using DotNetty.Common.Internal; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker00 : WebSocketClientHandshaker + { + static readonly AsciiString Websocket = AsciiString.Cached("WebSocket"); + + IByteBuffer expectedChallengeResponseBytes; + + public WebSocketClientHandshaker00(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength) + : base(webSocketUrl, version, subprotocol, customHeaders, maxFramePayloadLength) + { + } + + protected internal override unsafe IFullHttpRequest NewHandshakeRequest() + { + // Make keys + int spaces1 = WebSocketUtil.RandomNumber(1, 12); + int spaces2 = WebSocketUtil.RandomNumber(1, 12); + + int max1 = int.MaxValue / spaces1; + int max2 = int.MaxValue / spaces2; + + int number1 = WebSocketUtil.RandomNumber(0, max1); + int number2 = WebSocketUtil.RandomNumber(0, max2); + + int product1 = number1 * spaces1; + int product2 = number2 * spaces2; + + string key1 = Convert.ToString(product1); + string key2 = Convert.ToString(product2); + + key1 = InsertRandomCharacters(key1); + key2 = InsertRandomCharacters(key2); + + key1 = InsertSpaces(key1, spaces1); + key2 = InsertSpaces(key2, spaces2); + + byte[] key3 = WebSocketUtil.RandomBytes(8); + var challenge = new byte[16]; + fixed (byte* bytes = challenge) + { + Unsafe.WriteUnaligned(bytes, number1); + Unsafe.WriteUnaligned(bytes + 4, number2); + PlatformDependent.CopyMemory(key3, 0, bytes + 8, 8); + } + + this.expectedChallengeResponseBytes = Unpooled.WrappedBuffer(WebSocketUtil.Md5(challenge)); + + // Get path + Uri wsUrl = this.Uri; + string path = RawPath(wsUrl); + + // Format request + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, path); + HttpHeaders headers = request.Headers; + headers.Add(HttpHeaderNames.Upgrade, Websocket) + .Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade) + .Add(HttpHeaderNames.Host, WebsocketHostValue(wsUrl)) + .Add(HttpHeaderNames.Origin, WebsocketOriginValue(wsUrl)) + .Add(HttpHeaderNames.SecWebsocketKey1, key1) + .Add(HttpHeaderNames.SecWebsocketKey2, key2); + + string expectedSubprotocol = this.ExpectedSubprotocol; + if (!string.IsNullOrEmpty(expectedSubprotocol)) + { + headers.Add(HttpHeaderNames.SecWebsocketProtocol, expectedSubprotocol); + } + + if (this.CustomHeaders != null) + { + headers.Add(this.CustomHeaders); + } + + // Set Content-Length to workaround some known defect. + // See also: http://www.ietf.org/mail-archive/web/hybi/current/msg02149.html + headers.Set(HttpHeaderNames.ContentLength, key3.Length); + request.Content.WriteBytes(key3); + return request; + } + + protected override void Verify(IFullHttpResponse response) + { + if (!response.Status.Equals(HttpResponseStatus.SwitchingProtocols)) + { + throw new WebSocketHandshakeException($"Invalid handshake response getStatus: {response.Status}"); + } + + HttpHeaders headers = response.Headers; + + if (!headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence upgrade) + ||!Websocket.ContentEqualsIgnoreCase(upgrade)) + { + throw new WebSocketHandshakeException($"Invalid handshake response upgrade: {upgrade}"); + } + + if (!headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true)) + { + headers.TryGet(HttpHeaderNames.Connection, out upgrade); + throw new WebSocketHandshakeException($"Invalid handshake response connection: {upgrade}"); + } + + IByteBuffer challenge = response.Content; + if (!challenge.Equals(this.expectedChallengeResponseBytes)) + { + throw new WebSocketHandshakeException("Invalid challenge"); + } + } + + static string InsertRandomCharacters(string key) + { + int count = WebSocketUtil.RandomNumber(1, 12); + + var randomChars = new char[count]; + int randCount = 0; + while (randCount < count) + { + int rand = unchecked((int)(WebSocketUtil.RandomNext() * 0x7e + 0x21)); + if (0x21 < rand && rand < 0x2f || 0x3a < rand && rand < 0x7e) + { + randomChars[randCount] = (char)rand; + randCount += 1; + } + } + + for (int i = 0; i < count; i++) + { + int split = WebSocketUtil.RandomNumber(0, key.Length); + string part1 = key.Substring(0, split); + string part2 = key.Substring(split); + key = part1 + randomChars[i] + part2; + } + + return key; + } + + static string InsertSpaces(string key, int spaces) + { + for (int i = 0; i < spaces; i++) + { + int split = WebSocketUtil.RandomNumber(1, key.Length - 1); + string part1 = key.Substring(0, split); + string part2 = key.Substring(split); + key = part1 + ' ' + part2; + } + + return key; + } + + protected internal override IWebSocketFrameDecoder NewWebSocketDecoder() => new WebSocket00FrameDecoder(this.MaxFramePayloadLength); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket00FrameEncoder(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker07.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker07.cs new file mode 100644 index 000000000..218fd4b2e --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker07.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Text; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker07 : WebSocketClientHandshaker + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + public static readonly string MagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + AsciiString expectedChallengeResponseString; + + readonly bool allowExtensions; + readonly bool performMasking; + readonly bool allowMaskMismatch; + + public WebSocketClientHandshaker07(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) + : this(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false) + { + } + + public WebSocketClientHandshaker07(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + bool performMasking, bool allowMaskMismatch) + : base(webSocketUrl, version, subprotocol, customHeaders, maxFramePayloadLength) + { + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected internal override IFullHttpRequest NewHandshakeRequest() + { + // Get path + Uri wsUrl = this.Uri; + string path = RawPath(wsUrl); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.RandomBytes(16); + string key = WebSocketUtil.Base64String(nonce); + + string acceptSeed = key + MagicGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + this.expectedChallengeResponseString = new AsciiString(WebSocketUtil.Base64String(sha1)); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 07 client handshake key: {}, expected response: {}", + key, this.expectedChallengeResponseString); + } + + // Format request + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, path); + HttpHeaders headers = request.Headers; + + headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket) + .Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade) + .Add(HttpHeaderNames.SecWebsocketKey, key) + .Add(HttpHeaderNames.Host, WebsocketHostValue(wsUrl)) + .Add(HttpHeaderNames.SecWebsocketOrigin, WebsocketOriginValue(wsUrl)); + + string expectedSubprotocol = this.ExpectedSubprotocol; + if (!string.IsNullOrEmpty(expectedSubprotocol)) + { + headers.Add(HttpHeaderNames.SecWebsocketProtocol, expectedSubprotocol); + } + + headers.Add(HttpHeaderNames.SecWebsocketVersion, "7"); + + if (this.CustomHeaders != null) + { + headers.Add(this.CustomHeaders); + } + return request; + } + + protected override void Verify(IFullHttpResponse response) + { + HttpResponseStatus status = HttpResponseStatus.SwitchingProtocols; + HttpHeaders headers = response.Headers; + + if (!response.Status.Equals(status)) + { + throw new WebSocketHandshakeException($"Invalid handshake response getStatus: {response.Status}"); + } + + if (headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence upgrade) + || !HttpHeaderValues.Websocket.ContentEqualsIgnoreCase(upgrade)) + { + throw new WebSocketHandshakeException($"Invalid handshake response upgrade: {upgrade}"); + } + + if (!headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true)) + { + headers.TryGet(HttpHeaderNames.Connection, out upgrade); + throw new WebSocketHandshakeException($"Invalid handshake response connection: {upgrade}"); + } + + if (headers.TryGet(HttpHeaderNames.SecWebsocketAccept, out ICharSequence accept) + || !accept.Equals(this.expectedChallengeResponseString)) + { + throw new WebSocketHandshakeException($"Invalid challenge. Actual: {accept}. Expected: {this.expectedChallengeResponseString}"); + } + } + + protected internal override IWebSocketFrameDecoder NewWebSocketDecoder() => new WebSocket07FrameDecoder( + false, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket07FrameEncoder(this.performMasking); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker08.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker08.cs new file mode 100644 index 000000000..8a05fb014 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker08.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Text; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker08 : WebSocketClientHandshaker + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + public static readonly string MagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + AsciiString expectedChallengeResponseString; + + readonly bool allowExtensions; + readonly bool performMasking; + readonly bool allowMaskMismatch; + + public WebSocketClientHandshaker08(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) + : this(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false) + { + } + + public WebSocketClientHandshaker08(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + bool performMasking, bool allowMaskMismatch) + : base(webSocketUrl, version, subprotocol, customHeaders, maxFramePayloadLength) + { + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected internal override IFullHttpRequest NewHandshakeRequest() + { + // Get path + Uri wsUrl = this.Uri; + string path = RawPath(wsUrl); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.RandomBytes(16); + string key = WebSocketUtil.Base64String(nonce); + + string acceptSeed = key + MagicGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + this.expectedChallengeResponseString = new AsciiString(WebSocketUtil.Base64String(sha1)); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 08 client handshake key: {}, expected response: {}", + key, this.expectedChallengeResponseString); + } + + // Format request + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, path); + HttpHeaders headers = request.Headers; + + headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket) + .Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade) + .Add(HttpHeaderNames.SecWebsocketKey, key) + .Add(HttpHeaderNames.Host, WebsocketHostValue(wsUrl)) + .Add(HttpHeaderNames.SecWebsocketOrigin, WebsocketOriginValue(wsUrl)); + + string expectedSubprotocol = this.ExpectedSubprotocol; + if (!string.IsNullOrEmpty(expectedSubprotocol)) + { + headers.Add(HttpHeaderNames.SecWebsocketProtocol, expectedSubprotocol); + } + + headers.Add(HttpHeaderNames.SecWebsocketVersion, "8"); + + if (this.CustomHeaders != null) + { + headers.Add(this.CustomHeaders); + } + return request; + } + + protected override void Verify(IFullHttpResponse response) + { + HttpResponseStatus status = HttpResponseStatus.SwitchingProtocols; + HttpHeaders headers = response.Headers; + + if (!response.Status.Equals(status)) + { + throw new WebSocketHandshakeException($"Invalid handshake response getStatus: {response.Status}"); + } + + if (!headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence upgrade) + || !HttpHeaderValues.Websocket.ContentEqualsIgnoreCase(upgrade)) + { + throw new WebSocketHandshakeException($"Invalid handshake response upgrade: {upgrade}"); + } + + if (!headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true)) + { + headers.TryGet(HttpHeaderNames.Connection, out upgrade); + throw new WebSocketHandshakeException($"Invalid handshake response connection: {upgrade}"); + } + + if (!headers.TryGet(HttpHeaderNames.SecWebsocketAccept, out ICharSequence accept) + || !accept.Equals(this.expectedChallengeResponseString)) + { + throw new WebSocketHandshakeException($"Invalid challenge. Actual: {accept}. Expected: {this.expectedChallengeResponseString}"); + } + } + + protected internal override IWebSocketFrameDecoder NewWebSocketDecoder() => new WebSocket08FrameDecoder( + false, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket08FrameEncoder(this.performMasking); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker13.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker13.cs new file mode 100644 index 000000000..bce560e8e --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshaker13.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Text; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker13 : WebSocketClientHandshaker + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + public static readonly string MagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + AsciiString expectedChallengeResponseString; + + readonly bool allowExtensions; + readonly bool performMasking; + readonly bool allowMaskMismatch; + + public WebSocketClientHandshaker13(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) + : this(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false) + { + } + + public WebSocketClientHandshaker13(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + bool performMasking, bool allowMaskMismatch) + : base(webSocketUrl, version, subprotocol, customHeaders, maxFramePayloadLength) + { + + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected internal override IFullHttpRequest NewHandshakeRequest() + { + // Get path + Uri wsUrl = this.Uri; + string path = RawPath(wsUrl); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.RandomBytes(16); + string key = WebSocketUtil.Base64String(nonce); + + string acceptSeed = key + MagicGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + this.expectedChallengeResponseString = new AsciiString(WebSocketUtil.Base64String(sha1)); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 13 client handshake key: {}, expected response: {}", + key, this.expectedChallengeResponseString); + } + + // Format request + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, path); + HttpHeaders headers = request.Headers; + + headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket) + .Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade) + .Add(HttpHeaderNames.SecWebsocketKey, key) + .Add(HttpHeaderNames.Host, WebsocketHostValue(wsUrl)) + .Add(HttpHeaderNames.SecWebsocketOrigin, WebsocketOriginValue(wsUrl)); + + string expectedSubprotocol = this.ExpectedSubprotocol; + if (!string.IsNullOrEmpty(expectedSubprotocol)) + { + headers.Add(HttpHeaderNames.SecWebsocketProtocol, expectedSubprotocol); + } + + headers.Add(HttpHeaderNames.SecWebsocketVersion, "13"); + + if (this.CustomHeaders != null) + { + headers.Add(this.CustomHeaders); + } + return request; + } + + protected override void Verify(IFullHttpResponse response) + { + HttpResponseStatus status = HttpResponseStatus.SwitchingProtocols; + HttpHeaders headers = response.Headers; + + if (!response.Status.Equals(status)) + { + throw new WebSocketHandshakeException($"Invalid handshake response getStatus: {response.Status}"); + } + + if (!headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence upgrade) + || !HttpHeaderValues.Websocket.ContentEqualsIgnoreCase(upgrade)) + { + throw new WebSocketHandshakeException($"Invalid handshake response upgrade: {upgrade}"); + } + + if (!headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true)) + { + headers.TryGet(HttpHeaderNames.Connection, out upgrade); + throw new WebSocketHandshakeException($"Invalid handshake response connection: {upgrade}"); + } + + if (!headers.TryGet(HttpHeaderNames.SecWebsocketAccept, out ICharSequence accept) + || !accept.Equals(this.expectedChallengeResponseString)) + { + throw new WebSocketHandshakeException($"Invalid challenge. Actual: {accept}. Expected: {this.expectedChallengeResponseString}"); + } + } + + protected internal override IWebSocketFrameDecoder NewWebSocketDecoder() => new WebSocket13FrameDecoder( + false, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket13FrameEncoder(this.performMasking); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshakerFactory.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshakerFactory.cs new file mode 100644 index 000000000..7aca65aa9 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientHandshakerFactory.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + + using static WebSocketVersion; + + public static class WebSocketClientHandshakerFactory + { + public static WebSocketClientHandshaker NewHandshaker(Uri webSocketUrl, WebSocketVersion version, string subprotocol, bool allowExtensions, HttpHeaders customHeaders) => + NewHandshaker(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, 65536); + + public static WebSocketClientHandshaker NewHandshaker(Uri webSocketUrl, WebSocketVersion version, string subprotocol, bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) => + NewHandshaker(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false); + + public static WebSocketClientHandshaker NewHandshaker( + Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + bool performMasking, bool allowMaskMismatch) + { + if (version == V13) + { + return new WebSocketClientHandshaker13( + webSocketUrl, V13, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch); + } + if (version == V08) + { + return new WebSocketClientHandshaker08( + webSocketUrl, V08, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch); + } + if (version == V07) + { + return new WebSocketClientHandshaker07( + webSocketUrl, V07, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch); + } + if (version == V00) + { + return new WebSocketClientHandshaker00( + webSocketUrl, V00, subprotocol, customHeaders, maxFramePayloadLength); + } + + throw new WebSocketHandshakeException($"Protocol version {version}not supported."); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandler.cs new file mode 100644 index 000000000..62a1d59ab --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandler.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable once ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using DotNetty.Transport.Channels; + + public class WebSocketClientProtocolHandler : WebSocketProtocolHandler + { + readonly WebSocketClientHandshaker handshaker; + readonly bool handleCloseFrames; + + public WebSocketClientHandshaker Handshaker => this.handshaker; + + /// + /// Events that are fired to notify about handshake status + /// + public enum ClientHandshakeStateEvent + { + /// + /// The Handshake was started but the server did not response yet to the request + /// + HandshakeIssued, + + /// + /// The Handshake was complete succesful and so the channel was upgraded to websockets + /// + HandshakeComplete + } + + public WebSocketClientProtocolHandler(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, bool handleCloseFrames, + bool performMasking, bool allowMaskMismatch) + : this(WebSocketClientHandshakerFactory.NewHandshaker(webSocketUrl, version, subprotocol, + allowExtensions, customHeaders, maxFramePayloadLength, + performMasking, allowMaskMismatch), handleCloseFrames) + { + } + + public WebSocketClientProtocolHandler(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, bool handleCloseFrames) + : this(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, true, false) + { + } + + public WebSocketClientProtocolHandler(Uri webSocketUrl, WebSocketVersion version, string subprotocol, + bool allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength) + : this(webSocketUrl, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true) + { + } + + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, bool handleCloseFrames) + { + this.handshaker = handshaker; + this.handleCloseFrames = handleCloseFrames; + } + + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) + : this(handshaker, true) + { + } + + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame frame, List output) + { + if (this.handleCloseFrames && frame is CloseWebSocketFrame) + { + ctx.CloseAsync(); + return; + } + + base.Decode(ctx, frame, output); + } + + public override void HandlerAdded(IChannelHandlerContext ctx) + { + IChannelPipeline cp = ctx.Channel.Pipeline; + if (cp.Get() == null) + { + // Add the WebSocketClientProtocolHandshakeHandler before this one. + ctx.Channel.Pipeline.AddBefore(ctx.Name, nameof(WebSocketClientProtocolHandshakeHandler), + new WebSocketClientProtocolHandshakeHandler(this.handshaker)); + } + if (cp.Get() == null) + { + // Add the UFT8 checking before this one. + ctx.Channel.Pipeline.AddBefore(ctx.Name, nameof(Utf8FrameValidator), + new Utf8FrameValidator()); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandshakeHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandshakeHandler.cs new file mode 100644 index 000000000..977863c0d --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketClientProtocolHandshakeHandler.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Threading.Tasks; + using DotNetty.Transport.Channels; + + class WebSocketClientProtocolHandshakeHandler : ChannelHandlerAdapter + { + readonly WebSocketClientHandshaker handshaker; + + internal WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) + { + this.handshaker = handshaker; + } + + public override void ChannelActive(IChannelHandlerContext context) + { + base.ChannelActive(context); + this.handshaker.HandshakeAsync(context.Channel) + .ContinueWith((t, state) => + { + var ctx = (IChannelHandlerContext)state; + if (t.Status == TaskStatus.RanToCompletion) + { + ctx.FireUserEventTriggered(WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HandshakeIssued); + } + else + { + ctx.FireExceptionCaught(t.Exception); + } + }, + context, + TaskContinuationOptions.ExecuteSynchronously); + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + if (!(msg is IFullHttpResponse)) + { + ctx.FireChannelRead(msg); + return; + } + + var response = (IFullHttpResponse)msg; + try + { + if (!this.handshaker.IsHandshakeComplete) + { + this.handshaker.FinishHandshake(ctx.Channel, response); + ctx.FireUserEventTriggered(WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HandshakeComplete); + ctx.Channel.Pipeline.Remove(this); + return; + } + + throw new InvalidOperationException("WebSocketClientHandshaker should have been finished yet"); + } + finally + { + response.Release(); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrame.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrame.cs new file mode 100644 index 000000000..e4783745c --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrame.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + public abstract class WebSocketFrame : DefaultByteBufferHolder + { + // Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) may also be the + // final fragment. + readonly bool finalFragment; + + // RSV1, RSV2, RSV3 used for extensions + readonly int rsv; + + protected WebSocketFrame(IByteBuffer binaryData) + : this(true, 0, binaryData) + { + } + + protected WebSocketFrame(bool finalFragment, int rsv, IByteBuffer binaryData) + : base(binaryData) + { + this.finalFragment = finalFragment; + this.rsv = rsv; + } + + /// + /// Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) + /// may also be the final fragment. + /// + public bool IsFinalFragment => this.finalFragment; + + /// + /// RSV1, RSV2, RSV3 used for extensions + /// + public int Rsv => this.rsv; + + public override string ToString() => StringUtil.SimpleClassName(this) + "(data: " + this.ContentToString() + ')'; + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrameAggregator.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrameAggregator.cs new file mode 100644 index 000000000..8f95ebe6a --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketFrameAggregator.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs; + using DotNetty.Transport.Channels; + + public class WebSocketFrameAggregator : MessageAggregator + { + public WebSocketFrameAggregator(int maxContentLength) + : base(maxContentLength) + { + } + + protected override bool IsStartMessage(WebSocketFrame msg) => msg is TextWebSocketFrame || msg is BinaryWebSocketFrame; + + protected override bool IsContentMessage(WebSocketFrame msg) => msg is ContinuationWebSocketFrame; + + protected override bool IsLastContentMessage(ContinuationWebSocketFrame msg) => this.IsContentMessage(msg) && msg.IsFinalFragment; + + protected override bool IsAggregated(WebSocketFrame msg) + { + if (msg.IsFinalFragment) + { + return !this.IsContentMessage(msg); + } + + return !this.IsStartMessage(msg) && !this.IsContentMessage(msg); + } + + protected override bool IsContentLengthInvalid(WebSocketFrame start, int maxContentLength) => false; + + protected override object NewContinueResponse(WebSocketFrame start, int maxContentLength, IChannelPipeline pipeline) => null; + + protected override bool CloseAfterContinueResponse(object msg) => throw new NotSupportedException(); + + protected override bool IgnoreContentAfterContinueResponse(object msg) => throw new NotSupportedException(); + + protected override WebSocketFrame BeginAggregation(WebSocketFrame start, IByteBuffer content) + { + if (start is TextWebSocketFrame) + { + return new TextWebSocketFrame(true, start.Rsv, content); + } + + if (start is BinaryWebSocketFrame) + { + return new BinaryWebSocketFrame(true, start.Rsv, content); + } + + // Should not reach here. + throw new Exception("Unkonw WebSocketFrame type, must be either TextWebSocketFrame or BinaryWebSocketFrame"); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketHandshakeException.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketHandshakeException.cs new file mode 100644 index 000000000..168b1371f --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketHandshakeException.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + + public class WebSocketHandshakeException : Exception + { + public WebSocketHandshakeException(string message, Exception innereException) + : base(message, innereException) + { + } + + public WebSocketHandshakeException(string message) + : base(message) + { + } + + public WebSocketHandshakeException(Exception innerException) + : base(null, innerException) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketProtocolHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketProtocolHandler.cs new file mode 100644 index 000000000..36ec0fe61 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketProtocolHandler.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using DotNetty.Transport.Channels; + + public abstract class WebSocketProtocolHandler : MessageToMessageDecoder + { + readonly bool dropPongFrames; + + internal WebSocketProtocolHandler() : this(true) + { + } + + internal WebSocketProtocolHandler(bool dropPongFrames) + { + this.dropPongFrames = dropPongFrames; + } + + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame frame, List output) + { + if (frame is PingWebSocketFrame) + { + frame.Content.Retain(); + ctx.Channel.WriteAndFlushAsync(new PongWebSocketFrame(frame.Content)); + return; + } + + if (frame is PongWebSocketFrame && this.dropPongFrames) + { + // Pong frames need to get ignored + return; + } + + output.Add(frame.Retain()); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + ctx.FireExceptionCaught(cause); + ctx.CloseAsync(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketScheme.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketScheme.cs new file mode 100644 index 000000000..534e8ea63 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketScheme.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Http.WebSockets +{ + using DotNetty.Common.Utilities; + + public sealed class WebSocketScheme + { + // Scheme for non-secure WebSocket connection. + public static readonly WebSocketScheme WS = new WebSocketScheme(80, "ws"); + + // Scheme for secure WebSocket connection. + public static readonly WebSocketScheme WSS = new WebSocketScheme(443, "wss"); + + readonly int port; + readonly AsciiString name; + + WebSocketScheme(int port, string name) + { + this.port = port; + this.name = AsciiString.Cached(name); + } + + public AsciiString Name => this.name; + + public int Port => this.port; + + public override bool Equals(object obj) => obj is WebSocketScheme other + && other.port == this.port && other.name.Equals(this.name); + + public override int GetHashCode() => this.port * 31 + this.name.GetHashCode(); + + public override string ToString() => this.name.ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker.cs new file mode 100644 index 000000000..e5ec847d8 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Internal; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public abstract class WebSocketServerHandshaker + { + protected static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + static readonly ClosedChannelException ClosedChannelException = new ClosedChannelException(); + + readonly string uri; + + readonly string[] subprotocols; + + readonly WebSocketVersion version; + + readonly int maxFramePayloadLength; + + string selectedSubprotocol; + + // Use this as wildcard to support all requested sub-protocols + public static readonly string SubProtocolWildcard = "*"; + + protected WebSocketServerHandshaker(WebSocketVersion version, string uri, string subprotocols, int maxFramePayloadLength) + { + this.version = version; + this.uri = uri; + if (subprotocols != null) + { + string[] subprotocolArray = subprotocols.Split(','); + for (int i = 0; i < subprotocolArray.Length; i++) + { + subprotocolArray[i] = subprotocolArray[i].Trim(); + } + this.subprotocols = subprotocolArray; + } + else + { + this.subprotocols = EmptyArrays.EmptyStrings; + } + this.maxFramePayloadLength = maxFramePayloadLength; + } + + public string Uri => this.uri; + + public ISet Subprotocols() + { + var ret = new HashSet(this.subprotocols); + return ret; + } + + public WebSocketVersion Version => this. version; + + public int MaxFramePayloadLength => this.maxFramePayloadLength; + + public Task HandshakeAsync(IChannel channel, IFullHttpRequest req) => this.HandshakeAsync(channel, req, null); + + public Task HandshakeAsync(IChannel channel, IFullHttpRequest req, HttpHeaders responseHeaders) + { + var completion = new TaskCompletionSource(); + this.Handshake(channel, req, responseHeaders, completion); + return completion.Task; + } + + public void Handshake(IChannel channel, IFullHttpRequest req, HttpHeaders responseHeaders, TaskCompletionSource completion) + { + if (Logger.DebugEnabled) + { + Logger.Debug("{} WebSocket version {} server handshake", channel, this.version); + } + + IFullHttpResponse response = this.NewHandshakeResponse(req, responseHeaders); + IChannelPipeline p = channel.Pipeline; + if (p.Get() != null) + { + p.Remove(); + } + + if (p.Get() != null) + { + p.Remove(); + } + + IChannelHandlerContext ctx = p.Context(); + string encoderName; + if (ctx == null) + { + // this means the user use a HttpServerCodec + ctx = p.Context(); + if (ctx == null) + { + completion.TrySetException(new InvalidOperationException("No HttpDecoder and no HttpServerCodec in the pipeline")); + return; + } + + p.AddBefore(ctx.Name, "wsdecoder", this.NewWebsocketDecoder()); + p.AddBefore(ctx.Name, "wsencoder", this.NewWebSocketEncoder()); + encoderName = ctx.Name; + } + else + { + p.Replace(ctx.Name, "wsdecoder", this.NewWebsocketDecoder()); + + encoderName = p.Context().Name; + p.AddBefore(encoderName, "wsencoder", this.NewWebSocketEncoder()); + } + + channel.WriteAndFlushAsync(response).ContinueWith(t => + { + if (t.Status == TaskStatus.RanToCompletion) + { + p.Remove(encoderName); + completion.TryComplete(); + } + else + { + completion.TrySetException(t.Exception); + } + }); + } + + public Task HandshakeAsync(IChannel channel, IHttpRequest req, HttpHeaders responseHeaders) + { + if (req is IFullHttpRequest request) + { + return this.HandshakeAsync(channel, request, responseHeaders); + } + if (Logger.DebugEnabled) + { + Logger.Debug("{} WebSocket version {} server handshake", channel, this.version); + } + IChannelPipeline p = channel.Pipeline; + IChannelHandlerContext ctx = p.Context(); + if (ctx == null) + { + // this means the user use a HttpServerCodec + ctx = p.Context(); + if (ctx == null) + { + return TaskEx.FromException(new InvalidOperationException("No HttpDecoder and no HttpServerCodec in the pipeline")); + } + } + + // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then + // enough for the websockets handshake payload. + // + // TODO: Make handshake work without HttpObjectAggregator at all. + string aggregatorName = "httpAggregator"; + p.AddAfter(ctx.Name, aggregatorName, new HttpObjectAggregator(8192)); + var completion = new TaskCompletionSource(); + p.AddAfter(aggregatorName, "handshaker", new Handshaker(this, channel, responseHeaders, completion)); + try + { + ctx.FireChannelRead(ReferenceCountUtil.Retain(req)); + } + catch (Exception cause) + { + completion.TrySetException(cause); + } + return completion.Task; + } + + sealed class Handshaker : SimpleChannelInboundHandler + { + readonly WebSocketServerHandshaker serverHandshaker; + readonly IChannel channel; + readonly HttpHeaders responseHeaders; + readonly TaskCompletionSource completion; + + public Handshaker(WebSocketServerHandshaker serverHandshaker, IChannel channel, HttpHeaders responseHeaders, TaskCompletionSource completion) + { + this.serverHandshaker = serverHandshaker; + this.channel = channel; + this.responseHeaders = responseHeaders; + this.completion = completion; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, IFullHttpRequest msg) + { + // Remove ourself and do the actual handshake + ctx.Channel.Pipeline.Remove(this); + this.serverHandshaker.Handshake(this.channel, msg, this.responseHeaders, this.completion); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + // Remove ourself and fail the handshake promise. + ctx.Channel.Pipeline.Remove(this); + this.completion.TrySetException(cause); + ctx.FireExceptionCaught(cause); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + // Fail promise if Channel was closed + this.completion.TrySetException(ClosedChannelException); + ctx.FireChannelInactive(); + } + } + + protected abstract IFullHttpResponse NewHandshakeResponse(IFullHttpRequest req, HttpHeaders responseHeaders); + + public virtual Task CloseAsync(IChannel channel, CloseWebSocketFrame frame) + { + Contract.Requires(channel != null); + + return channel.WriteAndFlushAsync(frame).ContinueWith((t, s) => ((IChannel)s).CloseAsync(), + channel, TaskContinuationOptions.ExecuteSynchronously); + } + + protected string SelectSubprotocol(string requestedSubprotocols) + { + if (requestedSubprotocols == null || this.subprotocols.Length == 0) + { + return null; + } + + string[] requestedSubprotocolArray = requestedSubprotocols.Split(','); + foreach (string p in requestedSubprotocolArray) + { + string requestedSubprotocol = p.Trim(); + + foreach (string supportedSubprotocol in this.subprotocols) + { + if (SubProtocolWildcard.Equals(supportedSubprotocol) + || requestedSubprotocol.Equals(supportedSubprotocol)) + { + this.selectedSubprotocol = requestedSubprotocol; + return requestedSubprotocol; + } + } + } + + // No match found + return null; + } + + public string SelectedSubprotocol => this.selectedSubprotocol; + + protected internal abstract IWebSocketFrameDecoder NewWebsocketDecoder(); + + protected internal abstract IWebSocketFrameEncoder NewWebSocketEncoder(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker00.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker00.cs new file mode 100644 index 000000000..b2fe99bbd --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker00.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Diagnostics; + using System.Text.RegularExpressions; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocketServerHandshaker00 : WebSocketServerHandshaker + { + static readonly Regex BeginningDigit = new Regex("[^0-9]", RegexOptions.Compiled); + static readonly Regex BeginningSpace = new Regex("[^ ]", RegexOptions.Compiled); + + public WebSocketServerHandshaker00(string webSocketUrl, string subprotocols, int maxFramePayloadLength) + : base(WebSocketVersion.V00, webSocketUrl, subprotocols, maxFramePayloadLength) + { + } + + protected override IFullHttpResponse NewHandshakeResponse(IFullHttpRequest req, HttpHeaders headers) + { + // Serve the WebSocket handshake request. + if (!req.Headers.ContainsValue(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade, true) + || !req.Headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence value) + || !HttpHeaderValues.Websocket.ContentEqualsIgnoreCase(value)) + { + throw new WebSocketHandshakeException("not a WebSocket handshake request: missing upgrade"); + } + + // Hixie 75 does not contain these headers while Hixie 76 does + bool isHixie76 = req.Headers.Contains(HttpHeaderNames.SecWebsocketKey1) + && req.Headers.Contains(HttpHeaderNames.SecWebsocketKey2); + + // Create the WebSocket handshake response. + var res = new DefaultFullHttpResponse(HttpVersion.Http11, + new HttpResponseStatus(101, new AsciiString(isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"))); + if (headers != null) + { + res.Headers.Add(headers); + } + + res.Headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + res.Headers.Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + + // Fill in the headers and contents depending on handshake getMethod. + if (isHixie76) + { + // New handshake getMethod with a challenge: + value = req.Headers.Get(HttpHeaderNames.Origin, null); + Debug.Assert(value != null); + res.Headers.Add(HttpHeaderNames.SecWebsocketOrigin, value); + res.Headers.Add(HttpHeaderNames.SecWebsocketLocation, this.Uri); + + if (req.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out ICharSequence subprotocols)) + { + string selectedSubprotocol = this.SelectSubprotocol(subprotocols.ToString()); + if (selectedSubprotocol == null) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } + else + { + res.Headers.Add(HttpHeaderNames.SecWebsocketProtocol, selectedSubprotocol); + } + } + + // Calculate the answer of the challenge. + value = req.Headers.Get(HttpHeaderNames.SecWebsocketKey1, null); + Debug.Assert(value != null, $"{HttpHeaderNames.SecWebsocketKey1} must exist"); + string key1 = value.ToString(); + value = req.Headers.Get(HttpHeaderNames.SecWebsocketKey2, null); + Debug.Assert(value != null, $"{HttpHeaderNames.SecWebsocketKey2} must exist"); + string key2 = value.ToString(); + int a = (int)(long.Parse(BeginningDigit.Replace(key1, "")) / + BeginningSpace.Replace(key1, "").Length); + int b = (int)(long.Parse(BeginningDigit.Replace(key2, "")) / + BeginningSpace.Replace(key2, "").Length); + long c = req.Content.ReadLong(); + IByteBuffer input = Unpooled.Buffer(16); + input.WriteInt(a); + input.WriteInt(b); + input.WriteLong(c); + res.Content.WriteBytes(WebSocketUtil.Md5(input.Array)); + } + else + { + // Old Hixie 75 handshake getMethod with no challenge: + value = req.Headers.Get(HttpHeaderNames.Origin, null); + Debug.Assert(value != null); + res.Headers.Add(HttpHeaderNames.WebsocketOrigin, value); + res.Headers.Add(HttpHeaderNames.WebsocketLocation, this.Uri); + + if (req.Headers.TryGet(HttpHeaderNames.WebsocketProtocol, out ICharSequence protocol)) + { + res.Headers.Add(HttpHeaderNames.WebsocketProtocol, this.SelectSubprotocol(protocol.ToString())); + } + } + + return res; + } + + public override Task CloseAsync(IChannel channel, CloseWebSocketFrame frame) => channel.WriteAndFlushAsync(frame); + + protected internal override IWebSocketFrameDecoder NewWebsocketDecoder() => new WebSocket00FrameDecoder(this.MaxFramePayloadLength); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket00FrameEncoder(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker07.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker07.cs new file mode 100644 index 000000000..86f7a782e --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker07.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Common.Utilities; + + public class WebSocketServerHandshaker07 : WebSocketServerHandshaker + { + public static readonly string Websocket07AcceptGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + readonly bool allowExtensions; + readonly bool allowMaskMismatch; + + public WebSocketServerHandshaker07(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength) + : this(webSocketUrl, subprotocols, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocketServerHandshaker07(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength, + bool allowMaskMismatch) + : base(WebSocketVersion.V07, webSocketUrl, subprotocols, maxFramePayloadLength) + { + this.allowExtensions = allowExtensions; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected override IFullHttpResponse NewHandshakeResponse(IFullHttpRequest req, HttpHeaders headers) + { + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + + if (headers != null) + { + res.Headers.Add(headers); + } + + if (!req.Headers.TryGet(HttpHeaderNames.SecWebsocketKey, out ICharSequence key) + || key == null) + { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } + string acceptSeed = key + Websocket07AcceptGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + string accept = WebSocketUtil.Base64String(sha1); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 07 server handshake key: {}, response: {}.", key, accept); + } + + res.Headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + res.Headers.Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + res.Headers.Add(HttpHeaderNames.SecWebsocketAccept, accept); + + + if (req.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out ICharSequence subprotocols) + && subprotocols != null) + { + string selectedSubprotocol = this.SelectSubprotocol(subprotocols.ToString()); + if (selectedSubprotocol == null) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } + else + { + res.Headers.Add(HttpHeaderNames.SecWebsocketProtocol, selectedSubprotocol); + } + } + return res; + } + + protected internal override IWebSocketFrameDecoder NewWebsocketDecoder() => new WebSocket07FrameDecoder( + true, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket07FrameEncoder(false); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker08.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker08.cs new file mode 100644 index 000000000..7dd3731cf --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker08.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Common.Utilities; + + public class WebSocketServerHandshaker08 : WebSocketServerHandshaker + { + public static readonly string Websocket08AcceptGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + readonly bool allowExtensions; + readonly bool allowMaskMismatch; + + public WebSocketServerHandshaker08(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength) + : this(webSocketUrl, subprotocols, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocketServerHandshaker08(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength, + bool allowMaskMismatch) + : base(WebSocketVersion.V08, webSocketUrl, subprotocols, maxFramePayloadLength) + { + this.allowExtensions = allowExtensions; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected override IFullHttpResponse NewHandshakeResponse(IFullHttpRequest req, HttpHeaders headers) + { + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + + if (headers != null) + { + res.Headers.Add(headers); + } + + if (!req.Headers.TryGet(HttpHeaderNames.SecWebsocketKey, out ICharSequence key) + || key == null) + { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } + string acceptSeed = key + Websocket08AcceptGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + string accept = WebSocketUtil.Base64String(sha1); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 08 server handshake key: {}, response: {}", key, accept); + } + + res.Headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + res.Headers.Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + res.Headers.Add(HttpHeaderNames.SecWebsocketAccept, accept); + + if (req.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out ICharSequence subprotocols) + && subprotocols != null) + { + string selectedSubprotocol = this.SelectSubprotocol(subprotocols.ToString()); + if (selectedSubprotocol == null) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } + else + { + res.Headers.Add(HttpHeaderNames.SecWebsocketProtocol, selectedSubprotocol); + } + } + return res; + } + + protected internal override IWebSocketFrameDecoder NewWebsocketDecoder() => new WebSocket08FrameDecoder( + true, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket08FrameEncoder(false); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker13.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker13.cs new file mode 100644 index 000000000..888de19da --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshaker13.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Text; + using DotNetty.Common.Utilities; + + public class WebSocketServerHandshaker13 : WebSocketServerHandshaker + { + public static readonly string Websocket13AcceptGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + readonly bool allowExtensions; + readonly bool allowMaskMismatch; + + public WebSocketServerHandshaker13(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength) + : this(webSocketUrl, subprotocols, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocketServerHandshaker13(string webSocketUrl, string subprotocols, bool allowExtensions, int maxFramePayloadLength, + bool allowMaskMismatch) + : base(WebSocketVersion.V13, webSocketUrl, subprotocols, maxFramePayloadLength) + { + this.allowExtensions = allowExtensions; + this.allowMaskMismatch = allowMaskMismatch; + } + + protected override IFullHttpResponse NewHandshakeResponse(IFullHttpRequest req, HttpHeaders headers) + { + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + if (headers != null) + { + res.Headers.Add(headers); + } + + if (!req.Headers.TryGet(HttpHeaderNames.SecWebsocketKey, out ICharSequence key) + || key == null) + { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } + string acceptSeed = key.ToString() + Websocket13AcceptGuid; + byte[] sha1 = WebSocketUtil.Sha1(Encoding.ASCII.GetBytes(acceptSeed)); + string accept = WebSocketUtil.Base64String(sha1); + + if (Logger.DebugEnabled) + { + Logger.Debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept); + } + + res.Headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + res.Headers.Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + res.Headers.Add(HttpHeaderNames.SecWebsocketAccept, accept); + + if (req.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out ICharSequence subprotocols) + && subprotocols != null) + { + string selectedSubprotocol = this.SelectSubprotocol(subprotocols.ToString()); + if (selectedSubprotocol == null) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } + else + { + res.Headers.Add(HttpHeaderNames.SecWebsocketProtocol, selectedSubprotocol); + } + } + return res; + } + + protected internal override IWebSocketFrameDecoder NewWebsocketDecoder() => new WebSocket13FrameDecoder( + true, this.allowExtensions, this.MaxFramePayloadLength, this.allowMaskMismatch); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => new WebSocket13FrameEncoder(false); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshakerFactory.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshakerFactory.cs new file mode 100644 index 000000000..3c3b4cbc1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerHandshakerFactory.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class WebSocketServerHandshakerFactory + { + readonly string webSocketUrl; + + readonly string subprotocols; + + readonly bool allowExtensions; + + readonly int maxFramePayloadLength; + + readonly bool allowMaskMismatch; + + public WebSocketServerHandshakerFactory(string webSocketUrl, string subprotocols, bool allowExtensions) + : this(webSocketUrl, subprotocols, allowExtensions, 65536) + { + } + + public WebSocketServerHandshakerFactory(string webSocketUrl, string subprotocols, bool allowExtensions, + int maxFramePayloadLength) + : this(webSocketUrl, subprotocols, allowExtensions, maxFramePayloadLength, false) + { + } + + public WebSocketServerHandshakerFactory(string webSocketUrl, string subprotocols, bool allowExtensions, + int maxFramePayloadLength, bool allowMaskMismatch) + { + this.webSocketUrl = webSocketUrl; + this.subprotocols = subprotocols; + this.allowExtensions = allowExtensions; + this.maxFramePayloadLength = maxFramePayloadLength; + this.allowMaskMismatch = allowMaskMismatch; + } + + public WebSocketServerHandshaker NewHandshaker(IHttpRequest req) + { + if (req.Headers.TryGet(HttpHeaderNames.SecWebsocketVersion, out ICharSequence version) + && version != null) + { + if (version.Equals(WebSocketVersion.V13.ToHttpHeaderValue())) + { + // Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification). + return new WebSocketServerHandshaker13( + this.webSocketUrl, + this.subprotocols, + this.allowExtensions, + this.maxFramePayloadLength, + this.allowMaskMismatch); + } + else if (version.Equals(WebSocketVersion.V08.ToHttpHeaderValue())) + { + // Version 8 of the wire protocol - version 10 of the draft hybi specification. + return new WebSocketServerHandshaker08( + this.webSocketUrl, + this.subprotocols, + this.allowExtensions, + this.maxFramePayloadLength, + this.allowMaskMismatch); + } + else if (version.Equals(WebSocketVersion.V07.ToHttpHeaderValue())) + { + // Version 8 of the wire protocol - version 07 of the draft hybi specification. + return new WebSocketServerHandshaker07( + this.webSocketUrl, + this.subprotocols, + this.allowExtensions, + this.maxFramePayloadLength, + this.allowMaskMismatch); + } + else + { + return null; + } + } + else + { + // Assume version 00 where version header was not specified + return new WebSocketServerHandshaker00(this.webSocketUrl, this.subprotocols, this.maxFramePayloadLength); + } + } + + public static Task SendUnsupportedVersionResponse(IChannel channel) + { + var res = new DefaultFullHttpResponse( + HttpVersion.Http11, + HttpResponseStatus.UpgradeRequired); + res.Headers.Set(HttpHeaderNames.SecWebsocketVersion, WebSocketVersion.V13.ToHttpHeaderValue()); + HttpUtil.SetContentLength(res, 0); + return channel.WriteAndFlushAsync(res); + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandler.cs new file mode 100644 index 000000000..dee9fe198 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandler.cs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + using static HttpVersion; + + public class WebSocketServerProtocolHandler : WebSocketProtocolHandler + { + public sealed class HandshakeComplete + { + readonly string requestUri; + readonly HttpHeaders requestHeaders; + readonly string selectedSubprotocol; + + internal HandshakeComplete(string requestUri, HttpHeaders requestHeaders, string selectedSubprotocol) + { + this.requestUri = requestUri; + this.requestHeaders = requestHeaders; + this.selectedSubprotocol = selectedSubprotocol; + } + + public string RequestUri => this.requestUri; + + public HttpHeaders RequestHeaders => this.requestHeaders; + + public string SelectedSubprotocol => this.selectedSubprotocol; + } + + static readonly AttributeKey HandshakerAttrKey = + AttributeKey.ValueOf("HANDSHAKER"); + + readonly string websocketPath; + readonly string subprotocols; + readonly bool allowExtensions; + readonly int maxFramePayloadLength; + readonly bool allowMaskMismatch; + readonly bool checkStartsWith; + + public WebSocketServerProtocolHandler(string websocketPath) + : this(websocketPath, null, false) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, bool checkStartsWith) + : this(websocketPath, null, false, 65536, false, checkStartsWith) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols) + : this(websocketPath, subprotocols, false) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols, bool allowExtensions) + : this(websocketPath, subprotocols, allowExtensions, 65536) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols, + bool allowExtensions, int maxFrameSize) + : this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols, + bool allowExtensions, int maxFrameSize, bool allowMaskMismatch) + : this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols, bool allowExtensions, + int maxFrameSize, bool allowMaskMismatch, bool checkStartsWith) + : this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true) + { + } + + public WebSocketServerProtocolHandler(string websocketPath, string subprotocols, + bool allowExtensions, int maxFrameSize, bool allowMaskMismatch, bool checkStartsWith, bool dropPongFrames) + : base(dropPongFrames) + { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.allowExtensions = allowExtensions; + this.maxFramePayloadLength = maxFrameSize; + this.allowMaskMismatch = allowMaskMismatch; + this.checkStartsWith = checkStartsWith; + } + + public override void HandlerAdded(IChannelHandlerContext ctx) + { + IChannelPipeline cp = ctx.Channel.Pipeline; + if (cp.Get() == null) + { + // Add the WebSocketHandshakeHandler before this one. + ctx.Channel.Pipeline.AddBefore(ctx.Name, nameof(WebSocketServerProtocolHandshakeHandler), + new WebSocketServerProtocolHandshakeHandler( + this.websocketPath, + this.subprotocols, + this.allowExtensions, + this.maxFramePayloadLength, + this.allowMaskMismatch, + this.checkStartsWith)); + } + + if (cp.Get() == null) + { + // Add the UFT8 checking before this one. + ctx.Channel.Pipeline.AddBefore(ctx.Name, nameof(Utf8FrameValidator), new Utf8FrameValidator()); + } + } + + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame frame, List output) + { + if (frame is CloseWebSocketFrame socketFrame) + { + WebSocketServerHandshaker handshaker = GetHandshaker(ctx.Channel); + if (handshaker != null) + { + frame.Retain(); + handshaker.CloseAsync(ctx.Channel, socketFrame); + } + else + { + ctx.WriteAndFlushAsync(Unpooled.Empty) + .ContinueWith((t, c) => ((IChannelHandlerContext)c).CloseAsync(), + ctx, TaskContinuationOptions.ExecuteSynchronously); + } + + return; + } + + base.Decode(ctx, frame, output); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + if (cause is WebSocketHandshakeException) + { + var response = new DefaultFullHttpResponse(Http11, HttpResponseStatus.BadRequest, + Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(cause.Message))); + ctx.Channel.WriteAndFlushAsync(response) + .ContinueWith((t, c) => ((IChannelHandlerContext)c).CloseAsync(), + ctx, TaskContinuationOptions.ExecuteSynchronously); + } + else + { + ctx.FireExceptionCaught(cause); + ctx.CloseAsync(); + } + } + + internal static WebSocketServerHandshaker GetHandshaker(IChannel channel) => channel.GetAttribute(HandshakerAttrKey).Get(); + + internal static void SetHandshaker(IChannel channel, WebSocketServerHandshaker handshaker) => channel.GetAttribute(HandshakerAttrKey).Set(handshaker); + + internal static IChannelHandler ForbiddenHttpRequestResponder() => new ForbiddenResponseHandler(); + + sealed class ForbiddenResponseHandler : ChannelHandlerAdapter + { + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + if (msg is IFullHttpRequest request) + { + request.Release(); + var response = new DefaultFullHttpResponse(Http11, HttpResponseStatus.Forbidden); + ctx.Channel.WriteAndFlushAsync(response); + } + else + { + ctx.FireChannelRead(msg); + } + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandshakeHandler.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandshakeHandler.cs new file mode 100644 index 000000000..abd3ad136 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketServerProtocolHandshakeHandler.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Handlers.Tls; + using DotNetty.Transport.Channels; + + using static HttpUtil; + using static HttpMethod; + using static HttpVersion; + using static HttpResponseStatus; + + class WebSocketServerProtocolHandshakeHandler : ChannelHandlerAdapter + { + readonly string websocketPath; + readonly string subprotocols; + readonly bool allowExtensions; + readonly int maxFramePayloadSize; + readonly bool allowMaskMismatch; + readonly bool checkStartsWith; + + internal WebSocketServerProtocolHandshakeHandler(string websocketPath, string subprotocols, + bool allowExtensions, int maxFrameSize, bool allowMaskMismatch) + : this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false) + { + } + + internal WebSocketServerProtocolHandshakeHandler(string websocketPath, string subprotocols, + bool allowExtensions, int maxFrameSize, bool allowMaskMismatch, bool checkStartsWith) + { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.allowExtensions = allowExtensions; + this.maxFramePayloadSize = maxFrameSize; + this.allowMaskMismatch = allowMaskMismatch; + this.checkStartsWith = checkStartsWith; + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + var req = (IFullHttpRequest)msg; + if (this.IsNotWebSocketPath(req)) + { + ctx.FireChannelRead(msg); + return; + } + + try + { + if (!Equals(req.Method, Get)) + { + SendHttpResponse(ctx, req, new DefaultFullHttpResponse(Http11, Forbidden)); + return; + } + + var wsFactory = new WebSocketServerHandshakerFactory( + GetWebSocketLocation(ctx.Channel.Pipeline, req, this.websocketPath), this.subprotocols, + this.allowExtensions, this.maxFramePayloadSize, this.allowMaskMismatch); + WebSocketServerHandshaker handshaker = wsFactory.NewHandshaker(req); + if (handshaker == null) + { + WebSocketServerHandshakerFactory.SendUnsupportedVersionResponse(ctx.Channel); + } + else + { + Task task = handshaker.HandshakeAsync(ctx.Channel, req); + task.ContinueWith(t => + { + if (t.Status != TaskStatus.RanToCompletion) + { + ctx.FireExceptionCaught(t.Exception); + } + else + { + ctx.FireUserEventTriggered(new WebSocketServerProtocolHandler.HandshakeComplete( + req.Uri, req.Headers, handshaker.SelectedSubprotocol)); + } + }, + TaskContinuationOptions.ExecuteSynchronously); + + WebSocketServerProtocolHandler.SetHandshaker(ctx.Channel, handshaker); + ctx.Channel.Pipeline.Replace(this, "WS403Responder", + WebSocketServerProtocolHandler.ForbiddenHttpRequestResponder()); + } + } + finally + { + req.Release(); + } + } + + bool IsNotWebSocketPath(IFullHttpRequest req) => this.checkStartsWith + ? !req.Uri.StartsWith(this.websocketPath) + : !req.Uri.Equals(this.websocketPath); + + static void SendHttpResponse(IChannelHandlerContext ctx, IHttpRequest req, IHttpResponse res) + { + Task task = ctx.Channel.WriteAndFlushAsync(res); + if (!IsKeepAlive(req) || res.Status.Code != 200) + { + task.ContinueWith((t, c) => ((IChannel)c).CloseAsync(), + ctx.Channel, TaskContinuationOptions.ExecuteSynchronously); + } + } + + static string GetWebSocketLocation(IChannelPipeline cp, IHttpRequest req, string path) + { + string protocol = "ws"; + if (cp.Get() != null) + { + // SSL in use so use Secure WebSockets + protocol = "wss"; + } + + string host = null; + if (req.Headers.TryGet(HttpHeaderNames.Host, out ICharSequence value)) + { + host = value.ToString(); + } + return $"{protocol}://{host}{path}"; + } + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketUtil.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketUtil.cs new file mode 100644 index 000000000..1e11742e0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketUtil.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Security.Cryptography; + using System.Text; + using DotNetty.Codecs.Base64; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Internal; + + static class WebSocketUtil + { + static readonly Random Random = PlatformDependent.GetThreadLocalRandom(); + + static readonly ThreadLocalMD5 LocalMd5 = new ThreadLocalMD5(); + + sealed class ThreadLocalMD5 : FastThreadLocal + { + protected override MD5 GetInitialValue() => MD5.Create(); + } + + static readonly ThreadLocalSha1 LocalSha1 = new ThreadLocalSha1(); + + sealed class ThreadLocalSha1 : FastThreadLocal + { + protected override SHA1 GetInitialValue() => SHA1.Create(); + } + + internal static byte[] Md5(byte[] data) + { + MD5 md5 = LocalMd5.Value; + md5.Initialize(); + return md5.ComputeHash(data); + } + + internal static byte[] Sha1(byte[] data) + { + SHA1 sha1 = LocalSha1.Value; + sha1.Initialize(); + return sha1.ComputeHash(data); + } + + internal static string Base64String(byte[] data) + { + IByteBuffer encodedData = Unpooled.WrappedBuffer(data); + IByteBuffer encoded = Base64.Encode(encodedData); + string encodedString = encoded.ToString(Encoding.UTF8); + encoded.Release(); + return encodedString; + } + + internal static byte[] RandomBytes(int size) + { + var bytes = new byte[size]; + Random.NextBytes(bytes); + return bytes; + } + + internal static int RandomNumber(int minimum, int maximum) => unchecked((int)(Random.NextDouble() * maximum + minimum)); + + // Math.Random() + internal static double RandomNext() => Random.NextDouble(); + } +} diff --git a/src/DotNetty.Codecs.Http/WebSockets/WebSocketVersion.cs b/src/DotNetty.Codecs.Http/WebSockets/WebSocketVersion.cs new file mode 100644 index 000000000..0af355e62 --- /dev/null +++ b/src/DotNetty.Codecs.Http/WebSockets/WebSocketVersion.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ParameterOnlyUsedForPreconditionCheck.Local +namespace DotNetty.Codecs.Http.WebSockets +{ + using System; + using System.Runtime.CompilerServices; + using DotNetty.Common.Utilities; + + public sealed class WebSocketVersion + { + public static readonly WebSocketVersion Unknown = new WebSocketVersion(""); + + // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 + // draft-ietf-hybi-thewebsocketprotocol- 00. + public static readonly WebSocketVersion V00 = new WebSocketVersion("0"); + + // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07 + // draft-ietf-hybi-thewebsocketprotocol- 07 + public static readonly WebSocketVersion V07 = new WebSocketVersion("7"); + + // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 + // draft-ietf-hybi-thewebsocketprotocol- 10 + public static readonly WebSocketVersion V08 = new WebSocketVersion("8"); + + // http://tools.ietf.org/html/rfc6455 This was originally + // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 + //draft-ietf-hybi-thewebsocketprotocol- 17> + public static readonly WebSocketVersion V13 = new WebSocketVersion("13"); + + readonly AsciiString value; + + WebSocketVersion(string value) + { + this.value = AsciiString.Cached(value); + } + + public override string ToString() => this.value.ToString(); + + public AsciiString ToHttpHeaderValue() + { + ThrowIfUnknown(this); + return this.value; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowIfUnknown(WebSocketVersion webSocketVersion) + { + if (webSocketVersion == Unknown) + { + throw new InvalidOperationException("Unknown web socket version"); + } + } + } +} diff --git a/src/DotNetty.Codecs/ByteToMessageDecoder.cs b/src/DotNetty.Codecs/ByteToMessageDecoder.cs index 1f676e978..1a04bd423 100644 --- a/src/DotNetty.Codecs/ByteToMessageDecoder.cs +++ b/src/DotNetty.Codecs/ByteToMessageDecoder.cs @@ -363,6 +363,14 @@ protected virtual void CallDecode(IChannelHandlerContext context, IByteBuffer in } } - protected virtual void DecodeLast(IChannelHandlerContext context, IByteBuffer input, List output) => this.Decode(context, input, output); + protected virtual void DecodeLast(IChannelHandlerContext context, IByteBuffer input, List output) + { + if (input.IsReadable()) + { + // Only call decode() if there is something left in the buffer to decode. + // See https://github.com/netty/netty/issues/4386 + this.Decode(context, input, output); + } + } } } \ No newline at end of file diff --git a/src/DotNetty.Codecs/Compression/JZlibEncoder.cs b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs index c4600175f..901bfcf80 100644 --- a/src/DotNetty.Codecs/Compression/JZlibEncoder.cs +++ b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs @@ -72,6 +72,7 @@ public JZlibEncoder(ZlibWrapper wrapper, int compressionLevel, int windowBits, i this.wrapperOverhead = ZlibUtil.WrapperOverhead(wrapper); } + public JZlibEncoder(byte[] dictionary) : this(6, dictionary) { } diff --git a/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs b/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs index adb22ce54..955a7ec62 100644 --- a/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs +++ b/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +// ReSharper disable ConvertToAutoProperty namespace DotNetty.Codecs.Compression { public static class ZlibCodecFactory { + public static bool IsSupportingWindowSizeAndMemLevel => true; + public static ZlibEncoder NewZlibEncoder(int compressionLevel) => new JZlibEncoder(compressionLevel); public static ZlibEncoder NewZlibEncoder(ZlibWrapper wrapper) => new JZlibEncoder(wrapper); diff --git a/src/DotNetty.Codecs/Compression/ZlibUtil.cs b/src/DotNetty.Codecs/Compression/ZlibUtil.cs index 3b2d96f40..f192d2cac 100644 --- a/src/DotNetty.Codecs/Compression/ZlibUtil.cs +++ b/src/DotNetty.Codecs/Compression/ZlibUtil.cs @@ -46,7 +46,11 @@ public static int WrapperOverhead(ZlibWrapper wrapper) int overhead; switch (wrapper) { + case ZlibWrapper.None: + overhead = 0; + break; case ZlibWrapper.Zlib: + case ZlibWrapper.ZlibOrNone: overhead = 2; break; case ZlibWrapper.Gzip: diff --git a/src/DotNetty.Common/Utilities/AsciiString.cs b/src/DotNetty.Common/Utilities/AsciiString.cs index 15945ab06..2d7ca6f68 100644 --- a/src/DotNetty.Common/Utilities/AsciiString.cs +++ b/src/DotNetty.Common/Utilities/AsciiString.cs @@ -757,6 +757,37 @@ public AsciiString Trim() return new AsciiString(this.value, start, end - start + 1, false); } + public unsafe bool ContentEquals(string a) + { + if (a == null) + { + return false; + } + if (this.stringValue != null) + { + return this.stringValue.Equals(a); + } + if (this.length != a.Length) + { + return false; + } + + if (this.length > 0) + { + fixed (char* p = a) + fixed (byte* b = &this.value[this.offset]) + for (int i = 0; i < this.length; ++i) + { + if (CharToByte(*(p + i)) != *(b + i) ) + { + return false; + } + } + } + + return true; + } + public bool ContentEquals(ICharSequence a) { if (a == null || a.Count != this.length) diff --git a/src/DotNetty.Common/Utilities/NetUtil.cs b/src/DotNetty.Common/Utilities/NetUtil.cs new file mode 100644 index 000000000..b3995909d --- /dev/null +++ b/src/DotNetty.Common/Utilities/NetUtil.cs @@ -0,0 +1,251 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System; + using System.Runtime.CompilerServices; + using System.Text; + + public static class NetUtil + { + public static string ToSocketAddressString(string host, int port) + { + string portStr = Convert.ToString(port); + return NewSocketAddressStringBuilder(host, portStr, + !IsValidIpV6Address(host)).Append(':').Append(portStr).ToString(); + } + + static StringBuilder NewSocketAddressStringBuilder(string host, string port, bool ipv4) + { + int hostLen = host.Length; + if (ipv4) + { + // Need to include enough space for hostString:port. + return new StringBuilder(hostLen + 1 + port.Length).Append(host); + } + + // Need to include enough space for [hostString]:port. + var stringBuilder = new StringBuilder(hostLen + 3 + port.Length); + if (hostLen > 1 && host[0] == '[' && host[hostLen - 1] == ']') + { + return stringBuilder.Append(host); + } + + return stringBuilder.Append('[').Append(host).Append(']'); + } + + public static bool IsValidIpV6Address(string ip) + { + int end = ip.Length; + if (end < 2) + { + return false; + } + + // strip "[]" + int start; + char c = ip[0]; + if (c == '[') + { + end--; + if (ip[end] != ']') + { + // must have a close ] + return false; + } + + start = 1; + c = ip[1]; + } + else + { + start = 0; + } + + int colons; + int compressBegin; + if (c == ':') + { + // an IPv6 address can start with "::" or with a number + if (ip[start + 1] != ':') + { + return false; + } + + colons = 2; + compressBegin = start; + start += 2; + } + else + { + colons = 0; + compressBegin = -1; + } + + int wordLen = 0; + for (int i = start; i < end; i++) + { + c = ip[i]; + if (IsValidHexChar(c)) + { + if (wordLen < 4) + { + wordLen++; + continue; + } + + return false; + } + + switch (c) + { + case ':': + if (colons > 7) + { + return false; + } + + if (ip[i - 1] == ':') + { + if (compressBegin >= 0) + { + return false; + } + + compressBegin = i - 1; + } + else + { + wordLen = 0; + } + + colons++; + break; + case '.': + // case for the last 32-bits represented as IPv4 x:x:x:x:x:x:d.d.d.d + + // check a normal case (6 single colons) + if (compressBegin < 0 && colons != 6 || + // a special case ::1:2:3:4:5:d.d.d.d allows 7 colons with an + // IPv4 ending, otherwise 7 :'s is bad + (colons == 7 && compressBegin >= start || colons > 7)) + { + return false; + } + + // Verify this address is of the correct structure to contain an IPv4 address. + // It must be IPv4-Mapped or IPv4-Compatible + // (see https://tools.ietf.org/html/rfc4291#section-2.5.5). + int ipv4Start = i - wordLen; + int j = ipv4Start - 2; // index of character before the previous ':'. + if (IsValidIPv4MappedChar(ip[j])) + { + if (!IsValidIPv4MappedChar(ip[j - 1]) || + !IsValidIPv4MappedChar(ip[j - 2]) || + !IsValidIPv4MappedChar(ip[j - 3])) + { + return false; + } + + j -= 5; + } + + for (; j >= start; --j) + { + char tmpChar = ip[j]; + if (tmpChar != '0' && tmpChar != ':') + { + return false; + } + } + + // 7 - is minimum IPv4 address length + int ipv4End = ip.IndexOf('%', ipv4Start + 7); + if (ipv4End < 0) + { + ipv4End = end; + } + + return IsValidIpV4Address(ip, ipv4Start, ipv4End); + case '%': + // strip the interface name/index after the percent sign + end = i; + goto loop; + default: + return false; + } + + loop: + // normal case without compression + if (compressBegin < 0) + { + return colons == 7 && wordLen > 0; + } + + return compressBegin + 2 == end || + // 8 colons is valid only if compression in start or end + wordLen > 0 && (colons < 8 || compressBegin <= start); + } + + // normal case without compression + if (compressBegin < 0) + { + return colons == 7 && wordLen > 0; + } + + return compressBegin + 2 == end || + // 8 colons is valid only if compression in start or end + wordLen > 0 && (colons < 8 || compressBegin <= start); + } + + static bool IsValidIpV4Address(string ip, int from, int toExcluded) + { + int len = toExcluded - from; + int i; + return len <= 15 && len >= 7 && + (i = ip.IndexOf('.', from + 1)) > 0 && IsValidIpV4Word(ip, from, i) && + (i = ip.IndexOf('.', from = i + 2)) > 0 && IsValidIpV4Word(ip, from - 1, i) && + (i = ip.IndexOf('.', from = i + 2)) > 0 && IsValidIpV4Word(ip, from - 1, i) && + IsValidIpV4Word(ip, i + 1, toExcluded); + } + + static bool IsValidIpV4Word(string word, int from, int toExclusive) + { + int len = toExclusive - from; + char c0, c1, c2; + if (len < 1 || len > 3 || (c0 = word[from]) < '0') + { + return false; + } + + if (len == 3) + { + return (c1 = word[from + 1]) >= '0' + && (c2 = word[from + 2]) >= '0' + && (c0 <= '1' && c1 <= '9' && c2 <= '9' + || c0 == '2' && c1 <= '5' && (c2 <= '5' || c1 < '5' && c2 <= '9')); + } + + return c0 <= '9' && (len == 1 || IsValidNumericChar(word[from + 1])); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsValidHexChar(char c) + { + return c >= '0' && c <= '9' || c >= 'A' && c <= 'F' || c >= 'a' && c <= 'f'; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsValidNumericChar(char c) + { + return c >= '0' && c <= '9'; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsValidIPv4MappedChar(char c) + { + return c == 'f' || c == 'F'; + } + } +} diff --git a/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs index 5b854e73b..0f5dce1a9 100644 --- a/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs +++ b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs @@ -66,7 +66,6 @@ void Validate(TIn inbound, TOut outbound) } } - void CheckAdded() { if (!this.handlerAdded) diff --git a/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj b/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj index 860ece271..b15ef7373 100644 --- a/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj +++ b/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs new file mode 100644 index 000000000..cb8f3a6a2 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpServerUpgradeHandlerTest.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public class HttpServerUpgradeHandlerTest + { + sealed class TestUpgradeCodec : HttpServerUpgradeHandler.IUpgradeCodec + { + public ICollection RequiredUpgradeHeaders => new List(); + + public bool PrepareUpgradeResponse(IChannelHandlerContext ctx, IFullHttpRequest upgradeRequest, HttpHeaders upgradeHeaders) => true; + + public void UpgradeTo(IChannelHandlerContext ctx, IFullHttpRequest upgradeRequest) + { + // Ensure that the HttpServerUpgradeHandler is still installed when this is called + Assert.Equal(ctx.Channel.Pipeline.Context(), ctx); + Assert.NotNull(ctx.Channel.Pipeline.Get()); + + // Add a marker handler to signal that the upgrade has happened + ctx.Channel.Pipeline.AddAfter(ctx.Name, "marker", new ChannelHandlerAdapter()); + } + } + + sealed class UpgradeFactory : HttpServerUpgradeHandler.IUpgradeCodecFactory + { + public HttpServerUpgradeHandler.IUpgradeCodec NewUpgradeCodec(ICharSequence protocol) => + new TestUpgradeCodec(); + } + + sealed class ChannelHandler : ChannelDuplexHandler + { + // marker boolean to signal that we're in the `channelRead` method + bool inReadCall; + bool writeUpgradeMessage; + bool writeFlushed; + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + Assert.False(this.inReadCall); + Assert.False(this.writeUpgradeMessage); + + this.inReadCall = true; + try + { + base.ChannelRead(ctx, msg); + // All in the same call stack, the upgrade codec should receive the message, + // written the upgrade response, and upgraded the pipeline. + Assert.True(this.writeUpgradeMessage); + Assert.False(this.writeFlushed); + //Assert.Null(ctx.Channel.Pipeline.Get()); + //Assert.NotNull(ctx.Channel.Pipeline.Get("marker")); + } + finally + { + this.inReadCall = false; + } + } + + public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + { + // We ensure that we're in the read call and defer the write so we can + // make sure the pipeline was reformed irrespective of the flush completing. + Assert.True(this.inReadCall); + this.writeUpgradeMessage = true; + + var completion = new TaskCompletionSource(); + ctx.Channel.EventLoop.Execute(() => + { + ctx.WriteAsync(msg) + .ContinueWith(t => + { + if (t.Status == TaskStatus.RanToCompletion) + { + this.writeFlushed = true; + completion.TryComplete(); + return; + } + completion.TrySetException(new InvalidOperationException($"Invalid WriteAsync task status {t.Status}")); + }, + TaskContinuationOptions.ExecuteSynchronously); + }); + return completion.Task; + } + } + + [Fact] + public void UpgradesPipelineInSameMethodInvocation() + { + var httpServerCodec = new HttpServerCodec(); + var factory = new UpgradeFactory(); + var testInStackFrame = new ChannelHandler(); + + var upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory); + var channel = new EmbeddedChannel(testInStackFrame, httpServerCodec, upgradeHandler); + + const string UpgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: nextprotocol\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + IByteBuffer upgrade = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(UpgradeString)); + + Assert.False(channel.WriteInbound(upgrade)); + //Assert.Null(channel.Pipeline.Get()); + //Assert.NotNull(channel.Pipeline.Get("marker")); + + channel.Flush(); + Assert.Null(channel.Pipeline.Get()); + Assert.NotNull(channel.Pipeline.Get("marker")); + + var upgradeMessage = channel.ReadOutbound(); + const string ExpectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + + "connection: upgrade\r\n" + + "upgrade: nextprotocol\r\n\r\n"; + Assert.Equal(ExpectedHttpResponse, upgradeMessage.ToString(Encoding.ASCII)); + Assert.True(upgradeMessage.Release()); + Assert.False(channel.FinishAndReleaseAll()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshakerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshakerTest.cs new file mode 100644 index 000000000..735467108 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameClientExtensionHandshakerTest.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using Xunit; + + using static Http.WebSockets.Extensions.Compression.DeflateFrameServerExtensionHandshaker; + + public sealed class DeflateFrameClientExtensionHandshakerTest + { + [Fact] + public void WebkitDeflateFrameData() + { + var handshaker = new DeflateFrameClientExtensionHandshaker(true); + + WebSocketExtensionData data = handshaker.NewRequestData(); + + Assert.Equal(XWebkitDeflateFrameExtension, data.Name); + Assert.Empty(data.Parameters); + } + + [Fact] + public void DeflateFrameData() + { + var handshaker = new DeflateFrameClientExtensionHandshaker(false); + + WebSocketExtensionData data = handshaker.NewRequestData(); + + Assert.Equal(DeflateFrameExtension, data.Name); + Assert.Empty(data.Parameters); + } + + [Fact] + public void NormalHandshake() + { + var handshaker = new DeflateFrameClientExtensionHandshaker(false); + + IWebSocketClientExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(DeflateFrameExtension, new Dictionary())); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + } + + [Fact] + public void FailedHandshake() + { + var handshaker = new DeflateFrameClientExtensionHandshaker(false); + + var parameters = new Dictionary + { + { "invalid", "12" } + }; + IWebSocketClientExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(DeflateFrameExtension, parameters)); + + Assert.Null(extension); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshakerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshakerTest.cs new file mode 100644 index 000000000..0f7c873bb --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/DeflateFrameServerExtensionHandshakerTest.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using Xunit; + + using static Http.WebSockets.Extensions.Compression.DeflateFrameServerExtensionHandshaker; + + public sealed class DeflateFrameServerExtensionHandshakerTest + { + [Fact] + public void NormalHandshake() + { + var handshaker = new DeflateFrameServerExtensionHandshaker(); + IWebSocketServerExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(DeflateFrameExtension, new Dictionary())); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + } + + [Fact] + public void WebkitHandshake() + { + var handshaker = new DeflateFrameServerExtensionHandshaker(); + IWebSocketServerExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(XWebkitDeflateFrameExtension, new Dictionary())); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + } + + [Fact] + public void FailedHandshake() + { + var handshaker = new DeflateFrameServerExtensionHandshaker(); + var parameters = new Dictionary + { + { "unknown", "11" } + }; + IWebSocketServerExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(DeflateFrameExtension, parameters)); + + Assert.Null(extension); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateDecoderTest.cs new file mode 100644 index 000000000..bd975ed13 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateDecoderTest.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class PerFrameDeflateDecoderTest + { + readonly Random random; + + public PerFrameDeflateDecoderTest() + { + this.random = new Random(); + } + + [Fact] + public void CompressedFrame() + { + var encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.None, 9, 15, 8)); + var decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + encoderChannel.WriteOutbound(Unpooled.WrappedBuffer(payload)); + var compressedPayload = encoderChannel.ReadOutbound(); + + var compressedFrame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, + compressedPayload.Slice(0, compressedPayload.ReadableBytes - 4)); + + decoderChannel.WriteInbound(compressedFrame); + var uncompressedFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(uncompressedFrame); + Assert.NotNull(uncompressedFrame.Content); + Assert.IsType(uncompressedFrame); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame.Rsv); + Assert.Equal(300, uncompressedFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + uncompressedFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + uncompressedFrame.Release(); + } + + [Fact] + public void NormalFrame() + { + var decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload)); + + decoderChannel.WriteInbound(frame); + var newFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(newFrame); + Assert.NotNull(newFrame.Content); + Assert.IsType(newFrame); + Assert.Equal(WebSocketRsv.Rsv3, newFrame.Rsv); + Assert.Equal(300, newFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + newFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + newFrame.Release(); + } + + // See https://github.com/netty/netty/issues/4348 + [Fact] + public void CompressedEmptyFrame() + { + var encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.None, 9, 15, 8)); + var decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + encoderChannel.WriteOutbound(Unpooled.Empty); + var compressedPayload = encoderChannel.ReadOutbound(); + var compressedFrame = + new BinaryWebSocketFrame(true, WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedPayload); + + decoderChannel.WriteInbound(compressedFrame); + var uncompressedFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(uncompressedFrame); + Assert.NotNull(uncompressedFrame.Content); + Assert.IsType(uncompressedFrame); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame.Rsv); + Assert.Equal(0, uncompressedFrame.Content.ReadableBytes); + uncompressedFrame.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateEncoderTest.cs new file mode 100644 index 000000000..04a939a44 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerFrameDeflateEncoderTest.cs @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class PerFrameDeflateEncoderTest + { + readonly Random random; + + public PerFrameDeflateEncoderTest() + { + this.random = new Random(); + } + + [Fact] + public void CompressedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + var decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.None)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload)); + + encoderChannel.WriteOutbound(frame); + var compressedFrame = encoderChannel.ReadOutbound(); + + Assert.NotNull(compressedFrame); + Assert.NotNull(compressedFrame.Content); + Assert.IsType(compressedFrame); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame.Rsv); + + decoderChannel.WriteInbound(compressedFrame.Content); + decoderChannel.WriteInbound(DeflateDecoder.FrameTail); + var uncompressedPayload = decoderChannel.ReadInbound(); + Assert.Equal(300, uncompressedPayload.ReadableBytes); + + var finalPayload = new byte[300]; + uncompressedPayload.ReadBytes(finalPayload); + + Assert.Equal(payload, finalPayload); + uncompressedPayload.Release(); + } + + [Fact] + public void AlreadyCompressedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + var payload = new byte[300]; + this.random.NextBytes(payload); + + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3 | WebSocketRsv.Rsv1, Unpooled.WrappedBuffer(payload)); + + encoderChannel.WriteOutbound(frame); + var newFrame = encoderChannel.ReadOutbound(); + + Assert.NotNull(newFrame); + Assert.NotNull(newFrame.Content); + Assert.IsType(newFrame); + Assert.Equal(WebSocketRsv.Rsv3 | WebSocketRsv.Rsv1, newFrame.Rsv); + Assert.Equal(300, newFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + newFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + newFrame.Release(); + } + + [Fact] + public void FramementedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + var decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.None)); + + var payload1 = new byte[100]; + this.random.NextBytes(payload1); + var payload2 = new byte[100]; + this.random.NextBytes(payload2); + var payload3 = new byte[100]; + this.random.NextBytes(payload3); + + var frame1 = new BinaryWebSocketFrame(false, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload1)); + var frame2 = new ContinuationWebSocketFrame(false, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload2)); + var frame3 = new ContinuationWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload3)); + + encoderChannel.WriteOutbound(frame1); + encoderChannel.WriteOutbound(frame2); + encoderChannel.WriteOutbound(frame3); + var compressedFrame1 = encoderChannel.ReadOutbound(); + var compressedFrame2 = encoderChannel.ReadOutbound(); + var compressedFrame3 = encoderChannel.ReadOutbound(); + + Assert.NotNull(compressedFrame1); + Assert.NotNull(compressedFrame2); + Assert.NotNull(compressedFrame3); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame1.Rsv); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame2.Rsv); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame3.Rsv); + Assert.False(compressedFrame1.IsFinalFragment); + Assert.False(compressedFrame2.IsFinalFragment); + Assert.True(compressedFrame3.IsFinalFragment); + + decoderChannel.WriteInbound(compressedFrame1.Content); + decoderChannel.WriteInbound(Unpooled.WrappedBuffer(DeflateDecoder.FrameTail)); + var uncompressedPayload1 = decoderChannel.ReadInbound(); + var finalPayload1 = new byte[100]; + uncompressedPayload1.ReadBytes(finalPayload1); + Assert.Equal(payload1, finalPayload1); + uncompressedPayload1.Release(); + + decoderChannel.WriteInbound(compressedFrame2.Content); + decoderChannel.WriteInbound(Unpooled.WrappedBuffer(DeflateDecoder.FrameTail)); + var uncompressedPayload2 = decoderChannel.ReadInbound(); + var finalPayload2 = new byte[100]; + uncompressedPayload2.ReadBytes(finalPayload2); + Assert.Equal(payload2, finalPayload2); + uncompressedPayload2.Release(); + + decoderChannel.WriteInbound(compressedFrame3.Content); + decoderChannel.WriteInbound(Unpooled.WrappedBuffer(DeflateDecoder.FrameTail)); + var uncompressedPayload3 = decoderChannel.ReadInbound(); + var finalPayload3 = new byte[100]; + uncompressedPayload3.ReadBytes(finalPayload3); + Assert.Equal(payload3, finalPayload3); + uncompressedPayload3.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshakerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshakerTest.cs new file mode 100644 index 000000000..fadb505cc --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateClientExtensionHandshakerTest.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Codecs.Compression; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using Xunit; + + using static Http.WebSockets.Extensions.Compression.PerMessageDeflateServerExtensionHandshaker; + + public sealed class PerMessageDeflateClientExtensionHandshakerTest + { + [Fact] + public void NormalData() + { + var handshaker = new PerMessageDeflateClientExtensionHandshaker(); + WebSocketExtensionData data = handshaker.NewRequestData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Equal(ZlibCodecFactory.IsSupportingWindowSizeAndMemLevel ? 1 : 0, data.Parameters.Count); + } + + [Fact] + public void CustomData() + { + var handshaker = new PerMessageDeflateClientExtensionHandshaker(6, true, 10, true, true); + WebSocketExtensionData data = handshaker.NewRequestData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Contains(ClientMaxWindow, data.Parameters.Keys); + Assert.Contains(ServerMaxWindow, data.Parameters.Keys); + Assert.Equal("10", data.Parameters[ServerMaxWindow]); + } + + [Fact] + public void NormalHandshake() + { + var handshaker = new PerMessageDeflateClientExtensionHandshaker(); + + IWebSocketClientExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, new Dictionary())); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + } + + [Fact] + public void CustomHandshake() + { + var handshaker = new PerMessageDeflateClientExtensionHandshaker(6, true, 10, true, true); + + var parameters = new Dictionary + { + { ClientMaxWindow, "12" }, + { ServerMaxWindow, "10" }, + { ClientNoContext, null }, + { ServerNoContext, null } + }; + IWebSocketClientExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + parameters = new Dictionary + { + { ServerMaxWindow, "10" }, + { ServerNoContext, null } + }; + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + parameters = new Dictionary(); + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.Null(extension); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateDecoderTest.cs new file mode 100644 index 000000000..47287bbb7 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateDecoderTest.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class PerMessageDeflateDecoderTest + { + readonly Random random; + + public PerMessageDeflateDecoderTest() + { + this.random = new Random(); + } + + [Fact] + public void CompressedFrame() + { + var encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.None, 9, 15, 8)); + var decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + encoderChannel.WriteOutbound(Unpooled.WrappedBuffer(payload)); + var compressedPayload = encoderChannel.ReadOutbound(); + + var compressedFrame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, + compressedPayload.Slice(0, compressedPayload.ReadableBytes - 4)); + + decoderChannel.WriteInbound(compressedFrame); + var uncompressedFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(uncompressedFrame); + Assert.NotNull(uncompressedFrame.Content); + Assert.IsType(uncompressedFrame); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame.Rsv); + Assert.Equal(300, uncompressedFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + uncompressedFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + uncompressedFrame.Release(); + } + + [Fact] + public void NormalFrame() + { + var decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload)); + + decoderChannel.WriteInbound(frame); + var newFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(newFrame); + Assert.NotNull(newFrame.Content); + Assert.IsType(newFrame); + Assert.Equal(WebSocketRsv.Rsv3, newFrame.Rsv); + Assert.Equal(300, newFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + newFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + newFrame.Release(); + } + + [Fact] + public void FramementedFrame() + { + var encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.None, 9, 15, 8)); + var decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + encoderChannel.WriteOutbound(Unpooled.WrappedBuffer(payload)); + var compressedPayload = encoderChannel.ReadOutbound(); + compressedPayload = compressedPayload.Slice(0, compressedPayload.ReadableBytes - 4); + + int oneThird = compressedPayload.ReadableBytes / 3; + var compressedFrame1 = new BinaryWebSocketFrame(false, + WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, + compressedPayload.Slice(0, oneThird)); + var compressedFrame2 = new ContinuationWebSocketFrame(false, + WebSocketRsv.Rsv3, compressedPayload.Slice(oneThird, oneThird)); + var compressedFrame3 = new ContinuationWebSocketFrame(true, + WebSocketRsv.Rsv3, compressedPayload.Slice(oneThird * 2, + compressedPayload.ReadableBytes - oneThird * 2)); + + decoderChannel.WriteInbound(compressedFrame1.Retain()); + decoderChannel.WriteInbound(compressedFrame2.Retain()); + decoderChannel.WriteInbound(compressedFrame3); + var uncompressedFrame1 = decoderChannel.ReadInbound(); + var uncompressedFrame2 = decoderChannel.ReadInbound(); + var uncompressedFrame3 = decoderChannel.ReadInbound(); + + Assert.NotNull(uncompressedFrame1); + Assert.NotNull(uncompressedFrame2); + Assert.NotNull(uncompressedFrame3); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame1.Rsv); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame2.Rsv); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame3.Rsv); + + IByteBuffer finalPayloadWrapped = Unpooled.WrappedBuffer(uncompressedFrame1.Content, + uncompressedFrame2.Content, uncompressedFrame3.Content); + Assert.Equal(300, finalPayloadWrapped.ReadableBytes); + + var finalPayload = new byte[300]; + finalPayloadWrapped.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + finalPayloadWrapped.Release(); + } + + [Fact] + public void MultiCompressedPayloadWithinFrame() + { + var encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.None, 9, 15, 8)); + var decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + var payload1 = new byte[100]; + this.random.NextBytes(payload1); + var payload2 = new byte[100]; + this.random.NextBytes(payload2); + + encoderChannel.WriteOutbound(Unpooled.WrappedBuffer(payload1)); + var compressedPayload1 = encoderChannel.ReadOutbound(); + encoderChannel.WriteOutbound(Unpooled.WrappedBuffer(payload2)); + var compressedPayload2 = encoderChannel.ReadOutbound(); + + var compressedFrame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, + Unpooled.WrappedBuffer( + compressedPayload1, + compressedPayload2.Slice(0, compressedPayload2.ReadableBytes - 4))); + + decoderChannel.WriteInbound(compressedFrame); + var uncompressedFrame = decoderChannel.ReadInbound(); + + Assert.NotNull(uncompressedFrame); + Assert.NotNull(uncompressedFrame.Content); + Assert.IsType(uncompressedFrame); + Assert.Equal(WebSocketRsv.Rsv3, uncompressedFrame.Rsv); + Assert.Equal(200, uncompressedFrame.Content.ReadableBytes); + + var finalPayload1 = new byte[100]; + uncompressedFrame.Content.ReadBytes(finalPayload1); + Assert.Equal(payload1, finalPayload1); + var finalPayload2 = new byte[100]; + uncompressedFrame.Content.ReadBytes(finalPayload2); + Assert.Equal(payload2, finalPayload2); + uncompressedFrame.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateEncoderTest.cs new file mode 100644 index 000000000..e4379af27 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateEncoderTest.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class PerMessageDeflateEncoderTest + { + readonly Random random; + + public PerMessageDeflateEncoderTest() + { + this.random = new Random(); + } + + [Fact] + public void CompressedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + var decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.None)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload)); + + encoderChannel.WriteOutbound(frame); + var compressedFrame = encoderChannel.ReadOutbound(); + + Assert.NotNull(compressedFrame); + Assert.NotNull(compressedFrame.Content); + Assert.IsType(compressedFrame); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame.Rsv); + + decoderChannel.WriteInbound(compressedFrame.Content); + decoderChannel.WriteInbound(DeflateDecoder.FrameTail); + var uncompressedPayload = decoderChannel.ReadInbound(); + Assert.Equal(300, uncompressedPayload.ReadableBytes); + + var finalPayload = new byte[300]; + uncompressedPayload.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + uncompressedPayload.Release(); + } + + [Fact] + public void AlreadyCompressedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + + var payload = new byte[300]; + this.random.NextBytes(payload); + + var frame = new BinaryWebSocketFrame(true, + WebSocketRsv.Rsv3 | WebSocketRsv.Rsv1, Unpooled.WrappedBuffer(payload)); + + encoderChannel.WriteOutbound(frame); + var newFrame = encoderChannel.ReadOutbound(); + + Assert.NotNull(newFrame); + Assert.NotNull(newFrame.Content); + Assert.IsType(newFrame); + Assert.Equal(WebSocketRsv.Rsv3 | WebSocketRsv.Rsv1, newFrame.Rsv); + Assert.Equal(300, newFrame.Content.ReadableBytes); + + var finalPayload = new byte[300]; + newFrame.Content.ReadBytes(finalPayload); + Assert.Equal(payload, finalPayload); + newFrame.Release(); + } + + [Fact] + public void FramementedFrame() + { + var encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + var decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.None)); + + var payload1 = new byte[100]; + this.random.NextBytes(payload1); + var payload2 = new byte[100]; + this.random.NextBytes(payload2); + var payload3 = new byte[100]; + this.random.NextBytes(payload3); + + var frame1 = new BinaryWebSocketFrame(false, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload1)); + var frame2 = new ContinuationWebSocketFrame(false, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload2)); + var frame3 = new ContinuationWebSocketFrame(true, + WebSocketRsv.Rsv3, Unpooled.WrappedBuffer(payload3)); + + encoderChannel.WriteOutbound(frame1); + encoderChannel.WriteOutbound(frame2); + encoderChannel.WriteOutbound(frame3); + var compressedFrame1 = encoderChannel.ReadOutbound(); + var compressedFrame2 = encoderChannel.ReadOutbound(); + var compressedFrame3 = encoderChannel.ReadOutbound(); + + Assert.NotNull(compressedFrame1); + Assert.NotNull(compressedFrame2); + Assert.NotNull(compressedFrame3); + Assert.Equal(WebSocketRsv.Rsv1 | WebSocketRsv.Rsv3, compressedFrame1.Rsv); + Assert.Equal(WebSocketRsv.Rsv3, compressedFrame2.Rsv); + Assert.Equal(WebSocketRsv.Rsv3, compressedFrame3.Rsv); + Assert.False(compressedFrame1.IsFinalFragment); + Assert.False(compressedFrame2.IsFinalFragment); + Assert.True(compressedFrame3.IsFinalFragment); + + decoderChannel.WriteInbound(compressedFrame1.Content); + var uncompressedPayload1 = decoderChannel.ReadInbound(); + var finalPayload1 = new byte[100]; + uncompressedPayload1.ReadBytes(finalPayload1); + Assert.Equal(payload1, finalPayload1); + uncompressedPayload1.Release(); + + decoderChannel.WriteInbound(compressedFrame2.Content); + var uncompressedPayload2 = decoderChannel.ReadInbound(); + var finalPayload2 = new byte[100]; + uncompressedPayload2.ReadBytes(finalPayload2); + Assert.Equal(payload2, finalPayload2); + uncompressedPayload2.Release(); + + decoderChannel.WriteInbound(compressedFrame3.Content); + decoderChannel.WriteInbound(DeflateDecoder.FrameTail); + var uncompressedPayload3 = decoderChannel.ReadInbound(); + var finalPayload3 = new byte[100]; + uncompressedPayload3.ReadBytes(finalPayload3); + Assert.Equal(payload3, finalPayload3); + uncompressedPayload3.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshakerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshakerTest.cs new file mode 100644 index 000000000..423fa0ce4 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/PerMessageDeflateServerExtensionHandshakerTest.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using Xunit; + + using static Http.WebSockets.Extensions.Compression.PerMessageDeflateServerExtensionHandshaker; + + public sealed class PerMessageDeflateServerExtensionHandshakerTest + { + [Fact] + public void NormalHandshake() + { + var handshaker = new PerMessageDeflateServerExtensionHandshaker(); + IWebSocketServerExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, new Dictionary())); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + WebSocketExtensionData data = extension.NewReponseData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Empty(data.Parameters); + + var parameters = new Dictionary + { + { ClientMaxWindow, null }, + { ClientNoContext, null } + }; + + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + data = extension.NewReponseData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Empty(data.Parameters); + + parameters = new Dictionary + { + { ServerMaxWindow, "12" }, + { ServerNoContext, null } + }; + + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + Assert.Null(extension); + } + + [Fact] + public void CustomHandshake() + { + var handshaker = new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true); + + var parameters = new Dictionary + { + { ClientMaxWindow, null }, + { ServerMaxWindow, "12" }, + { ClientNoContext, null }, + { ServerNoContext, null } + }; + + IWebSocketServerExtension extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + WebSocketExtensionData data = extension.NewReponseData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Contains(ClientMaxWindow, data.Parameters.Keys); + Assert.Equal("10", data.Parameters[ClientMaxWindow]); + Assert.Contains(ServerMaxWindow, data.Parameters.Keys); + Assert.Equal("12", data.Parameters[ServerMaxWindow]); + + parameters = new Dictionary + { + { ServerMaxWindow, "12" }, + { ServerNoContext, null } + }; + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + + Assert.NotNull(extension); + Assert.Equal(WebSocketRsv.Rsv1, extension.Rsv); + Assert.IsType(extension.NewExtensionDecoder()); + Assert.IsType(extension.NewExtensionEncoder()); + + data = extension.NewReponseData(); + + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Equal(2, data.Parameters.Count); + Assert.Contains(ServerMaxWindow, data.Parameters.Keys); + Assert.Equal("12", data.Parameters[ServerMaxWindow]); + Assert.Contains(ServerNoContext, data.Parameters.Keys); + + parameters = new Dictionary(); + extension = handshaker.HandshakeExtension( + new WebSocketExtensionData(PerMessageDeflateExtension, parameters)); + Assert.NotNull(extension); + + data = extension.NewReponseData(); + Assert.Equal(PerMessageDeflateExtension, data.Name); + Assert.Empty(data.Parameters); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/WebSocketServerCompressionHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/WebSocketServerCompressionHandlerTest.cs new file mode 100644 index 000000000..6814321cc --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/Compression/WebSocketServerCompressionHandlerTest.cs @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions.Compression +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Codecs.Http.WebSockets.Extensions.Compression; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static Http.WebSockets.Extensions.Compression.PerMessageDeflateServerExtensionHandshaker; + using static WebSocketExtensionTestUtil; + + public sealed class WebSocketServerCompressionHandlerTest + { + [Fact] + public void NormalSuccess() + { + var ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + IHttpRequest req = NewUpgradeRequest(PerMessageDeflateExtension); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Empty(exts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + + [Fact] + public void ClientWindowSizeSuccess() + { + var ch = new EmbeddedChannel( + new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 10, false, false))); + + IHttpRequest req = NewUpgradeRequest(PerMessageDeflateExtension + "; " + ClientMaxWindow); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Equal("10", exts[0].Parameters[ClientMaxWindow]); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + + [Fact] + public void ClientWindowSizeUnavailable() + { + var ch = new EmbeddedChannel( + new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 10, false, false))); + + IHttpRequest req = NewUpgradeRequest(PerMessageDeflateExtension); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Empty(exts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + + [Fact] + public void ServerWindowSizeSuccess() + { + var ch = new EmbeddedChannel( + new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, true, 15, false, false))); + + IHttpRequest req = NewUpgradeRequest(PerMessageDeflateExtension + "; " + ServerMaxWindow + "=10"); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Equal("10", exts[0].Parameters[ServerMaxWindow]); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + + [Fact] + public void ServerWindowSizeDisable() + { + var ch = new EmbeddedChannel( + new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 15, false, false))); + + IHttpRequest req = NewUpgradeRequest(PerMessageDeflateExtension + "; " + ServerMaxWindow + "=10"); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + + Assert.False(res2.Headers.Contains(HttpHeaderNames.SecWebsocketExtensions)); + Assert.Null(ch.Pipeline.Get()); + Assert.Null(ch.Pipeline.Get()); + } + + [Fact] + public void ServerNoContext() + { + var ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + IHttpRequest req = NewUpgradeRequest( + PerMessageDeflateExtension + "; " + + PerMessageDeflateServerExtensionHandshaker.ServerNoContext); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + + Assert.False(res2.Headers.Contains(HttpHeaderNames.SecWebsocketExtensions)); + Assert.Null(ch.Pipeline.Get()); + Assert.Null(ch.Pipeline.Get()); + } + + [Fact] + public void ClientNoContext() + { + var ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + IHttpRequest req = NewUpgradeRequest( + PerMessageDeflateExtension + "; " + + PerMessageDeflateServerExtensionHandshaker.ClientNoContext); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Empty(exts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + + [Fact] + public void ServerWindowSizeDisableThenFallback() + { + var ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 15, false, false))); + + IHttpRequest req = NewUpgradeRequest( + PerMessageDeflateExtension + "; " + ServerMaxWindow + "=10, " + + PerMessageDeflateExtension); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List exts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(PerMessageDeflateExtension, exts[0].Name); + Assert.Empty(exts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketClientExtensionHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketClientExtensionHandlerTest.cs new file mode 100644 index 000000000..eae3bb177 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketClientExtensionHandlerTest.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Moq; + using Xunit; + + using static WebSocketExtensionTestUtil; + + public sealed class WebSocketClientExtensionHandlerTest + { + readonly Mock mainHandshaker; + readonly Mock fallbackHandshaker; + readonly Mock mainExtension; + readonly Mock fallbackExtension; + + public WebSocketClientExtensionHandlerTest() + { + this.mainHandshaker = new Mock(MockBehavior.Strict); + this.fallbackHandshaker = new Mock(MockBehavior.Strict); + this.mainExtension = new Mock(MockBehavior.Strict); + this.fallbackExtension = new Mock(MockBehavior.Strict); + } + + [Fact] + public void MainSuccess() + { + this.mainHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("main", new Dictionary())); + this.mainHandshaker.Setup(x => x.HandshakeExtension(It.IsAny())) + .Returns(this.mainExtension.Object); + + this.fallbackHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("fallback", new Dictionary())); + + this.mainExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.mainExtension.Setup(x => x.NewExtensionEncoder()).Returns(new DummyEncoder()); + this.mainExtension.Setup(x => x.NewExtensionDecoder()).Returns(new DummyDecoder()); + + var ch = new EmbeddedChannel( + new WebSocketClientExtensionHandler( + this.mainHandshaker.Object, + this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest(null); + ch.WriteOutbound(req); + + var req2 = ch.ReadOutbound(); + Assert.True(req2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List reqExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + IHttpResponse res = NewUpgradeResponse("main"); + ch.WriteInbound(res); + + var res2 = ch.ReadInbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out value)); + List resExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(2, reqExts.Count); + Assert.Equal("main", reqExts[0].Name); + Assert.Equal("fallback", reqExts[1].Name); + + Assert.Single(resExts); + Assert.Equal("main", resExts[0].Name); + Assert.Empty(resExts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + + this.mainExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + } + + [Fact] + public void FallbackSuccess() + { + this.mainHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("main", new Dictionary())); + this.mainHandshaker.Setup(x => x.HandshakeExtension(It.IsAny())) + .Returns(default(IWebSocketClientExtension)); + + this.fallbackHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("fallback", new Dictionary())); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension(It.IsAny())) + .Returns(this.fallbackExtension.Object); + + this.fallbackExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.fallbackExtension.Setup(x => x.NewExtensionEncoder()).Returns(new DummyEncoder()); + this.fallbackExtension.Setup(x => x.NewExtensionDecoder()).Returns(new DummyDecoder()); + + var ch = new EmbeddedChannel( + new WebSocketClientExtensionHandler( + this.mainHandshaker.Object, + this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest(null); + ch.WriteOutbound(req); + + var req2 = ch.ReadOutbound(); + Assert.True(req2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List reqExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + IHttpResponse res = NewUpgradeResponse("fallback"); + ch.WriteInbound(res); + + var res2 = ch.ReadInbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out value)); + List resExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(2, reqExts.Count); + Assert.Equal("main", reqExts[0].Name); + Assert.Equal("fallback", reqExts[1].Name); + + Assert.Single(resExts); + Assert.Equal("fallback", resExts[0].Name); + Assert.Empty(resExts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + + this.fallbackExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + } + + [Fact] + public void AllSuccess() + { + this.mainHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("main", new Dictionary())); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(this.mainExtension.Object); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(default(IWebSocketClientExtension)); + this.fallbackHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("fallback", new Dictionary())); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(default(IWebSocketClientExtension)); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(this.fallbackExtension.Object); + + var mainEncoder = new DummyEncoder(); + var mainDecoder = new DummyDecoder(); + this.mainExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.mainExtension.Setup(x => x.NewExtensionEncoder()).Returns(mainEncoder); + this.mainExtension.Setup(x => x.NewExtensionDecoder()).Returns(mainDecoder); + + var fallbackEncoder = new Dummy2Encoder(); + var fallbackDecoder = new Dummy2Decoder(); + this.fallbackExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv2); + this.fallbackExtension.Setup(x => x.NewExtensionEncoder()).Returns(fallbackEncoder); + this.fallbackExtension.Setup(x => x.NewExtensionDecoder()).Returns(fallbackDecoder); + + var ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + this.mainHandshaker.Object, this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest(null); + ch.WriteOutbound(req); + + var req2 = ch.ReadOutbound(); + Assert.True(req2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List reqExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + IHttpResponse res = NewUpgradeResponse("main, fallback"); + ch.WriteInbound(res); + + var res2 = ch.ReadInbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out value)); + List resExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(2, reqExts.Count); + Assert.Equal("main", reqExts[0].Name); + Assert.Equal("fallback", reqExts[1].Name); + + Assert.Equal(2, resExts.Count); + Assert.Equal("main", resExts[0].Name); + Assert.Equal("fallback", resExts[1].Name); + Assert.NotNull(ch.Pipeline.Context(mainEncoder)); + Assert.NotNull(ch.Pipeline.Context(mainDecoder)); + Assert.NotNull(ch.Pipeline.Context(fallbackEncoder)); + Assert.NotNull(ch.Pipeline.Context(fallbackDecoder)); + + this.mainExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + this.fallbackExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + } + + [Fact] + public void MainAndFallbackUseRsv1WillFail() + { + this.mainHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("main", new Dictionary())); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(this.mainExtension.Object); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(default(IWebSocketClientExtension)); + this.fallbackHandshaker.Setup(x => x.NewRequestData()) + .Returns(new WebSocketExtensionData("fallback", new Dictionary())); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(this.fallbackExtension.Object); + this.mainExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.fallbackExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + + var ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + this.mainHandshaker.Object, this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest(null); + ch.WriteOutbound(req); + + var req2 = ch.ReadOutbound(); + Assert.True(req2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List reqExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + IHttpResponse res = NewUpgradeResponse("main, fallback"); + Assert.Throws(() => ch.WriteInbound(res)); + + Assert.Equal(2, reqExts.Count); + Assert.Equal("main", reqExts[0].Name); + Assert.Equal("fallback", reqExts[1].Name); + + this.mainHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main"))), Times.AtLeastOnce); + this.mainHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), Times.AtLeastOnce); + + this.fallbackHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), Times.AtLeastOnce); + + this.mainExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + this.fallbackExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionTestUtil.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionTestUtil.cs new file mode 100644 index 000000000..6b05427a5 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionTestUtil.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Transport.Channels; + + static class WebSocketExtensionTestUtil + { + public static IHttpRequest NewUpgradeRequest(string ext) + { + var req = new DefaultHttpRequest( + HttpVersion.Http11, + HttpMethod.Get, + "/chat"); + + req.Headers.Set(HttpHeaderNames.Host, "server.example.com"); + req.Headers.Set(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket.ToString().ToLower()); + req.Headers.Set(HttpHeaderNames.Connection, "Upgrade"); + req.Headers.Set(HttpHeaderNames.Origin, "http://example.com"); + if (ext != null) + { + req.Headers.Set(HttpHeaderNames.SecWebsocketExtensions, ext); + } + + return req; + } + + public static IHttpResponse NewUpgradeResponse(string ext) + { + var res = new DefaultHttpResponse( + HttpVersion.Http11, + HttpResponseStatus.SwitchingProtocols); + + res.Headers.Set(HttpHeaderNames.Host, "server.example.com"); + res.Headers.Set(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket.ToString().ToLower()); + res.Headers.Set(HttpHeaderNames.Connection, "Upgrade"); + res.Headers.Set(HttpHeaderNames.Origin, "http://example.com"); + if (ext != null) + { + res.Headers.Set(HttpHeaderNames.SecWebsocketExtensions, ext); + } + + return res; + } + + internal class DummyEncoder : WebSocketExtensionEncoder + { + protected override void Encode(IChannelHandlerContext ctx, WebSocketFrame msg, List ouput) + { + // unused + } + } + + internal class DummyDecoder : WebSocketExtensionDecoder + { + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + // unused + } + } + + internal class Dummy2Encoder : WebSocketExtensionEncoder + { + protected override void Encode(IChannelHandlerContext ctx, WebSocketFrame msg, List ouput) + { + // unused + } + } + + internal class Dummy2Decoder : WebSocketExtensionDecoder + { + protected override void Decode(IChannelHandlerContext ctx, WebSocketFrame msg, List output) + { + // unused + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionUtilTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionUtilTest.cs new file mode 100644 index 000000000..1944b8bef --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketExtensionUtilTest.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions +{ + using DotNetty.Codecs.Http.WebSockets.Extensions; + using Xunit; + + public sealed class WebSocketExtensionUtilTest + { + [Fact] + public void IsWebsocketUpgrade() + { + HttpHeaders headers = new DefaultHttpHeaders(); + Assert.False(WebSocketExtensionUtil.IsWebsocketUpgrade(headers)); + + headers.Add(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + Assert.False(WebSocketExtensionUtil.IsWebsocketUpgrade(headers)); + + headers.Add(HttpHeaderNames.Connection, "Keep-Alive, Upgrade"); + Assert.True(WebSocketExtensionUtil.IsWebsocketUpgrade(headers)); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketServerExtensionHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketServerExtensionHandlerTest.cs new file mode 100644 index 000000000..bdc774cc2 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/Extensions/WebSocketServerExtensionHandlerTest.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets.Extensions +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.WebSockets.Extensions; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Moq; + using Xunit; + + using static WebSocketExtensionTestUtil; + + public sealed class WebSocketServerExtensionHandlerTest + { + readonly Mock mainHandshaker; + readonly Mock fallbackHandshaker; + readonly Mock mainExtension; + readonly Mock fallbackExtension; + + public WebSocketServerExtensionHandlerTest() + { + this.mainHandshaker = new Mock(); + this.fallbackHandshaker = new Mock(); + this.mainExtension = new Mock(); + this.fallbackExtension = new Mock(); + } + + [Fact] + public void MainSuccess() + { + this.mainHandshaker.Setup( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(this.mainExtension.Object); + this.mainHandshaker.Setup( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(default(IWebSocketServerExtension)); + + this.fallbackHandshaker.Setup( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(this.fallbackExtension.Object); + this.fallbackHandshaker.Setup( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(default(IWebSocketServerExtension)); + + this.mainExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.mainExtension.Setup(x => x.NewReponseData()).Returns( + new WebSocketExtensionData("main", new Dictionary())); + this.mainExtension.Setup(x => x.NewExtensionEncoder()).Returns(new DummyEncoder()); + this.mainExtension.Setup(x => x.NewExtensionDecoder()).Returns(new DummyDecoder()); + + this.fallbackExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + + var ch = new EmbeddedChannel( + new WebSocketServerExtensionHandler( + this.mainHandshaker.Object, + this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest("main, fallback"); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List resExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Single(resExts); + Assert.Equal("main", resExts[0].Name); + Assert.Empty(resExts[0].Parameters); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + + this.mainHandshaker.Verify( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main"))), + Times.AtLeastOnce); + this.mainHandshaker.Verify( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), + Times.AtLeastOnce); + this.fallbackHandshaker.Verify( + x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), + Times.AtLeastOnce); + + this.mainExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + this.fallbackExtension.Verify(x => x.Rsv, Times.AtLeastOnce); + } + + [Fact] + public void CompatibleExtensionTogetherSuccess() + { + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(this.mainExtension.Object); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(default(IWebSocketServerExtension)); + + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback")))) + .Returns(this.fallbackExtension.Object); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main")))) + .Returns(default(IWebSocketServerExtension)); + + this.mainExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv1); + this.mainExtension.Setup(x => x.NewReponseData()).Returns( + new WebSocketExtensionData("main", new Dictionary())); + this.mainExtension.Setup(x => x.NewExtensionEncoder()).Returns(new DummyEncoder()); + this.mainExtension.Setup(x => x.NewExtensionDecoder()).Returns(new DummyDecoder()); + + this.fallbackExtension.Setup(x => x.Rsv).Returns(WebSocketRsv.Rsv2); + this.fallbackExtension.Setup(x => x.NewReponseData()).Returns( + new WebSocketExtensionData("fallback", new Dictionary())); + this.fallbackExtension.Setup(x => x.NewExtensionEncoder()).Returns(new Dummy2Encoder()); + this.fallbackExtension.Setup(x => x.NewExtensionDecoder()).Returns(new Dummy2Decoder()); + + var ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + this.mainHandshaker.Object, this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest("main, fallback"); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + Assert.True(res2.Headers.TryGet(HttpHeaderNames.SecWebsocketExtensions, out ICharSequence value)); + List resExts = WebSocketExtensionUtil.ExtractExtensions(value.ToString()); + + Assert.Equal(2, resExts.Count); + Assert.Equal("main", resExts[0].Name); + Assert.Equal("fallback", resExts[1].Name); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + Assert.NotNull(ch.Pipeline.Get()); + + this.mainHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("main"))), + Times.AtLeastOnce); + this.mainHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), + Times.AtLeastOnce); + this.fallbackHandshaker.Verify(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("fallback"))), + Times.AtLeastOnce); + + this.mainExtension.Verify(x => x.Rsv, Times.Exactly(2)); + this.fallbackExtension.Verify(x => x.Rsv, Times.Exactly(2)); + } + + [Fact] + public void NoneExtensionMatchingSuccess() + { + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("unknown")))). + Returns(default(IWebSocketServerExtension)); + this.mainHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("unknown2")))). + Returns(default(IWebSocketServerExtension)); + + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("unknown")))). + Returns(default(IWebSocketServerExtension)); + this.fallbackHandshaker.Setup(x => x.HandshakeExtension( + It.Is(v => v.Name.Equals("unknown2")))). + Returns(default(IWebSocketServerExtension)); + + var ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + this.mainHandshaker.Object, this.fallbackHandshaker.Object)); + + IHttpRequest req = NewUpgradeRequest("unknown, unknown2"); + ch.WriteInbound(req); + + IHttpResponse res = NewUpgradeResponse(null); + ch.WriteOutbound(res); + + var res2 = ch.ReadOutbound(); + + Assert.False(res2.Headers.Contains(HttpHeaderNames.SecWebsocketExtensions)); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket00FrameEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket00FrameEncoderTest.cs new file mode 100644 index 000000000..8bc032b9e --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket00FrameEncoderTest.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class WebSocket00FrameEncoderTest + { + // Test for https://github.com/netty/netty/issues/2768 + [Fact] + public void MultipleWebSocketCloseFrames() + { + var channel = new EmbeddedChannel(new WebSocket00FrameEncoder()); + Assert.True(channel.WriteOutbound(new CloseWebSocketFrame())); + Assert.True(channel.WriteOutbound(new CloseWebSocketFrame())); + Assert.True(channel.Finish()); + AssertCloseWebSocketFrame(channel); + AssertCloseWebSocketFrame(channel); + Assert.Null(channel.ReadOutbound()); + } + + static void AssertCloseWebSocketFrame(EmbeddedChannel channel) + { + var buf = channel.ReadOutbound(); + Assert.Equal(2, buf.ReadableBytes); + Assert.Equal((byte)0xFF, buf.ReadByte()); + Assert.Equal((byte)0x00, buf.ReadByte()); + buf.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08EncoderDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08EncoderDecoderTest.cs new file mode 100644 index 000000000..32506ab97 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08EncoderDecoderTest.cs @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class WebSocket08EncoderDecoderTest : IDisposable + { + readonly IByteBuffer binTestData; + readonly string strTestData; + + public WebSocket08EncoderDecoderTest() + { + const int MaxTestdataLength = 100 * 1024; + + this.binTestData = Unpooled.Buffer(MaxTestdataLength); + byte j = 0; + for (int i = 0; i < MaxTestdataLength; i++) + { + this.binTestData.Array[i] = j; + j++; + } + + var s = new StringBuilder(); + char c = 'A'; + for (int i = 0; i < MaxTestdataLength; i++) + { + s.Append(c); + c++; + if (c == 'Z') + { + c = 'A'; + } + } + this.strTestData = s.ToString(); + } + + [Fact] + public void WebSocketEncodingAndDecoding() + { + // Test without masking + var outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(false)); + var inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(false, false, 1024 * 1024, false)); + this.ExecuteTests(outChannel, inChannel); + + // Test with activated masking + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(true, false, 1024 * 1024, false)); + this.ExecuteTests(outChannel, inChannel); + + // Test with activated masking and an unmasked expecting but forgiving decoder + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(false, false, 1024 * 1024, true)); + this.ExecuteTests(outChannel, inChannel); + } + + void ExecuteTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel) + { + // Test at the boundaries of each message type, because this shifts the position of the mask field + // Test min. 4 lengths to check for problems related to an uneven frame length + this.ExecuteTests(outChannel, inChannel, 0); + this.ExecuteTests(outChannel, inChannel, 1); + this.ExecuteTests(outChannel, inChannel, 2); + this.ExecuteTests(outChannel, inChannel, 3); + this.ExecuteTests(outChannel, inChannel, 4); + this.ExecuteTests(outChannel, inChannel, 5); + + this.ExecuteTests(outChannel, inChannel, 125); + this.ExecuteTests(outChannel, inChannel, 126); + this.ExecuteTests(outChannel, inChannel, 127); + this.ExecuteTests(outChannel, inChannel, 128); + this.ExecuteTests(outChannel, inChannel, 129); + + this.ExecuteTests(outChannel, inChannel, 65535); + this.ExecuteTests(outChannel, inChannel, 65536); + this.ExecuteTests(outChannel, inChannel, 65537); + this.ExecuteTests(outChannel, inChannel, 65538); + this.ExecuteTests(outChannel, inChannel, 65539); + } + + void ExecuteTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) + { + this.TextWithLen(outChannel, inChannel, testDataLength); + this.BinaryWithLen(outChannel, inChannel, testDataLength); + } + + void TextWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) + { + string testStr = this.strTestData.Substring(0, testDataLength); + outChannel.WriteOutbound(new TextWebSocketFrame(testStr)); + + // Transfer encoded data into decoder + // Loop because there might be multiple frames (gathering write) + while (true) + { + var encoded = outChannel.ReadOutbound(); + if (encoded != null) + { + inChannel.WriteInbound(encoded); + } + else + { + break; + } + } + + var txt = inChannel.ReadInbound(); + Assert.NotNull(txt); + Assert.Equal(testStr, txt.Text()); + txt.Release(); + } + + void BinaryWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) + { + this.binTestData.Retain(); // need to retain for sending and still keeping it + this.binTestData.SetIndex(0, testDataLength); // Send only len bytes + outChannel.WriteOutbound(new BinaryWebSocketFrame(this.binTestData)); + + // Transfer encoded data into decoder + // Loop because there might be multiple frames (gathering write) + while (true) + { + var encoded = outChannel.ReadOutbound(); + if (encoded != null) + { + inChannel.WriteInbound(encoded); + } + else + { + break; + } + } + + var binFrame = inChannel.ReadInbound(); + Assert.NotNull(binFrame); + int readable = binFrame.Content.ReadableBytes; + Assert.Equal(readable, testDataLength); + for (int i = 0; i < testDataLength; i++) + { + Assert.Equal(this.binTestData.GetByte(i), binFrame.Content.GetByte(i)); + } + + binFrame.Release(); + } + + public void Dispose() + { + this.binTestData.SafeRelease(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08FrameDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08FrameDecoderTest.cs new file mode 100644 index 000000000..ff04a874d --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocket08FrameDecoderTest.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Transport.Channels; + using Moq; + using Xunit; + + public sealed class WebSocket08FrameDecoderTest + { + [Fact] + public void ChannelInactive() + { + var decoder = new WebSocket08FrameDecoder(true, true, 65535, false); + var ctx = new Mock(MockBehavior.Strict); + ctx.Setup(x => x.FireChannelInactive()).Returns(ctx.Object); + + decoder.ChannelInactive(ctx.Object); + ctx.Verify(x => x.FireChannelInactive(), Times.Once); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker00Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker00Test.cs new file mode 100644 index 000000000..d9876556f --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker00Test.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker00Test : WebSocketClientHandshakerTest + { + protected override WebSocketClientHandshaker NewHandshaker(Uri uri) => + new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, null, null, 1024); + + protected override AsciiString GetOriginHeaderName() => HttpHeaderNames.Origin; + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker07Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker07Test.cs new file mode 100644 index 000000000..c779213bd --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker07Test.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + + public class WebSocketClientHandshaker07Test : WebSocketClientHandshakerTest + { + protected override WebSocketClientHandshaker NewHandshaker(Uri uri) => new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, null, false, null, 1024); + + protected override AsciiString GetOriginHeaderName() => HttpHeaderNames.SecWebsocketOrigin; + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker08Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker08Test.cs new file mode 100644 index 000000000..e1fa0df3d --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker08Test.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Codecs.Http.WebSockets; + + public class WebSocketClientHandshaker08Test : WebSocketClientHandshaker07Test + { + protected override WebSocketClientHandshaker NewHandshaker(Uri uri) => new WebSocketClientHandshaker08(uri, WebSocketVersion.V08, null, false, null, 1024); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker13Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker13Test.cs new file mode 100644 index 000000000..41b6901cd --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshaker13Test.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Codecs.Http.WebSockets; + + public class WebSocketClientHandshaker13Test : WebSocketClientHandshaker07Test + { + protected override WebSocketClientHandshaker NewHandshaker(Uri uri) => new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, null, false, null, 1024); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshakerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshakerTest.cs new file mode 100644 index 000000000..1d300ae6c --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketClientHandshakerTest.cs @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Internal; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public abstract class WebSocketClientHandshakerTest + { + protected abstract WebSocketClientHandshaker NewHandshaker(Uri uri); + + protected abstract AsciiString GetOriginHeaderName(); + + [Fact] + public void HostHeaderWs() + { + foreach (string scheme in new[] { "ws://", "http://" }) + { + foreach (string host in new[] { /*"localhost", "127.0.0.1", "[::1]",*/ "Netty.io" }) + { + string enter = scheme + host; + + this.HostHeader(enter, host); + this.HostHeader(enter + '/', host); + this.HostHeader(enter + ":80", host); + this.HostHeader(enter + ":443", host + ":443"); + this.HostHeader(enter + ":9999", host + ":9999"); + this.HostHeader(enter + "/path", host); + this.HostHeader(enter + ":80/path", host); + this.HostHeader(enter + ":443/path", host + ":443"); + this.HostHeader(enter + ":9999/path", host + ":9999"); + } + } + } + + [Fact] + public void HostHeaderWss() + { + foreach (string scheme in new[] { "wss://", "https://" }) + { + foreach (string host in new[] { "localhost", "127.0.0.1", "[::1]", "Netty.io" }) + { + string enter = scheme + host; + + this.HostHeader(enter, host); + this.HostHeader(enter + '/', host); + this.HostHeader(enter + ":80", host + ":80"); + this.HostHeader(enter + ":443", host); + this.HostHeader(enter + ":9999", host + ":9999"); + this.HostHeader(enter + "/path", host); + this.HostHeader(enter + ":80/path", host + ":80"); + this.HostHeader(enter + ":443/path", host); + this.HostHeader(enter + ":9999/path", host + ":9999"); + } + } + } + + [Fact] + public void HostHeaderWithoutScheme() + { + this.HostHeader("//localhost/", "localhost"); + this.HostHeader("//localhost/path", "localhost"); + this.HostHeader("//localhost:80/", "localhost:80"); + this.HostHeader("//localhost:443/", "localhost:443"); + this.HostHeader("//localhost:9999/", "localhost:9999"); + } + + [Fact] + public void OriginHeaderWs() + { + foreach (string scheme in new[] { "ws://", "http://" }) + { + foreach (string host in new[] { "localhost", "127.0.0.1", "[::1]", "NETTY.IO" }) + { + string enter = scheme + host; + string expect = "http://" + host.ToLower(); + + this.OriginHeader(enter, expect); + this.OriginHeader(enter + '/', expect); + this.OriginHeader(enter + ":80", expect); + this.OriginHeader(enter + ":443", expect + ":443"); + this.OriginHeader(enter + ":9999", expect + ":9999"); + this.OriginHeader(enter + "/path%20with%20ws", expect); + this.OriginHeader(enter + ":80/path%20with%20ws", expect); + this.OriginHeader(enter + ":443/path%20with%20ws", expect + ":443"); + this.OriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); + } + } + } + + [Fact] + public void OriginHeaderWss() + { + foreach (string scheme in new[] { "wss://", "https://" }) + { + foreach (string host in new[] { "localhost", "127.0.0.1", "[::1]", "NETTY.IO" }) + { + string enter = scheme + host; + string expect = "https://" + host.ToLower(); + + this.OriginHeader(enter, expect); + this.OriginHeader(enter + '/', expect); + this.OriginHeader(enter + ":80", expect + ":80"); + this.OriginHeader(enter + ":443", expect); + this.OriginHeader(enter + ":9999", expect + ":9999"); + this.OriginHeader(enter + "/path%20with%20ws", expect); + this.OriginHeader(enter + ":80/path%20with%20ws", expect + ":80"); + this.OriginHeader(enter + ":443/path%20with%20ws", expect); + this.OriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); + } + } + } + + [Fact] + public void OriginHeaderWithoutScheme() + { + this.OriginHeader("//localhost/", "http://localhost"); + this.OriginHeader("//localhost/path", "http://localhost"); + + // http scheme by port + this.OriginHeader("//localhost:80/", "http://localhost"); + this.OriginHeader("//localhost:80/path", "http://localhost"); + + // https scheme by port + this.OriginHeader("//localhost:443/", "https://localhost"); + this.OriginHeader("//localhost:443/path", "https://localhost"); + + // http scheme for non standard port + this.OriginHeader("//localhost:9999/", "http://localhost:9999"); + this.OriginHeader("//localhost:9999/path", "http://localhost:9999"); + + // convert host to lower case + this.OriginHeader("//LOCALHOST/", "http://localhost"); + } + + void HostHeader(string uri, string expected) => + this.HeaderDefaultHttp(uri, HttpHeaderNames.Host, expected); + + void OriginHeader(string uri, string expected) => + this.HeaderDefaultHttp(uri, this.GetOriginHeaderName(), expected); + + protected void HeaderDefaultHttp(string uri, AsciiString header, string expectedValue) + { + Assert.True(Uri.TryCreate(uri, UriKind.RelativeOrAbsolute, out Uri originalUri)); + WebSocketClientHandshaker handshaker = this.NewHandshaker(originalUri); + IFullHttpRequest request = handshaker.NewHandshakeRequest(); + try + { + Assert.True(request.Headers.TryGet(header, out ICharSequence value)); + Assert.Equal(expectedValue, value.ToString(), true); + } + finally + { + request.Release(); + } + } + + [Fact] + public void RawPath() + { + var uri = new Uri("ws://localhost:9999/path%20with%20ws"); + WebSocketClientHandshaker handshaker = this.NewHandshaker(uri); + IFullHttpRequest request = handshaker.NewHandshakeRequest(); + try + { + Assert.Equal("/path%20with%20ws", request.Uri); + } + finally + { + request.Release(); + } + } + + [Fact] + public void RawPathWithQuery() + { + var uri = new Uri("ws://localhost:9999/path%20with%20ws?a=b%20c"); + WebSocketClientHandshaker handshaker = this.NewHandshaker(uri); + IFullHttpRequest request = handshaker.NewHandshakeRequest(); + try + { + Assert.Equal("/path%20with%20ws?a=b%20c", request.Uri); + } + finally + { + request.Release(); + } + } + + [Fact] + public void HttpResponseAndFrameInSameBuffer() => this.TestHttpResponseAndFrameInSameBuffer(false); + + [Fact] + public void HttpResponseAndFrameInSameBufferCodec() => this.TestHttpResponseAndFrameInSameBuffer(true); + + void TestHttpResponseAndFrameInSameBuffer(bool codec) + { + string url = "ws://localhost:9999/ws"; + WebSocketClientHandshaker shaker = this.NewHandshaker(new Uri(url)); + var handshaker = new Handshaker(shaker); + + var data = new byte[24]; + PlatformDependent.GetThreadLocalRandom().NextBytes(data); + + // Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these + // to test the actual handshaker. + var factory = new WebSocketServerHandshakerFactory(url, null, false); + WebSocketServerHandshaker socketServerHandshaker = factory.NewHandshaker(shaker.NewHandshakeRequest()); + var websocketChannel = new EmbeddedChannel(socketServerHandshaker.NewWebSocketEncoder(), + socketServerHandshaker.NewWebsocketDecoder()); + Assert.True(websocketChannel.WriteOutbound(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(data)))); + + byte[] bytes = Encoding.ASCII.GetBytes("HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n"); + + CompositeByteBuffer compositeByteBuf = Unpooled.CompositeBuffer(); + compositeByteBuf.AddComponent(true, Unpooled.WrappedBuffer(bytes)); + for (;;) + { + var frameBytes = websocketChannel.ReadOutbound(); + if (frameBytes == null) + { + break; + } + compositeByteBuf.AddComponent(true, frameBytes); + } + + var ch = new EmbeddedChannel(new HttpObjectAggregator(int.MaxValue), new Handler(handshaker)); + if (codec) + { + ch.Pipeline.AddFirst(new HttpClientCodec()); + } + else + { + ch.Pipeline.AddFirst(new HttpRequestEncoder(), new HttpResponseDecoder()); + } + + // We need to first write the request as HttpClientCodec will fail if we receive a response before a request + // was written. + shaker.HandshakeAsync(ch).Wait(); + for (;;) + { + // Just consume the bytes, we are not interested in these. + var buf = ch.ReadOutbound(); + if (buf == null) + { + break; + } + buf.Release(); + } + Assert.True(ch.WriteInbound(compositeByteBuf)); + Assert.True(ch.Finish()); + + var frame = ch.ReadInbound(); + IByteBuffer expect = Unpooled.WrappedBuffer(data); + try + { + Assert.Equal(expect, frame.Content); + Assert.True(frame.IsFinalFragment); + Assert.Equal(0, frame.Rsv); + } + finally + { + expect.Release(); + frame.Release(); + } + } + + sealed class Handshaker : WebSocketClientHandshaker + { + readonly WebSocketClientHandshaker shaker; + + public Handshaker(WebSocketClientHandshaker shaker) + : base(shaker.Uri, shaker.Version, null, EmptyHttpHeaders.Default, int.MaxValue) + { + this.shaker = shaker; + } + + protected internal override IFullHttpRequest NewHandshakeRequest() => this.shaker.NewHandshakeRequest(); + + protected override void Verify(IFullHttpResponse response) + { + // Not do any verification, so we do not need to care sending the correct headers etc in the test, + // which would just make things more complicated. + } + + protected internal override IWebSocketFrameDecoder NewWebSocketDecoder() => this.shaker.NewWebSocketDecoder(); + + protected internal override IWebSocketFrameEncoder NewWebSocketEncoder() => this.shaker.NewWebSocketEncoder(); + } + + sealed class Handler : SimpleChannelInboundHandler + { + readonly Handshaker handshaker; + + public Handler(Handshaker handshaker) + { + this.handshaker = handshaker; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, IFullHttpResponse msg) + { + this.handshaker.FinishHandshake(ctx.Channel, msg); + ctx.Channel.Pipeline.Remove(this); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketFrameAggregatorTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketFrameAggregatorTest.cs new file mode 100644 index 000000000..4c0db8aeb --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketFrameAggregatorTest.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public class WebSocketFrameAggregatorTest + { + readonly byte[] content1 = Encoding.UTF8.GetBytes("Content1"); + readonly byte[] content2 = Encoding.UTF8.GetBytes("Content2"); + readonly byte[] content3 = Encoding.UTF8.GetBytes("Content3"); + readonly byte[] aggregatedContent; + + public WebSocketFrameAggregatorTest() + { + this.aggregatedContent = new byte[this.content1.Length + this.content2.Length + this.content3.Length]; + Array.Copy(this.content1, 0, this.aggregatedContent, 0, this.content1.Length); + Array.Copy(this.content2, 0, this.aggregatedContent, this.content1.Length, this.content2.Length); + Array.Copy(this.content3, 0, this.aggregatedContent, this.content1.Length + this.content2.Length, this.content3.Length); + } + + [Fact] + public void AggregationBinary() + { + var channel = new EmbeddedChannel(new WebSocketFrameAggregator(int.MaxValue)); + channel.WriteInbound(new BinaryWebSocketFrame(true, 1, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new BinaryWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2))); + channel.WriteInbound(new PingWebSocketFrame(Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new PongWebSocketFrame(Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.WrappedBuffer(this.content3))); + + Assert.True(channel.Finish()); + + var frame = channel.ReadInbound(); + Assert.True(frame.IsFinalFragment); + Assert.Equal(1, frame.Rsv); + Assert.Equal(this.content1, ToBytes(frame.Content)); + + var frame2 = channel.ReadInbound(); + Assert.True(frame2.IsFinalFragment); + Assert.Equal(0, frame2.Rsv); + Assert.Equal(this.content1, ToBytes(frame2.Content)); + + var frame3 = channel.ReadInbound(); + Assert.True(frame3.IsFinalFragment); + Assert.Equal(0, frame3.Rsv); + Assert.Equal(this.content1, ToBytes(frame3.Content)); + + var frame4 = channel.ReadInbound(); + Assert.True(frame4.IsFinalFragment); + Assert.Equal(0, frame4.Rsv); + Assert.Equal(this.aggregatedContent, ToBytes(frame4.Content)); + + Assert.Null(channel.ReadInbound()); + } + + [Fact] + public void AggregationText() + { + var channel = new EmbeddedChannel(new WebSocketFrameAggregator(int.MaxValue)); + channel.WriteInbound(new TextWebSocketFrame(true, 1, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new TextWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2))); + channel.WriteInbound(new PingWebSocketFrame(Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new PongWebSocketFrame(Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.WrappedBuffer(this.content3))); + + Assert.True(channel.Finish()); + + var frame = channel.ReadInbound(); + Assert.True(frame.IsFinalFragment); + Assert.Equal(1, frame.Rsv); + Assert.Equal(this.content1, ToBytes(frame.Content)); + + var frame2 = channel.ReadInbound(); + Assert.True(frame2.IsFinalFragment); + Assert.Equal(0, frame2.Rsv); + Assert.Equal(this.content1, ToBytes(frame2.Content)); + + var frame3 = channel.ReadInbound(); + Assert.True(frame3.IsFinalFragment); + Assert.Equal(0, frame3.Rsv); + Assert.Equal(this.content1, ToBytes(frame3.Content)); + + var frame4 = channel.ReadInbound(); + Assert.True(frame4.IsFinalFragment); + Assert.Equal(0, frame4.Rsv); + Assert.Equal(this.aggregatedContent, ToBytes(frame4.Content)); + + Assert.Null(channel.ReadInbound()); + } + + [Fact] + public void TextFrameTooBig() + { + var channel = new EmbeddedChannel(new WebSocketFrameAggregator(8)); + channel.WriteInbound(new BinaryWebSocketFrame(true, 1, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new BinaryWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content1))); + Assert.Throws(() => + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2)))); + + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2))); + channel.WriteInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.WrappedBuffer(this.content2))); + + channel.WriteInbound(new BinaryWebSocketFrame(true, 1, Unpooled.WrappedBuffer(this.content1))); + channel.WriteInbound(new BinaryWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content1))); + Assert.Throws(() => + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2)))); + + channel.WriteInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.WrappedBuffer(this.content2))); + channel.WriteInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.WrappedBuffer(this.content2))); + for (;;) + { + var msg = channel.ReadInbound(); + if (msg == null) + { + break; + } + ReferenceCountUtil.Release(msg); + } + channel.Finish(); + } + + static byte[] ToBytes(IByteBuffer buf) + { + var bytes = new byte[buf.ReadableBytes]; + buf.ReadBytes(bytes); + buf.Release(); + return bytes; + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketHandshakeHandOverTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketHandshakeHandOverTest.cs new file mode 100644 index 000000000..cf527aad6 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketHandshakeHandOverTest.cs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public class WebSocketHandshakeHandOverTest + { + bool serverReceivedHandshake; + WebSocketServerProtocolHandler.HandshakeComplete serverHandshakeComplete; + bool clientReceivedHandshake; + bool clientReceivedMessage; + + public WebSocketHandshakeHandOverTest() + { + this.serverReceivedHandshake = false; + this.serverHandshakeComplete = null; + this.clientReceivedHandshake = false; + this.clientReceivedMessage = false; + } + + [Fact] + public void Handover() + { + var serverHandler = new ServerHandler(this); + EmbeddedChannel serverChannel = CreateServerChannel(serverHandler); + EmbeddedChannel clientChannel = CreateClientChannel(new ClientHandler(this)); + + // Transfer the handshake from the client to the server + TransferAllDataWithMerge(clientChannel, serverChannel); + Assert.True(serverHandler.Completion.Wait(TimeSpan.FromSeconds(1))); + + Assert.True(this.serverReceivedHandshake); + Assert.NotNull(this.serverHandshakeComplete); + Assert.Equal("/test", this.serverHandshakeComplete.RequestUri); + Assert.Equal(8, this.serverHandshakeComplete.RequestHeaders.Size); + Assert.Equal("test-proto-2", this.serverHandshakeComplete.SelectedSubprotocol); + + // Transfer the handshake response and the websocket message to the client + TransferAllDataWithMerge(serverChannel, clientChannel); + Assert.True(this.clientReceivedHandshake); + Assert.True(this.clientReceivedMessage); + } + + sealed class ServerHandler : SimpleChannelInboundHandler + { + readonly WebSocketHandshakeHandOverTest owner; + readonly TaskCompletionSource completion; + + public ServerHandler(WebSocketHandshakeHandOverTest owner) + { + this.owner = owner; + this.completion = new TaskCompletionSource(); + } + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) + { + if (evt is WebSocketServerProtocolHandler.HandshakeComplete complete) + { + this.owner.serverReceivedHandshake = true; + this.owner.serverHandshakeComplete = complete; + + // immediately send a message to the client on connect + context.WriteAndFlushAsync(new TextWebSocketFrame("abc")) + .LinkOutcome(this.completion); + } + } + + public Task Completion => this.completion.Task; + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + // Empty + } + } + + sealed class ClientHandler : SimpleChannelInboundHandler + { + readonly WebSocketHandshakeHandOverTest owner; + + public ClientHandler(WebSocketHandshakeHandOverTest owner) + { + this.owner = owner; + } + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) + { + if (evt is WebSocketClientProtocolHandler.ClientHandshakeStateEvent stateEvent + && stateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HandshakeComplete) + { + this.owner.clientReceivedHandshake = true; + } + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + if (msg is TextWebSocketFrame) + { + this.owner.clientReceivedMessage = true; + } + } + } + + static void TransferAllDataWithMerge(EmbeddedChannel srcChannel, EmbeddedChannel dstChannel) + { + IByteBuffer mergedBuffer = null; + for (;;) + { + var srcData = srcChannel.ReadOutbound(); + if (srcData != null) + { + Assert.IsAssignableFrom(srcData); + var srcBuf = (IByteBuffer)srcData; + try + { + if (mergedBuffer == null) + { + mergedBuffer = Unpooled.Buffer(); + } + mergedBuffer.WriteBytes(srcBuf); + } + finally + { + srcBuf.Release(); + } + } + else + { + break; + } + } + + if (mergedBuffer != null) + { + dstChannel.WriteInbound(mergedBuffer); + } + } + + static EmbeddedChannel CreateClientChannel(IChannelHandler handler) => new EmbeddedChannel( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + new WebSocketClientProtocolHandler( + new Uri("ws://localhost:1234/test"), + WebSocketVersion.V13, + "test-proto-2", + false, + null, + 65536), + handler); + + static EmbeddedChannel CreateServerChannel(IChannelHandler handler) => new EmbeddedChannel( + new HttpServerCodec(), + new HttpObjectAggregator(8192), + new WebSocketServerProtocolHandler("/test", "test-proto-1, test-proto-2", false), + handler); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketProtocolHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketProtocolHandlerTest.cs new file mode 100644 index 000000000..10353d35a --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketProtocolHandlerTest.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public class WebSocketProtocolHandlerTest + { + [Fact] + public void PingFrame() + { + IByteBuffer pingData = Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Hello, world")); + var channel = new EmbeddedChannel(new Handler()); + + var inputMessage = new PingWebSocketFrame(pingData); + Assert.False(channel.WriteInbound(inputMessage)); // the message was not propagated inbound + + // a Pong frame was written to the channel + var response = channel.ReadOutbound(); + Assert.Equal(pingData, response.Content); + + pingData.Release(); + Assert.False(channel.Finish()); + } + + [Fact] + public void PongFrameDropFrameFalse() + { + var channel = new EmbeddedChannel(new Handler(false)); + + var pingResponse = new PongWebSocketFrame(); + Assert.True(channel.WriteInbound(pingResponse)); + + AssertPropagatedInbound(pingResponse, channel); + + pingResponse.Release(); + Assert.False(channel.Finish()); + } + + [Fact] + public void PongFrameDropFrameTrue() + { + var channel = new EmbeddedChannel(new Handler()); + + var pingResponse = new PongWebSocketFrame(); + Assert.False(channel.WriteInbound(pingResponse)); // message was not propagated inbound + } + + [Fact] + public void TextFrame() + { + var channel = new EmbeddedChannel(new Handler()); + + var textFrame = new TextWebSocketFrame(); + Assert.True(channel.WriteInbound(textFrame)); + + AssertPropagatedInbound(textFrame, channel); + + textFrame.Release(); + Assert.False(channel.Finish()); + } + + static void AssertPropagatedInbound(T message, EmbeddedChannel channel) + where T : WebSocketFrame + { + var propagatedResponse = channel.ReadInbound(); + Assert.Equal(message, propagatedResponse); + } + + sealed class Handler : WebSocketProtocolHandler + { + public Handler(bool dropPongFrames = true) : base(dropPongFrames) + { + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketRequestBuilder.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketRequestBuilder.cs new file mode 100644 index 000000000..e4f32f133 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketRequestBuilder.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + + using static Http.HttpVersion; + + public class WebSocketRequestBuilder + { + HttpVersion httpVersion; + HttpMethod method; + string uri; + string host; + string upgrade; + string connection; + string key; + string origin; + WebSocketVersion version; + + public WebSocketRequestBuilder HttpVersion(HttpVersion value) + { + this.httpVersion = value; + return this; + } + + public WebSocketRequestBuilder Method(HttpMethod value) + { + this.method = value; + return this; + } + + public WebSocketRequestBuilder Uri(string value) + { + this.uri = value; + return this; + } + + public WebSocketRequestBuilder Host(string value) + { + this.host = value; + return this; + } + + public WebSocketRequestBuilder Upgrade(string value) + { + this.upgrade = value; + return this; + } + + public WebSocketRequestBuilder Upgrade(AsciiString value) + { + this.upgrade = this.upgrade == null ? null : value.ToString(); + return this; + } + + public WebSocketRequestBuilder Connection(string value) + { + this.connection = value; + return this; + } + + public WebSocketRequestBuilder Key(string value) + { + this.key = value; + return this; + } + + public WebSocketRequestBuilder Origin(string value) + { + this.origin = value; + return this; + } + + public WebSocketRequestBuilder Version13() + { + this.version = WebSocketVersion.V13; + return this; + } + + public WebSocketRequestBuilder Version8() + { + this.version = WebSocketVersion.V08; + return this; + } + + public WebSocketRequestBuilder Version00() + { + this.version = null; + return this; + } + + public WebSocketRequestBuilder NoVersion() + { + return this; + } + + public IFullHttpRequest Build() + { + var req = new DefaultFullHttpRequest(this.httpVersion, this.method, this.uri); + HttpHeaders headers = req.Headers; + + if (this.host != null) + { + headers.Set(HttpHeaderNames.Host, this.host); + } + if (this.upgrade != null) + { + headers.Set(HttpHeaderNames.Upgrade, this.upgrade); + } + if (this.connection != null) + { + headers.Set(HttpHeaderNames.Connection, this.connection); + } + if (this.key != null) + { + headers.Set(HttpHeaderNames.SecWebsocketKey, this.key); + } + if (this.origin != null) + { + headers.Set(HttpHeaderNames.SecWebsocketOrigin, this.origin); + } + if (this.version != null) + { + headers.Set(HttpHeaderNames.SecWebsocketVersion, this.version.ToHttpHeaderValue()); + } + return req; + } + + public static IHttpRequest Successful() => new WebSocketRequestBuilder() + .HttpVersion(Http11) + .Method(HttpMethod.Get) + .Uri("/test") + .Host("server.example.com") + .Upgrade(HttpHeaderValues.Websocket) + .Key("dGhlIHNhbXBsZSBub25jZQ==") + .Origin("http://example.com") + .Version13() + .Build(); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker00Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker00Test.cs new file mode 100644 index 000000000..9778bd0cd --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker00Test.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static HttpVersion; + + public class WebSocketServerHandshaker00Test + { + [Fact] + public void PerformOpeningHandshake() => PerformOpeningHandshake0(true); + + [Fact] + public void PerformOpeningHandshakeSubProtocolNotSupported() => PerformOpeningHandshake0(false); + + static void PerformOpeningHandshake0(bool subProtocol) + { + var ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + var req = new DefaultFullHttpRequest(Http11, HttpMethod.Get, "/chat", + Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("^n:ds[4U"))); + + req.Headers.Set(HttpHeaderNames.Host, "server.example.com"); + req.Headers.Set(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + req.Headers.Set(HttpHeaderNames.Connection, "Upgrade"); + req.Headers.Set(HttpHeaderNames.Origin, "http://example.com"); + req.Headers.Set(HttpHeaderNames.SecWebsocketKey1, "4 @1 46546xW%0l 1 5"); + req.Headers.Set(HttpHeaderNames.SecWebsocketKey2, "12998 5 Y3 1 .P00"); + req.Headers.Set(HttpHeaderNames.SecWebsocketProtocol, "chat, superchat"); + + WebSocketServerHandshaker00 handshaker; + if (subProtocol) + { + handshaker = new WebSocketServerHandshaker00("ws://example.com/chat", "chat", int.MaxValue); + } + else + { + handshaker = new WebSocketServerHandshaker00("ws://example.com/chat", null, int.MaxValue); + } + Assert.True(handshaker.HandshakeAsync(ch, req).Wait(TimeSpan.FromSeconds(2))); + + var ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.WriteInbound(ch.ReadOutbound()); + var res = ch2.ReadInbound(); + + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketLocation, out ICharSequence value)); + Assert.Equal("ws://example.com/chat", value.ToString()); + + if (subProtocol) + { + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + Assert.Equal("chat", value.ToString()); + } + else + { + Assert.False(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + } + var content = ch2.ReadInbound(); + + Assert.Equal("8jKS'y:G*Co,Wxa-", content.Content.ToString(Encoding.ASCII)); + content.Release(); + req.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker08Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker08Test.cs new file mode 100644 index 000000000..cc24d5a54 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker08Test.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static HttpVersion; + + public class WebSocketServerHandshaker08Test + { + [Fact] + public void PerformOpeningHandshake() => PerformOpeningHandshake0(true); + + [Fact] + public void PerformOpeningHandshakeSubProtocolNotSupported() => PerformOpeningHandshake0(false); + + static void PerformOpeningHandshake0(bool subProtocol) + { + var ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + var req = new DefaultFullHttpRequest(Http11, HttpMethod.Get, "/chat"); + req.Headers.Set(HttpHeaderNames.Host, "server.example.com"); + req.Headers.Set(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + req.Headers.Set(HttpHeaderNames.Connection, "Upgrade"); + req.Headers.Set(HttpHeaderNames.SecWebsocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + req.Headers.Set(HttpHeaderNames.SecWebsocketOrigin, "http://example.com"); + req.Headers.Set(HttpHeaderNames.SecWebsocketProtocol, "chat, superchat"); + req.Headers.Set(HttpHeaderNames.SecWebsocketVersion, "8"); + + WebSocketServerHandshaker08 handshaker; + if (subProtocol) + { + handshaker = new WebSocketServerHandshaker08( + "ws://example.com/chat", "chat", false, int.MaxValue, false); + } + else + { + handshaker = new WebSocketServerHandshaker08( + "ws://example.com/chat", null, false, int.MaxValue, false); + } + + Assert.True(handshaker.HandshakeAsync(ch, req).Wait(TimeSpan.FromSeconds(2))); + + var resBuf = ch.ReadOutbound(); + + var ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.WriteInbound(resBuf); + var res = ch2.ReadInbound(); + + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketAccept, out ICharSequence value)); + Assert.Equal("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", value.ToString()); + if (subProtocol) + { + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + Assert.Equal("chat", value.ToString()); + } + else + { + Assert.False(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + } + ReferenceCountUtil.Release(res); + req.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker13Test.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker13Test.cs new file mode 100644 index 000000000..22deb9391 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshaker13Test.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static HttpVersion; + + public class WebSocketServerHandshaker13Test + { + [Fact] + public void PerformOpeningHandshake() => PerformOpeningHandshake0(true); + + [Fact] + public void PerformOpeningHandshakeSubProtocolNotSupported() => PerformOpeningHandshake0(false); + + static void PerformOpeningHandshake0(bool subProtocol) + { + var ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + var req = new DefaultFullHttpRequest(Http11, HttpMethod.Get, "/chat"); + req.Headers.Set(HttpHeaderNames.Host, "server.example.com"); + req.Headers.Set(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket); + req.Headers.Set(HttpHeaderNames.Connection, "Upgrade"); + req.Headers.Set(HttpHeaderNames.SecWebsocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + req.Headers.Set(HttpHeaderNames.SecWebsocketOrigin, "http://example.com"); + req.Headers.Set(HttpHeaderNames.SecWebsocketProtocol, "chat, superchat"); + req.Headers.Set(HttpHeaderNames.SecWebsocketVersion, "13"); + + WebSocketServerHandshaker13 handshaker; + if (subProtocol) + { + handshaker = new WebSocketServerHandshaker13( + "ws://example.com/chat", "chat", false, int.MaxValue, false); + } + else + { + handshaker = new WebSocketServerHandshaker13( + "ws://example.com/chat", null, false, int.MaxValue, false); + } + + Assert.True(handshaker.HandshakeAsync(ch, req).Wait(TimeSpan.FromSeconds(2))); + + var resBuf = ch.ReadOutbound(); + + var ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.WriteInbound(resBuf); + var res = ch2.ReadInbound(); + + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketAccept, out ICharSequence value)); + Assert.Equal("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", value.ToString()); + if (subProtocol) + { + Assert.True(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + Assert.Equal("chat", value.ToString()); + } + else + { + Assert.False(res.Headers.TryGet(HttpHeaderNames.SecWebsocketProtocol, out value)); + } + ReferenceCountUtil.Release(res); + req.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshakerFactoryTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshakerFactoryTest.cs new file mode 100644 index 000000000..b7c0002fb --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerHandshakerFactoryTest.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public class WebSocketServerHandshakerFactoryTest + { + [Fact] + public void UnsupportedVersion() + { + var ch = new EmbeddedChannel(); + WebSocketServerHandshakerFactory.SendUnsupportedVersionResponse(ch); + ch.RunPendingTasks(); + var response = ch.ReadOutbound(); + Assert.NotNull(response); + + Assert.Equal(HttpResponseStatus.UpgradeRequired, response.Status); + Assert.True(response.Headers.TryGet(HttpHeaderNames.SecWebsocketVersion, out ICharSequence value)); + Assert.Equal(WebSocketVersion.V13.ToHttpHeaderValue(), value); + Assert.True(HttpUtil.IsContentLengthSet(response)); + Assert.Equal(0, HttpUtil.GetContentLength(response)); + + ReferenceCountUtil.Release(response); + Assert.False(ch.Finish()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs new file mode 100644 index 000000000..9fe1705a3 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/WebSockets/WebSocketServerProtocolHandlerTest.cs @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.WebSockets +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Codecs.Http.WebSockets; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static HttpResponseStatus; + using static HttpVersion; + + public class WebSocketServerProtocolHandlerTest : IDisposable + { + readonly Queue responses; + + public WebSocketServerProtocolHandlerTest() + { + this.responses = new Queue(); + } + + [Fact] + public void HttpUpgradeRequest() + { + EmbeddedChannel ch = this.CreateChannel(new MockOutboundHandler(this)); + IChannelHandlerContext handshakerCtx = ch.Pipeline.Context(); + WriteUpgradeRequest(ch); + + IFullHttpResponse response = this.responses.Dequeue(); + Assert.Equal(SwitchingProtocols, response.Status); + response.Release(); + Assert.NotNull(WebSocketServerProtocolHandler.GetHandshaker(handshakerCtx.Channel)); + } + + [Fact] + public void SubsequentHttpRequestsAfterUpgradeShouldReturn403() + { + EmbeddedChannel ch = this.CreateChannel(); + + WriteUpgradeRequest(ch); + + IFullHttpResponse response = this.responses.Dequeue(); + Assert.Equal(SwitchingProtocols, response.Status); + response.Release(); + + ch.WriteInbound(new DefaultFullHttpRequest(Http11, HttpMethod.Get, "/test")); + response = this.responses.Dequeue(); + Assert.Equal(Forbidden, response.Status); + response.Release(); + } + + [Fact] + public void HttpUpgradeRequestInvalidUpgradeHeader() + { + EmbeddedChannel ch = this.CreateChannel(); + IFullHttpRequest httpRequestWithEntity = new WebSocketRequestBuilder() + .HttpVersion(Http11) + .Method(HttpMethod.Get) + .Uri("/test") + .Connection("Upgrade") + .Version00() + .Upgrade("BogusSocket") + .Build(); + + ch.WriteInbound(httpRequestWithEntity); + + IFullHttpResponse response = this.responses.Dequeue(); + Assert.Equal(BadRequest, response.Status); + Assert.Equal("not a WebSocket handshake request: missing upgrade", GetResponseMessage(response)); + response.Release(); + } + + [Fact] + public void HttpUpgradeRequestMissingWsKeyHeader() + { + EmbeddedChannel ch = this.CreateChannel(); + IHttpRequest httpRequest = new WebSocketRequestBuilder() + .HttpVersion(Http11) + .Method(HttpMethod.Get) + .Uri("/test") + .Key(null) + .Connection("Upgrade") + .Upgrade(HttpHeaderValues.Websocket) + .Version13() + .Build(); + + ch.WriteInbound(httpRequest); + + IFullHttpResponse response = this.responses.Dequeue(); + Assert.Equal(BadRequest, response.Status); + Assert.Equal("not a WebSocket request: missing key", GetResponseMessage(response)); + response.Release(); + } + + [Fact] + public void HandleTextFrame() + { + var customTextFrameHandler = new CustomTextFrameHandler(); + EmbeddedChannel ch = this.CreateChannel(customTextFrameHandler); + WriteUpgradeRequest(ch); + + if (ch.Pipeline.Context() != null) + { + // Removing the HttpRequestDecoder because we are writing a TextWebSocketFrame and thus + // decoding is not necessary. + ch.Pipeline.Remove(); + } + + ch.WriteInbound(new TextWebSocketFrame("payload")); + + Assert.Equal("processed: payload", customTextFrameHandler.Content); + } + + EmbeddedChannel CreateChannel() => this.CreateChannel(null); + + EmbeddedChannel CreateChannel(IChannelHandler handler) => + new EmbeddedChannel( + new WebSocketServerProtocolHandler("/test", null, false), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler(this), + handler); + + static void WriteUpgradeRequest(EmbeddedChannel ch) => ch.WriteInbound(WebSocketRequestBuilder.Successful()); + + static string GetResponseMessage(IFullHttpResponse response) => Encoding.ASCII.GetString(response.Content.Array); + + sealed class MockOutboundHandler : ChannelHandlerAdapter + { + readonly WebSocketServerProtocolHandlerTest owner; + + public MockOutboundHandler(WebSocketServerProtocolHandlerTest owner) + { + this.owner = owner; + } + + public override Task WriteAsync(IChannelHandlerContext ctx, object msg) + { + this.owner.responses.Enqueue((IFullHttpResponse)msg); + return TaskEx.Completed; + } + + public override void Flush(IChannelHandlerContext ctx) + { + } + } + + sealed class CustomTextFrameHandler : ChannelHandlerAdapter + { + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + Assert.Null(this.Content); + this.Content = "processed: " + ((TextWebSocketFrame)msg).Text(); + ReferenceCountUtil.Release(msg); + } + + public string Content { get; private set; } + } + + public void Dispose() => this.responses.Clear(); + } +}