From da8348a7b493a9298ed93f8faf5ce74e40eb17b9 Mon Sep 17 00:00:00 2001 From: Muthuveer Somanathan <41929942+msomanathan@users.noreply.github.com> Date: Thu, 14 Oct 2021 21:59:07 -0700 Subject: [PATCH] Streaming Library Refactor (#5908) --- CodeCoverage.runsettings | 1 + Microsoft.Bot.Builder.sln | 58 ++ .../CoreBotAdapter.cs | 2 +- .../Microsoft.Bot.Builder/CloudAdapterBase.cs | 21 +- .../Microsoft.Bot.Builder.csproj | 3 + .../Streaming/BotFrameworkHttpAdapterBase.cs | 43 +- .../Streaming/StreamingRequestHandler.cs | 108 ++-- .../Application/LegacyStreamingConnection.cs | 160 ++++++ .../Application/StreamingConnection.cs | 37 ++ .../Application/TimerAwaitable.cs | 161 ++++++ .../Application/WebSocketClient.cs | 323 +++++++++++ .../WebSocketStreamingConnection.cs | 125 +++++ .../AssemblyInfo.cs | 11 + .../Microsoft.Bot.Connector.Streaming.csproj | 46 ++ .../Session/StreamingSession.cs | 523 ++++++++++++++++++ .../Transport/DuplexPipe.cs | 44 ++ .../Transport/Payloads/RequestPayload.cs | 32 ++ .../Transport/Payloads/ResponsePayload.cs | 26 + .../Transport/TransportHandler.cs | 323 +++++++++++ .../Transport/WebSocketExtensions.cs | 48 ++ .../Transport/WebSocketTransport.cs | 382 +++++++++++++ .../PasswordServiceClientCredentialFactory.cs | 2 +- .../Microsoft.Bot.Streaming/AssemblyInfo.cs | 11 + .../StreamingRequest.cs | 2 +- .../Transport/WebSocket/WebSocketClient.cs | 30 + .../BotFrameworkHttpAdapter.cs | 43 +- .../CloudAdapter.cs | 102 +++- ...Bot.Builder.Integration.AspNet.Core.csproj | 6 + ...rosoft.Bot.Connector.Streaming.Perf.csproj | 18 + .../Program.cs | 13 + ...ot.Connector.Streaming.Tests.Client.csproj | 17 + .../Program.cs | 241 ++++++++ .../Controllers/BotController.cs | 36 ++ ...ot.Connector.Streaming.Tests.Server.csproj | 27 + .../Program.cs | 29 + .../Startup.cs | 57 ++ .../appsettings.json | 12 + .../snapshot.dialog | 19 + .../LegacyStreamingConnectionTests.cs | 216 ++++++++ .../Application/ObjectWithTimerAwaitable.cs | 41 ++ .../Application/TimerAwaitableTests.cs | 38 ++ ...pplicationToApplicationIntegrationTests.cs | 140 +++++ .../Integration/EndToEndMiniLoadTests.cs | 13 + .../InteropApplicationIntegrationTests.cs | 317 +++++++++++ .../WebSocketTransportClientServerTests.cs | 65 +++ ...osoft.Bot.Connector.Streaming.Tests.csproj | 47 ++ .../Session/ProtocolDispatcherTests.cs | 194 +++++++ .../Session/StreamingSessionTests.cs | 369 ++++++++++++ .../Tools/MemorySegment{T}.cs | 28 + .../Tools/SyncPoint.cs | 86 +++ .../Tools/TaskExtensions.cs | 62 +++ .../Tools/TestTransportObserver.cs | 30 + .../Tools/TestWebSocketConnectionFeature.cs | 283 ++++++++++ .../Tools/XUnitLogger.cs | 56 ++ .../Transport/TransportHandlerTests.cs | 424 ++++++++++++++ .../Transport/WebSocketTransportTests.cs | 515 +++++++++++++++++ .../RequestTests.cs | 8 +- .../CloudAdapterTests.cs | 170 ++++++ tests/tests.uischema | 3 +- 59 files changed, 6178 insertions(+), 69 deletions(-) create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Application/LegacyStreamingConnection.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Application/StreamingConnection.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Application/TimerAwaitable.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketClient.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketStreamingConnection.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/AssemblyInfo.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Microsoft.Bot.Connector.Streaming.csproj create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Session/StreamingSession.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/DuplexPipe.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/RequestPayload.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/ResponsePayload.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/TransportHandler.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketExtensions.cs create mode 100644 libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketTransport.cs create mode 100644 libraries/Microsoft.Bot.Streaming/AssemblyInfo.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Perf/Microsoft.Bot.Connector.Streaming.Perf.csproj create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Perf/Program.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Microsoft.Bot.Connector.Streaming.Tests.Client.csproj create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Program.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Controllers/BotController.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Microsoft.Bot.Connector.Streaming.Tests.Server.csproj create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Program.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Startup.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/appsettings.json create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests.Server/snapshot.dialog create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Application/LegacyStreamingConnectionTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Application/ObjectWithTimerAwaitable.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Application/TimerAwaitableTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/ApplicationToApplicationIntegrationTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/EndToEndMiniLoadTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/InteropApplicationIntegrationTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/WebSocketTransportClientServerTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Microsoft.Bot.Connector.Streaming.Tests.csproj create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Session/ProtocolDispatcherTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Session/StreamingSessionTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/MemorySegment{T}.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/SyncPoint.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TaskExtensions.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestTransportObserver.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestWebSocketConnectionFeature.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/XUnitLogger.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/TransportHandlerTests.cs create mode 100644 tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/WebSocketTransportTests.cs diff --git a/CodeCoverage.runsettings b/CodeCoverage.runsettings index 7bc289dbb2..316bc2003b 100644 --- a/CodeCoverage.runsettings +++ b/CodeCoverage.runsettings @@ -40,6 +40,7 @@ .*\Microsoft.Bot.Connector.dll$ .*\Microsoft.Bot.Schema.dll$ + .*\Microsoft.Bot.Connector.Streaming.dll$ .*\Microsoft.Bot.Streaming.dll$ diff --git a/Microsoft.Bot.Builder.sln b/Microsoft.Bot.Builder.sln index 58797e7779..3412dffbfb 100644 --- a/Microsoft.Bot.Builder.sln +++ b/Microsoft.Bot.Builder.sln @@ -225,6 +225,18 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Builder.Dialo EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime", "libraries\Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime\Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime.csproj", "{2DB4E5B0-3209-425E-A912-005A330CC66A}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Connector.Streaming", "libraries\Microsoft.Bot.Connector.Streaming\Microsoft.Bot.Connector.Streaming.csproj", "{80FA0E50-8F81-4C60-B265-1039391C1CEE}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Connector.Streaming.Tests", "tests\Microsoft.Bot.Connector.Streaming.Tests\Microsoft.Bot.Connector.Streaming.Tests.csproj", "{9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Streaming", "Streaming", "{EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Connector.Streaming.Tests.Server", "tests\Microsoft.Bot.Connector.Streaming.Tests.Server\Microsoft.Bot.Connector.Streaming.Tests.Server.csproj", "{FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bot.Connector.Streaming.Tests.Client", "tests\Microsoft.Bot.Connector.Streaming.Tests.Client\Microsoft.Bot.Connector.Streaming.Tests.Client.csproj", "{2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Bot.Connector.Streaming.Perf", "tests\Microsoft.Bot.Connector.Streaming.Perf\Microsoft.Bot.Connector.Streaming.Perf.csproj", "{B49A3201-5BEE-426C-A082-D92D52172E06}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -881,6 +893,46 @@ Global {2DB4E5B0-3209-425E-A912-005A330CC66A}.Release|Any CPU.Build.0 = Release|Any CPU {2DB4E5B0-3209-425E-A912-005A330CC66A}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU {2DB4E5B0-3209-425E-A912-005A330CC66A}.Release-Windows|Any CPU.Build.0 = Release|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Debug-Windows|Any CPU.ActiveCfg = Debug|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Debug-Windows|Any CPU.Build.0 = Debug|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Release|Any CPU.Build.0 = Release|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU + {80FA0E50-8F81-4C60-B265-1039391C1CEE}.Release-Windows|Any CPU.Build.0 = Release|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Debug-Windows|Any CPU.ActiveCfg = Debug|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Debug-Windows|Any CPU.Build.0 = Debug|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Release|Any CPU.Build.0 = Release|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7}.Release-Windows|Any CPU.Build.0 = Release|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Debug-Windows|Any CPU.ActiveCfg = Debug|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Debug-Windows|Any CPU.Build.0 = Debug|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Release|Any CPU.Build.0 = Release|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71}.Release-Windows|Any CPU.Build.0 = Release|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Debug-Windows|Any CPU.ActiveCfg = Debug|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Debug-Windows|Any CPU.Build.0 = Debug|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Release|Any CPU.Build.0 = Release|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161}.Release-Windows|Any CPU.Build.0 = Release|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Debug-Windows|Any CPU.ActiveCfg = Debug|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Debug-Windows|Any CPU.Build.0 = Debug|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Release|Any CPU.Build.0 = Release|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Release-Windows|Any CPU.ActiveCfg = Release|Any CPU + {B49A3201-5BEE-426C-A082-D92D52172E06}.Release-Windows|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -979,6 +1031,12 @@ Global {0BF5E92D-D034-4D80-8921-07627F55F412} = {C40A300C-8988-4733-A760-A776C6309B57} {D611AC03-9859-4EB6-BAB9-C26F493DFDB3} = {AD743B78-D61F-4FBF-B620-FA83CE599A50} {2DB4E5B0-3209-425E-A912-005A330CC66A} = {4269F3C3-6B42-419B-B64A-3E6DC0F1574A} + {80FA0E50-8F81-4C60-B265-1039391C1CEE} = {4269F3C3-6B42-419B-B64A-3E6DC0F1574A} + {9EBA6EDB-7D67-4BC5-9F94-E0162A538CC7} = {EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C} + {EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C} = {AD743B78-D61F-4FBF-B620-FA83CE599A50} + {FB7ADCDF-C0A5-49EA-8ADC-CC77B6FB9D71} = {EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C} + {2E5AD07C-4F6E-4B6B-BEFE-9FBE9F789161} = {EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C} + {B49A3201-5BEE-426C-A082-D92D52172E06} = {EBFEF03F-9ACE-4312-89D7-2C8A147CDF9C} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7173C9F3-A7F9-496E-9078-9156E35D6E16} diff --git a/libraries/Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime/CoreBotAdapter.cs b/libraries/Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime/CoreBotAdapter.cs index 2756412e8f..b834a31671 100644 --- a/libraries/Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime/CoreBotAdapter.cs +++ b/libraries/Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime/CoreBotAdapter.cs @@ -13,7 +13,7 @@ internal class CoreBotAdapter : CloudAdapter public CoreBotAdapter( BotFrameworkAuthentication botFrameworkAuthentication, IEnumerable middlewares, - ILogger logger = null) + ILogger logger = null) : base(botFrameworkAuthentication, logger) { // Pick up feature based middlewares such as telemetry or transcripts diff --git a/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs b/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs index d04a9d049e..ad10362487 100644 --- a/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs +++ b/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs @@ -233,6 +233,23 @@ public override async Task CreateConversationAsync(string botAppId, string chann } } + /// + /// Gets the correct streaming connector factory that is processing the given activity. + /// + /// The activity that is being processed. + /// The Streaming Connector Factory responsible for processing the activity. + /// + /// For HTTP requests, we usually create a new connector factory and reply to the activity over a new HTTP request. + /// However, when processing activities over a streaming connection, we need to reply over the same connection that is talking to a web socket. + /// This method will look up all active streaming connections in cloud adapter and return the connector factory that is processing the activity. + /// Messages between bot and channel go through the StreamingConnection (bot -> channel) and RequestHandler (channel -> bot), both created by the adapter. + /// However, proactive messages don't know which connection to talk to, so this method is designed to aid in the connection resolution for such proactive messages. + /// + protected virtual ConnectorFactory GetStreamingConnectorFactory(Activity activity) + { + throw new NotImplementedException(); + } + /// /// The implementation for continue conversation. /// @@ -247,7 +264,9 @@ protected async Task ProcessProactiveAsync(ClaimsIdentity claimsIdentity, Activi Logger.LogInformation($"ProcessProactiveAsync for Conversation Id: {continuationActivity.Conversation.Id}"); // Create the connector factory. - var connectorFactory = BotFrameworkAuthentication.CreateConnectorFactory(claimsIdentity); + var connectorFactory = continuationActivity.IsFromStreamingConnection() + ? GetStreamingConnectorFactory(continuationActivity) + : BotFrameworkAuthentication.CreateConnectorFactory(claimsIdentity); // Create the connector client to use for outbound requests. using (var connectorClient = await connectorFactory.CreateAsync(continuationActivity.ServiceUrl, audience, cancellationToken).ConfigureAwait(false)) diff --git a/libraries/Microsoft.Bot.Builder/Microsoft.Bot.Builder.csproj b/libraries/Microsoft.Bot.Builder/Microsoft.Bot.Builder.csproj index 90e9764329..b3776a82d8 100644 --- a/libraries/Microsoft.Bot.Builder/Microsoft.Bot.Builder.csproj +++ b/libraries/Microsoft.Bot.Builder/Microsoft.Bot.Builder.csproj @@ -24,6 +24,8 @@ + + @@ -33,6 +35,7 @@ + diff --git a/libraries/Microsoft.Bot.Builder/Streaming/BotFrameworkHttpAdapterBase.cs b/libraries/Microsoft.Bot.Builder/Streaming/BotFrameworkHttpAdapterBase.cs index 9e9e9c2c56..21a1f3b218 100644 --- a/libraries/Microsoft.Bot.Builder/Streaming/BotFrameworkHttpAdapterBase.cs +++ b/libraries/Microsoft.Bot.Builder/Streaming/BotFrameworkHttpAdapterBase.cs @@ -22,8 +22,10 @@ namespace Microsoft.Bot.Builder.Streaming /// /// An HTTP adapter base class. /// - public class BotFrameworkHttpAdapterBase : BotFrameworkAdapter, IStreamingActivityProcessor + public class BotFrameworkHttpAdapterBase : BotFrameworkAdapter, IStreamingActivityProcessor, IDisposable { + private bool _disposedValue; + /// /// Initializes a new instance of the class. /// @@ -220,7 +222,9 @@ public async Task SendStreamingActivityAsync(Activity activity var host = uri[uri.Length - 1]; await connection.ConnectAsync(new Uri(protocol + host + "/api/messages"), cancellationToken).ConfigureAwait(false); +#pragma warning disable CA2000 // Dispose objects before losing scope (We'll dispose this when the adapter gets disposed or when elements are removed) var handler = new StreamingRequestHandler(ConnectedBot, this, connection, Logger); +#pragma warning restore CA2000 // Dispose objects before losing scope if (RequestHandlers == null) { @@ -259,12 +263,49 @@ public async Task ConnectNamedPipeAsync(string pipeName, IBot bot, string audien RequestHandlers = new List(); } +#pragma warning disable CA2000 // Dispose objects before losing scope (We'll dispose this when the adapter gets disposed or when elements are removed) var requestHandler = new StreamingRequestHandler(bot, this, pipeName, audience, Logger); +#pragma warning restore CA2000 // Dispose objects before losing scope RequestHandlers.Add(requestHandler); await requestHandler.ListenAsync().ConfigureAwait(false); } + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes resources of the . + /// + /// Whether we are disposing managed resources. + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + if (RequestHandlers != null) + { + foreach (var handler in RequestHandlers) + { + if (handler is IDisposable disposable) + { + handler.Dispose(); + } + } + } + } + + RequestHandlers = null; + _disposedValue = true; + } + } + /// /// Evaluates if processing an outgoing activity is possible. /// diff --git a/libraries/Microsoft.Bot.Builder/Streaming/StreamingRequestHandler.cs b/libraries/Microsoft.Bot.Builder/Streaming/StreamingRequestHandler.cs index 672c75b81e..b3ed0b8044 100644 --- a/libraries/Microsoft.Bot.Builder/Streaming/StreamingRequestHandler.cs +++ b/libraries/Microsoft.Bot.Builder/Streaming/StreamingRequestHandler.cs @@ -14,11 +14,10 @@ using System.Threading.Tasks; using Microsoft.Bot.Connector; using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Connector.Streaming.Application; using Microsoft.Bot.Schema; using Microsoft.Bot.Streaming; using Microsoft.Bot.Streaming.Transport; -using Microsoft.Bot.Streaming.Transport.NamedPipes; -using Microsoft.Bot.Streaming.Transport.WebSockets; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; @@ -29,19 +28,34 @@ namespace Microsoft.Bot.Builder.Streaming /// A request handler that processes incoming requests sent over an IStreamingTransport /// and adheres to the Bot Framework Protocol v3 with Streaming Extensions. /// - public class StreamingRequestHandler : RequestHandler + public class StreamingRequestHandler : RequestHandler, IDisposable { - private static ConcurrentDictionary _requestHandlers = new ConcurrentDictionary(); - private readonly string _instanceId = Guid.NewGuid().ToString(); - private readonly IBot _bot; private readonly ILogger _logger; private readonly IStreamingActivityProcessor _activityProcessor; - private readonly string _userAgent; - private readonly ConcurrentDictionary _conversations; - private readonly IStreamingTransportServer _server; + private readonly string _userAgent = GetUserAgent(); + private readonly ConcurrentDictionary _conversations = new ConcurrentDictionary(); + private readonly StreamingConnection _innerConnection; - private bool _serverIsConnected; + private bool _disposedValue; + + /// + /// Initializes a new instance of the class. + /// + /// The bot for which we handle requests. + /// The processor for incoming requests. + /// Connection used to send requests to the transport. + /// The specified recipient of all outgoing activities. + /// Logger implementation for tracing and debugging information. + public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityProcessor, StreamingConnection connection, string audience = null, ILogger logger = null) + { + _bot = bot ?? throw new ArgumentNullException(nameof(bot)); + _activityProcessor = activityProcessor ?? throw new ArgumentNullException(nameof(activityProcessor)); + _innerConnection = connection ?? throw new ArgumentNullException(nameof(connection)); + _logger = logger ?? NullLogger.Instance; + + Audience = audience; + } /// /// Initializes a new instance of the class and @@ -85,11 +99,7 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro Audience = audience; _logger = logger ?? NullLogger.Instance; - _conversations = new ConcurrentDictionary(); - _userAgent = GetUserAgent(); - _server = new WebSocketServer(socket, this); - _serverIsConnected = true; - _server.Disconnected += ServerDisconnected; + _innerConnection = new LegacyStreamingConnection(socket, _logger, ServerDisconnected); } /// @@ -134,11 +144,7 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro } Audience = audience; - _conversations = new ConcurrentDictionary(); - _userAgent = GetUserAgent(); - _server = new NamedPipeServer(pipeName, this); - _serverIsConnected = true; - _server.Disconnected += ServerDisconnected; + _innerConnection = new LegacyStreamingConnection(pipeName, _logger, ServerDisconnected); } /// @@ -165,11 +171,19 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro /// A task that completes once the server is no longer listening. public virtual async Task ListenAsync() { - await _server.StartAsync().ConfigureAwait(false); - _logger.LogInformation("Streaming request handler started listening"); + await ListenAsync(CancellationToken.None).ConfigureAwait(false); + } - // add ourselves to a global collection to ensure a reference is maintained if we are connected - _requestHandlers.TryAdd(_instanceId, this); + /// + /// Begins listening for incoming requests over this StreamingRequestHandler's server. + /// + /// Cancellation token. + /// A task that completes once the server is no longer listening. + public async Task ListenAsync(CancellationToken cancellationToken) + { + _logger.LogInformation("Streaming request handler started listening"); + await _innerConnection.ListenAsync(this, cancellationToken).ConfigureAwait(false); + _logger.LogInformation("Streaming request handler completed listening"); } /// @@ -403,15 +417,7 @@ public virtual async Task SendActivityAsync(Activity activity, } } - if (!_serverIsConnected) - { - throw new InvalidOperationException("Error while attempting to send: Streaming transport is disconnected."); - } - - // Attempt to send the request. If send fails, we let the original exception get thrown so that the - // upper layers can handle it and trigger OnError. This is consistent with error handling in http and proactive - // paths, making all 3 paths consistent in terms of error handling. - var serverResponse = await _server.SendAsync(request, cancellationToken).ConfigureAwait(false); + var serverResponse = await _innerConnection.SendStreamingRequestAsync(request, cancellationToken).ConfigureAwait(false); if (serverResponse.StatusCode == (int)HttpStatusCode.OK) { @@ -431,12 +437,35 @@ public virtual async Task SendActivityAsync(Activity activity, /// A task that resolves to a . public Task SendStreamingRequestAsync(StreamingRequest request, CancellationToken cancellationToken = default) { - if (!_serverIsConnected) + return _innerConnection.SendStreamingRequestAsync(request, cancellationToken); + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes resources of the . + /// + /// Whether we are disposing managed resources. + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) { - throw new InvalidOperationException("Error while attempting to send: Streaming transport is disconnected."); - } + if (disposing) + { + if (_innerConnection is IDisposable disposable) + { + disposable?.Dispose(); + } + } - return _server.SendAsync(request, cancellationToken); + _disposedValue = true; + } } /// @@ -446,10 +475,7 @@ public Task SendStreamingRequestAsync(StreamingRequest request, /// The arguments specified by the disconnection event. protected virtual void ServerDisconnected(object sender, DisconnectedEventArgs e) { - _serverIsConnected = false; - - // remove ourselves from the global collection - _requestHandlers.TryRemove(_instanceId, out var _); + // Subtypes can override this method to add logging when an underlying transport server is disconnected } /// diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Application/LegacyStreamingConnection.cs b/libraries/Microsoft.Bot.Connector.Streaming/Application/LegacyStreamingConnection.cs new file mode 100644 index 0000000000..41908604b5 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Application/LegacyStreamingConnection.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Transport; +using Microsoft.Bot.Streaming.Transport.NamedPipes; +using Microsoft.Bot.Streaming.Transport.WebSockets; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Application +{ + /// + /// The to be used by legacy bots. + /// + [Obsolete("Use `WebSocketStreamingConnection` instead.", false)] + public class LegacyStreamingConnection : StreamingConnection, IDisposable + { + private readonly WebSocket _socket; + private readonly string _pipeName; + private readonly ILogger _logger; + + private readonly DisconnectedEventHandler _onServerDisconnect; + + private IStreamingTransportServer _server; + private bool _serverIsConnected; + private bool _disposedValue; + + /// + /// Initializes a new instance of the class that uses web sockets. + /// + /// The instance to use for legacy streaming connection. + /// Logger implementation for tracing and debugging information. + /// Additional handling code to be run when the transport server is disconnected. + public LegacyStreamingConnection(WebSocket socket, ILogger logger, DisconnectedEventHandler onServerDisconnect = null) + { + _socket = socket ?? throw new ArgumentNullException(nameof(socket)); + _logger = logger ?? NullLogger.Instance; + _onServerDisconnect = onServerDisconnect; + } + + /// + /// Initializes a new instance of the class that uses named pipes. + /// + /// The name of the named pipe. + /// Logger implementation for tracing and debugging information. + /// Additional handling code to be run when the transport server is disconnected. + public LegacyStreamingConnection(string pipeName, ILogger logger, DisconnectedEventHandler onServerDisconnect = null) + { + if (string.IsNullOrWhiteSpace(pipeName)) + { + throw new ArgumentNullException(nameof(pipeName)); + } + + _pipeName = pipeName; + _logger = logger ?? NullLogger.Instance; + _onServerDisconnect = onServerDisconnect; + } + + /// + public override async Task ListenAsync(RequestHandler requestHandler, CancellationToken cancellationToken = default) + { + _server = CreateStreamingTransportServer(requestHandler); + _serverIsConnected = true; + _server.Disconnected += Server_Disconnected; + + if (_onServerDisconnect != null) + { + _server.Disconnected += _onServerDisconnect; + } + + await _server.StartAsync().ConfigureAwait(false); + } + + /// + public override async Task SendStreamingRequestAsync(StreamingRequest request, CancellationToken cancellationToken = default) + { + if (!_serverIsConnected) + { + throw new InvalidOperationException("Error while attempting to send: Streaming transport is disconnected."); + } + + return await _server.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + internal virtual IStreamingTransportServer CreateStreamingTransportServer(RequestHandler requestHandler) + { + if (_socket != null) + { + return new WebSocketServer(_socket, requestHandler); + } + + if (!string.IsNullOrWhiteSpace(_pipeName)) + { + return new NamedPipeServer(_pipeName, requestHandler); + } + + throw new ApplicationException("Neither web socket, nor named pipe found to instantiate a streaming transport server!"); + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We want to catch all exceptions while disconnecting.")] +#pragma warning disable CA1063 // Implement IDisposable Correctly + private void Dispose(bool disposing) +#pragma warning restore CA1063 // Implement IDisposable Correctly + { + if (!_disposedValue) + { + if (disposing) + { + try + { + if (_server != null) + { + if (_server is WebSocketServer webSocketServer) + { + webSocketServer.Disconnect(); + } + else if (_server is NamedPipeServer namedPipeServer) + { + namedPipeServer.Disconnect(); + } + + if (_server is IDisposable disposable) + { + disposable.Dispose(); + } + + _server.Disconnected -= Server_Disconnected; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to gracefully disconnect server while tearing down streaming connection."); + } + } + + // TODO: free unmanaged resources (unmanaged objects) and override finalizer + // TODO: set large fields to null + _disposedValue = true; + } + } + + private void Server_Disconnected(object sender, DisconnectedEventArgs e) + { + _serverIsConnected = false; + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Application/StreamingConnection.cs b/libraries/Microsoft.Bot.Connector.Streaming/Application/StreamingConnection.cs new file mode 100644 index 0000000000..ab6aa919ff --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Application/StreamingConnection.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Streaming; + +namespace Microsoft.Bot.Connector.Streaming.Application +{ + /// + /// A streaming based connection that can listen for incoming requests and send them to a , + /// and can also send requests to the other end of the connection. + /// + public abstract class StreamingConnection + { + /// + /// Sends a streaming request through the connection. + /// + /// to be sent. + /// to cancel the send process. + /// The returned from the client. + public abstract Task SendStreamingRequestAsync(StreamingRequest request, CancellationToken cancellationToken = default(CancellationToken)); + + /// + /// Opens the and listens for incoming requests, which will + /// be assembled and sent to the provided . + /// + /// to which incoming requests will be sent. + /// that signals the need to stop the connection. + /// Once the token is cancelled, the connection will be gracefully shut down, finishing pending sends and receives. + /// A representing the asynchronous operation. + public abstract Task ListenAsync(RequestHandler requestHandler, CancellationToken cancellationToken = default(CancellationToken)); + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Application/TimerAwaitable.cs b/libraries/Microsoft.Bot.Connector.Streaming/Application/TimerAwaitable.cs new file mode 100644 index 0000000000..5f992170d3 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Application/TimerAwaitable.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Bot.Connector.Streaming.Application +{ + // Reusing internal awaitable timer from https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/common/Shared/TimerAwaitable.cs + internal class TimerAwaitable : IDisposable, INotifyCompletion + { + private static readonly Action _callbackCompleted = () => { }; + + private Timer _timer; + private Action _callback; + + private readonly TimeSpan _period; + + private readonly TimeSpan _dueTime; + private readonly object _lockObj = new object(); + private bool _disposed; + private bool _running = true; + + public TimerAwaitable(TimeSpan dueTime, TimeSpan period) + { + _dueTime = dueTime; + _period = period; + } + + public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); + + public void Start() + { + if (_timer == null) + { + lock (_lockObj) + { + if (_disposed) + { + return; + } + + if (_timer == null) + { + // This fixes the cycle by using a WeakReference to the state object. The object graph now looks like this: + // Timer -> TimerHolder -> TimerQueueTimer -> WeakReference -> Timer -> ... + // If TimerAwaitable falls out of scope, the timer should be released. + _timer = NonCapturingTimer.Create( + state => + { + var weakRef = (WeakReference)state!; + if (weakRef.TryGetTarget(out var thisRef)) + { + thisRef.Tick(); + } + }, + state: new WeakReference(this), + dueTime: _dueTime, + period: _period); + } + } + } + } + + public TimerAwaitable GetAwaiter() => this; + + public bool GetResult() + { + _callback = null; + + return _running; + } + + public void OnCompleted(Action continuation) + { + if (ReferenceEquals(_callback, _callbackCompleted) || + ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) + { + _ = Task.Run(continuation); + } + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + + public void Stop() + { + lock (_lockObj) + { + // Stop should be used to trigger the call to end the loop which disposes + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + + _running = false; + } + + // Call tick here to make sure that we yield the callback, + // if it's currently waiting, we don't need to wait for the next period + Tick(); + } + + void IDisposable.Dispose() + { + lock (_lockObj) + { + _disposed = true; + + _timer?.Dispose(); + + _timer = null; + } + } + + private void Tick() + { + var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); + continuation?.Invoke(); + } + + // A convenience API for interacting with System.Threading.Timer in a way + // that doesn't capture the ExecutionContext. We should be using this (or equivalent) + // everywhere we use timers to avoid rooting any values stored in asynclocals. + private static class NonCapturingTimer + { + public static Timer Create(TimerCallback callback, object state, TimeSpan dueTime, TimeSpan period) + { + if (callback == null) + { + throw new ArgumentNullException(nameof(callback)); + } + + // Don't capture the current ExecutionContext and its AsyncLocals onto the timer + bool restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + return new Timer(callback, state, dueTime, period); + } + finally + { + // Restore the current ExecutionContext + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketClient.cs b/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketClient.cs new file mode 100644 index 0000000000..5cd5f4a7aa --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketClient.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Session; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Transport; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using static Microsoft.Bot.Connector.Streaming.Transport.DuplexPipe; + +namespace Microsoft.Bot.Connector.Streaming.Application +{ + /// + /// Web socket client. + /// + public class WebSocketClient : IStreamingTransportClient + { + private readonly string _url; + private readonly RequestHandler _requestHandler; + private readonly ILogger _logger; + private readonly TimeSpan _closeTimeout; + private readonly TimeSpan? _keepAlive; + + private CancellationTokenSource _disconnectCts; + private StreamingSession _session; + private TransportHandler _transportHandler; + private DuplexPipePair _duplexPipePair; + private volatile bool _disposed = false; + + /// + /// Initializes a new instance of the class. + /// + /// Url to connect to. + /// Request handler that will receive incoming requests to this client instance. + /// Optional time out for closing the client connection. + /// Optional spacing between keep alives for proactive disconnection detection. If null is provided, no keep alives will be sent. + /// for the client. + public WebSocketClient(string url, RequestHandler requestHandler, TimeSpan? closeTimeOut = null, TimeSpan? keepAlive = null, ILogger logger = null) + { + if (string.IsNullOrEmpty(url)) + { + throw new ArgumentNullException(nameof(url)); + } + + _url = url; + _requestHandler = requestHandler ?? throw new ArgumentNullException(nameof(requestHandler)); + _logger = logger ?? NullLogger.Instance; + _closeTimeout = closeTimeOut ?? TimeSpan.FromSeconds(15); + _keepAlive = keepAlive; + } + + internal WebSocketClient(RequestHandler requestHandler, TimeSpan? closeTimeOut = null, TimeSpan? keepAlive = null, ILogger logger = null) + { + _requestHandler = requestHandler ?? throw new ArgumentNullException(nameof(requestHandler)); + _logger = logger ?? NullLogger.Instance; + _closeTimeout = closeTimeOut ?? TimeSpan.FromSeconds(15); + _keepAlive = keepAlive; + } + + /// + public event DisconnectedEventHandler Disconnected; + + /// + public bool IsConnected { get; set; } = false; + + /// + public async Task ConnectAsync() + { + await ConnectAsync(new Dictionary(), CancellationToken.None).ConfigureAwait(false); + } + + /// + public async Task ConnectAsync(IDictionary requestHeaders) + { + await ConnectAsync(requestHeaders, CancellationToken.None).ConfigureAwait(false); + } + + /// + /// Establishes the connection. + /// + /// Request headers. + /// for the client connection. + /// A representing the asynchronous operation. + public async Task ConnectAsync(IDictionary requestHeaders, CancellationToken cancellationToken) + { + await ConnectInternalAsync( + connectFunc: transport => transport.ConnectAsync(_url, requestHeaders, CancellationToken.None), + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + public async Task SendAsync(StreamingRequest message, CancellationToken cancellationToken = default) + { + CheckDisposed(); + + if (_session == null) + { + throw new InvalidOperationException("Session not established. Call ConnectAsync() in order to send requests through this client."); + } + + if (message == null) + { + throw new ArgumentNullException(nameof(message)); + } + + return await _session.SendRequestAsync(message, cancellationToken).ConfigureAwait(false); + } + + /// + public void Disconnect() + { + CheckDisposed(); + DisconnectAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + } + + /// + /// Disconnects. + /// + /// A representing the asynchronous operation. + public async Task DisconnectAsync() + { + CheckDisposed(); + await _transportHandler.StopAsync().ConfigureAwait(false); + IsConnected = false; + } + + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + internal async Task ConnectInternalAsync(WebSocket clientSocket, CancellationToken cancellationToken) + { + await ConnectInternalAsync( + connectFunc: transport => transport.ProcessSocketAsync(clientSocket, CancellationToken.None), + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + internal async Task ConnectInternalAsync(Func connectFunc, CancellationToken cancellationToken) + { + CheckDisposed(); + + TimerAwaitable timer = null; + Task timerTask = null; + + try + { + // Pipes + _duplexPipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + // Transport + var transport = new WebSocketTransport(_duplexPipePair.Application, _logger); + + // Application + _transportHandler = new TransportHandler(_duplexPipePair.Transport, _logger); + + // Session + _session = new StreamingSession(_requestHandler, _transportHandler, _logger); + + // Set up cancellation + _disconnectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + // Start transport and application + var transportTask = connectFunc(transport); + var applicationTask = _transportHandler.ListenAsync(_disconnectCts.Token); + var combinedTask = Task.WhenAll(transportTask, applicationTask); + + Log.ClientStarted(_logger, _url ?? string.Empty); + + // Periodic task: keep alive + // Disposed with `timer.Stop()` in the finally block below + if (_keepAlive.HasValue) + { + timer = new TimerAwaitable(_keepAlive.Value, _keepAlive.Value); + timerTask = TimerLoopAsync(timer); + } + + // We are connected! + IsConnected = true; + + // Block until transport or application ends. + await combinedTask.ConfigureAwait(false); + + // Signal that we're done + _disconnectCts.Cancel(); + Log.ClientTransportApplicationCompleted(_logger, _url); + } + finally + { + timer?.Stop(); + + if (timerTask != null) + { + await timerTask.ConfigureAwait(false); + } + } + + Log.ClientCompleted(_logger, _url ?? string.Empty); + } + + internal async Task TimerLoopAsync(TimerAwaitable timer) + { + timer.Start(); + + using (timer) + { + // await returns True until `timer.Stop()` is called in the `finally` block of `ReceiveLoop` + while (await timer) + { + try + { + // Ping server + var response = await _session.SendRequestAsync(StreamingRequest.CreateGet("/api/version"), _disconnectCts.Token).ConfigureAwait(false); + + if (!IsSuccessResponse(response)) + { + Log.ClientKeepAliveFail(_logger, _url, response.StatusCode); + + IsConnected = false; + + Disconnected?.Invoke(this, new DisconnectedEventArgs() { Reason = $"Received failure from server heartbeat: {response.StatusCode}." }); + } + else + { + Log.ClientKeepAliveSucceed(_logger, _url); + } + } +#pragma warning disable CA1031 // Do not catch general exception types + catch (Exception e) +#pragma warning restore CA1031 // Do not catch general exception types + { + Log.ClientKeepAliveFail(_logger, _url, 0, e); + IsConnected = false; + Disconnected?.Invoke(this, new DisconnectedEventArgs() { Reason = $"Received failure from server heartbeat: {e}." }); + } + } + } + } + + /// + /// Disposes objected used by the class. + /// + /// A Boolean that indicates whether the method call comes from a Dispose method (its value is true) or from a finalizer (its value is false). + /// + /// The disposing parameter should be false when called from a finalizer, and true when called from the IDisposable.Dispose method. + /// In other words, it is true when deterministically called and false when non-deterministically called. + /// + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + try + { + Disconnect(); + _disconnectCts.Cancel(); + } + finally + { + _transportHandler.Dispose(); + _disconnectCts.Dispose(); + } + } + + _disposed = true; + } + + private static bool IsSuccessResponse(ReceiveResponse response) + { + return response != null && response.StatusCode >= 200 && response.StatusCode <= 299; + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + + private class Log + { + private static readonly Action _clientStarted = + LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(ClientStarted)), "WebSocket client connected to {string}."); + + private static readonly Action _clientCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(ClientKeepAliveSucceed)), "WebSocket client connection to {string} closed."); + + private static readonly Action _clientKeepAliveSucceed = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(ClientStarted)), "WebSocket client heartbeat to {string} succeeded."); + + private static readonly Action _clientKeepAliveFail = + LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(ClientKeepAliveFail)), "WebSocket client heartbeat to {string} failed with status code {int}."); + + private static readonly Action _clientTransportApplicationCompleted = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ClientTransportApplicationCompleted)), "WebSocket client heartbeat to {string} completed transport and application tasks."); + + public static void ClientStarted(ILogger logger, string url) => _clientStarted(logger, url ?? string.Empty, null); + + public static void ClientCompleted(ILogger logger, string url) => _clientCompleted(logger, url ?? string.Empty, null); + + public static void ClientKeepAliveSucceed(ILogger logger, string url) => _clientKeepAliveSucceed(logger, url ?? string.Empty, null); + + public static void ClientKeepAliveFail(ILogger logger, string url, int statusCode = 0, Exception e = null) => _clientKeepAliveFail(logger, url ?? string.Empty, statusCode, e); + + public static void ClientTransportApplicationCompleted(ILogger logger, string url) => _clientTransportApplicationCompleted(logger, url ?? string.Empty, null); + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketStreamingConnection.cs b/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketStreamingConnection.cs new file mode 100644 index 0000000000..180e230821 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Application/WebSocketStreamingConnection.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Bot.Connector.Streaming.Session; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Application +{ + /// + /// Default implementation of for WebSocket transport. + /// + public class WebSocketStreamingConnection : StreamingConnection + { + private readonly ILogger _logger; + private readonly HttpContext _httpContext; + private readonly TaskCompletionSource _sessionInitializedTask = new TaskCompletionSource(); + + private StreamingSession _session; + private CancellationToken _cancellationToken; + + /// + /// Initializes a new instance of the class. + /// + /// instance on which to accept the web socket. + /// for the connection. + public WebSocketStreamingConnection(HttpContext httpContext, ILogger logger) + : this(logger) + { + _httpContext = httpContext ?? throw new ArgumentNullException(nameof(httpContext)); + } + + internal WebSocketStreamingConnection(ILogger logger) + { + _logger = logger ?? NullLogger.Instance; + } + + /// + public override async Task SendStreamingRequestAsync(StreamingRequest request, CancellationToken cancellationToken = default(CancellationToken)) + { + if (request == null) + { + throw new ArgumentNullException(nameof(request)); + } + + // This request could come fast while the session, transport and application are still being set up. + // Wait for the session to signal that application and transport started before using the session. + await _sessionInitializedTask.Task.ConfigureAwait(false); + + if (_session == null) + { + throw new InvalidOperationException("Cannot send streaming request since the session is not set up."); + } + + return await _session.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + } + + /// + public override async Task ListenAsync(RequestHandler requestHandler, CancellationToken cancellationToken = default(CancellationToken)) + { + if (requestHandler == null) + { + throw new ArgumentNullException(nameof(requestHandler)); + } + + await ListenImplAsync( + socketConnectFunc: t => t.ConnectAsync(_httpContext, CancellationToken.None), + requestHandler: requestHandler, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + internal async Task ListenInternalAsync(WebSocket webSocket, RequestHandler requestHandler, CancellationToken cancellationToken = default(CancellationToken)) + { + if (requestHandler == null) + { + throw new ArgumentNullException(nameof(requestHandler)); + } + + if (requestHandler == null) + { + throw new ArgumentNullException(nameof(requestHandler)); + } + + await ListenImplAsync( + socketConnectFunc: t => t.ProcessSocketAsync(webSocket, cancellationToken), + requestHandler: requestHandler, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + private async Task ListenImplAsync(Func socketConnectFunc, RequestHandler requestHandler, CancellationToken cancellationToken = default(CancellationToken)) + { + var duplexPipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + _cancellationToken = cancellationToken; + + // Create transport and application + var transport = new WebSocketTransport(duplexPipePair.Application, _logger); + var application = new TransportHandler(duplexPipePair.Transport, _logger); + + // Create session + _session = new StreamingSession(requestHandler, application, _logger, _cancellationToken); + + // Start transport and application + var transportTask = socketConnectFunc(transport); + var applicationTask = application.ListenAsync(cancellationToken); + + var tasks = new List() { transportTask, applicationTask }; + + // Signal that session is ready to be used + _sessionInitializedTask.SetResult(true); + + // Let application and transport run + await Task.WhenAll(tasks).ConfigureAwait(false); + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/AssemblyInfo.cs b/libraries/Microsoft.Bot.Connector.Streaming/AssemblyInfo.cs new file mode 100644 index 0000000000..0fd6a13f1b --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/AssemblyInfo.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; + +#if SIGNASSEMBLY +[assembly: InternalsVisibleTo("Microsoft.Bot.Connector.Streaming.Tests", PublicKey=0024000004800000940000000602000000240000525341310004000001000100b5fc90e7027f67871e773a8fde8938c81dd402ba65b9201d60593e96c492651e889cc13f1415ebb53fac1131ae0bd333c5ee6021672d9718ea31a8aebd0da0072f25d87dba6fc90ffd598ed4da35e44c398c454307e8e33b8426143daec9f596836f97c8f74750e5975c64e2189f45def46b2a2b1247adc3652bf5c308055da9")] +#else +[assembly: InternalsVisibleTo("Microsoft.Bot.Connector.Streaming.Tests")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2")] +#endif diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Microsoft.Bot.Connector.Streaming.csproj b/libraries/Microsoft.Bot.Connector.Streaming/Microsoft.Bot.Connector.Streaming.csproj new file mode 100644 index 0000000000..680ad2d2b0 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Microsoft.Bot.Connector.Streaming.csproj @@ -0,0 +1,46 @@ + + + + $(LocalPackageVersion) + $(ReleasePackageVersion) + $(LocalPackageVersion) + $(ReleasePackageVersion) + Debug;Release + $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb + bin\$(Configuration)\$(TargetFramework)\Microsoft.Bot.Connector.Streaming.xml + + + + netstandard2.0 + Full + Microsoft.Bot.Connector.Streaming + Streaming library for the Bot Framework SDK + Streaming library for the Bot Framework SDK + + + + Full + true + + + + $(NoWarn); + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Session/StreamingSession.cs b/libraries/Microsoft.Bot.Connector.Streaming/Session/StreamingSession.cs new file mode 100644 index 0000000000..8644486850 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Session/StreamingSession.cs @@ -0,0 +1,523 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Payloads; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Net.Http.Headers; +using Newtonsoft.Json; +using JsonSerializer = Newtonsoft.Json.JsonSerializer; + +namespace Microsoft.Bot.Connector.Streaming.Session +{ + internal class StreamingSession + { + // Utf byte order mark constant as defined + // Dotnet runtime: https://github.com/dotnet/runtime/blob/main/src/libraries/System.Text.Json/src/System/Text/Json/JsonConstants.cs#L35 + // Unicode.org spec: https://www.unicode.org/faq/utf_bom.html#bom5 + private static byte[] _utf8Bom = { 0xEF, 0xBB, 0xBF }; + + private readonly Dictionary _streamDefinitions = new Dictionary(); + private readonly Dictionary _requests = new Dictionary(); + private readonly Dictionary _responses = new Dictionary(); + private readonly ConcurrentDictionary> _pendingResponses = new ConcurrentDictionary>(); + + private readonly RequestHandler _receiver; + private readonly TransportHandler _sender; + + private readonly ILogger _logger; + private readonly CancellationToken _connectionCancellationToken; + + private readonly object _receiveSync = new object(); + + public StreamingSession(RequestHandler receiver, TransportHandler sender, ILogger logger, CancellationToken connectionCancellationToken = default) + { + _receiver = receiver ?? throw new ArgumentNullException(nameof(receiver)); + _sender = sender ?? throw new ArgumentNullException(nameof(sender)); + _sender.Subscribe(new ProtocolDispatcher(this)); + + _logger = logger ?? NullLogger.Instance; + _connectionCancellationToken = connectionCancellationToken; + } + + public async Task SendRequestAsync(StreamingRequest request, CancellationToken cancellationToken) + { + if (request == null) + { + throw new ArgumentNullException(nameof(request)); + } + + var payload = new RequestPayload() + { + Verb = request.Verb, + Path = request.Path, + }; + + if (request.Streams != null) + { + payload.Streams = new List(); + foreach (var contentStream in request.Streams) + { + var description = GetStreamDescription(contentStream); + + payload.Streams.Add(description); + } + } + + var requestId = Guid.NewGuid(); + + var responseCompletionSource = new TaskCompletionSource(); + _pendingResponses.TryAdd(requestId, responseCompletionSource); + + // Send request + await _sender.SendRequestAsync(requestId, payload, cancellationToken).ConfigureAwait(false); + + foreach (var stream in request.Streams) + { + await _sender.SendStreamAsync(stream.Id, await stream.Content.ReadAsStreamAsync().ConfigureAwait(false), cancellationToken).ConfigureAwait(false); + } + + // Timeout: We could be waiting for this TaskCompletionSource forever if the connection is broken + // before this response gets back, blocking termination on this thread. + using (var timeoutCancellationTokenSource = new CancellationTokenSource()) + { + var completedTask = await Task.WhenAny(responseCompletionSource.Task, Task.Delay(TimeSpan.FromSeconds(5), timeoutCancellationTokenSource.Token)).ConfigureAwait(false); + if (completedTask == responseCompletionSource.Task) + { + timeoutCancellationTokenSource.Cancel(); + return await responseCompletionSource.Task.ConfigureAwait(false); + } + else + { + throw new TimeoutException($"The operation has timed out"); + } + } + } + + public async Task SendResponseAsync(Header header, StreamingResponse response, CancellationToken cancellationToken) + { + if (header == null) + { + throw new ArgumentNullException(nameof(header)); + } + + if (header.Type != PayloadTypes.Response) + { + throw new InvalidOperationException($"StreamingSession SendResponseAsync expected Response payload, but instead received a payload of type {header.Type}"); + } + + if (response == null) + { + throw new ArgumentNullException(nameof(response)); + } + + var payload = new ResponsePayload() + { + StatusCode = response.StatusCode, + }; + + if (response.Streams != null) + { + payload.Streams = new List(); + foreach (var contentStream in response.Streams) + { + var description = GetStreamDescription(contentStream); + + payload.Streams.Add(description); + } + } + + await _sender.SendResponseAsync(header.Id, payload, cancellationToken).ConfigureAwait(false); + } + + public virtual void ReceiveRequest(Header header, ReceiveRequest request) + { + if (header == null) + { + throw new ArgumentNullException(nameof(header)); + } + + if (header.Type != PayloadTypes.Request) + { + throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as request."); + } + + if (request == null) + { + throw new ArgumentNullException(nameof(request)); + } + + Log.PayloadReceived(_logger, header); + + lock (_receiveSync) + { + _requests.Add(header.Id, request); + + if (request.Streams.Any()) + { + foreach (var streamDefinition in request.Streams) + { + _streamDefinitions.Add(streamDefinition.Id, streamDefinition as StreamDefinition); + } + } + else + { + ProcessRequest(header.Id, request); + } + } + } + + public virtual void ReceiveResponse(Header header, ReceiveResponse response) + { + if (header == null) + { + throw new ArgumentNullException(nameof(header)); + } + + if (header.Type != PayloadTypes.Response) + { + throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as response"); + } + + if (response == null) + { + throw new ArgumentNullException(nameof(response)); + } + + Log.PayloadReceived(_logger, header); + + lock (_receiveSync) + { + if (!response.Streams.Any()) + { + if (_pendingResponses.TryGetValue(header.Id, out TaskCompletionSource responseTask)) + { + responseTask.SetResult(response); + _pendingResponses.TryRemove(header.Id, out TaskCompletionSource removedResponse); + } + } + else + { + _responses.Add(header.Id, response); + + foreach (var streamDefinition in response.Streams) + { + _streamDefinitions.Add(streamDefinition.Id, streamDefinition as StreamDefinition); + } + } + } + } + + public virtual void ReceiveStream(Header header, ArraySegment payload) + { + if (header == null) + { + throw new ArgumentNullException(nameof(header)); + } + + if (header.Type != PayloadTypes.Stream) + { + throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as stream"); + } + + if (payload == null) + { + throw new ArgumentNullException(nameof(payload)); + } + + Log.PayloadReceived(_logger, header); + + // Find request for incoming stream header + if (_streamDefinitions.TryGetValue(header.Id, out StreamDefinition streamDefinition)) + { + streamDefinition.Stream.Write(payload.Array, payload.Offset, payload.Count); + + // Is this the end of this stream? + if (header.End) + { + // Mark this stream as completed + if (streamDefinition is StreamDefinition streamDef) + { + streamDef.Complete = true; + streamDef.Stream.Seek(0, SeekOrigin.Begin); + + List streams = null; + + // Find the request / response + if (streamDef.PayloadType == PayloadTypes.Request) + { + if (_requests.TryGetValue(streamDef.PayloadId, out ReceiveRequest req)) + { + streams = req.Streams; + } + } + else if (streamDef.PayloadType == PayloadTypes.Response) + { + if (_responses.TryGetValue(streamDef.PayloadId, out ReceiveResponse res)) + { + streams = res.Streams; + } + } + + if (streams != null) + { + lock (_receiveSync) + { + // Have we completed all the streams we expect for this request? + bool allStreamsDone = streams.All(s => s is StreamDefinition streamDef && streamDef.Complete); + + // If we received all the streams, then it's time to pass this request to the request handler! + // For example, if this request is a send activity, the request handler will deserialize the first stream + // into an activity and pass to the adapter. + if (allStreamsDone) + { + if (streamDef.PayloadType == PayloadTypes.Request) + { + if (_requests.TryGetValue(streamDef.PayloadId, out ReceiveRequest request)) + { + ProcessRequest(streamDef.PayloadId, request); + _requests.Remove(streamDef.PayloadId); + } + } + else if (streamDef.PayloadType == PayloadTypes.Response) + { + if (_responses.TryGetValue(streamDef.PayloadId, out ReceiveResponse response)) + { + if (_pendingResponses.TryGetValue(streamDef.PayloadId, out TaskCompletionSource responseTask)) + { + responseTask.SetResult(response); + _responses.Remove(streamDef.PayloadId); + _pendingResponses.TryRemove(streamDef.PayloadId, out TaskCompletionSource removedResponse); + } + } + } + } + } + } + } + } + } + else + { + Log.OrphanedStream(_logger, header); + } + } + + private static StreamDescription GetStreamDescription(ResponseMessageStream stream) + { + var description = new StreamDescription() + { + Id = stream.Id.ToString("D"), + }; + + if (stream.Content.Headers.TryGetValues(HeaderNames.ContentType, out IEnumerable contentType)) + { + description.ContentType = contentType?.FirstOrDefault(); + } + + if (stream.Content.Headers.TryGetValues(HeaderNames.ContentLength, out IEnumerable contentLength)) + { + var value = contentLength?.FirstOrDefault(); + if (value != null && int.TryParse(value, out int length)) + { + description.Length = length; + } + } + else + { + description.Length = (int?)stream.Content.Headers.ContentLength; + } + + return description; + } + + private static ArraySegment GetArraySegment(ReadOnlySequence sequence) + { + if (sequence.IsSingleSegment) + { + if (MemoryMarshal.TryGetArray(sequence.First, out ArraySegment segment)) + { + return segment; + } + } + + // Can be optimized by not copying but should be uncommon. If perf data shows that we are hitting this + // code branch, then we can optimize and avoid copies and heap allocations. + return new ArraySegment(sequence.ToArray()); + } + + private void ProcessRequest(Guid id, ReceiveRequest request) + { + _ = Task.Run(async () => + { + var streamingResponse = await _receiver.ProcessRequestAsync(request, null).ConfigureAwait(false); + await SendResponseAsync(new Header() { Id = id, Type = PayloadTypes.Response }, streamingResponse, _connectionCancellationToken).ConfigureAwait(false); + + request.Streams.ForEach(s => _streamDefinitions.Remove(s.Id)); + }); + } + + internal class ProtocolDispatcher : IObserver<(Header Header, ReadOnlySequence Payload)> + { + private readonly StreamingSession _streamingSession; + + public ProtocolDispatcher(StreamingSession streamingSession) + { + _streamingSession = streamingSession ?? throw new ArgumentNullException(nameof(streamingSession)); + } + + public void OnCompleted() + { + throw new NotImplementedException(); + } + + public void OnError(Exception error) + { + throw new NotImplementedException(); + } + + public void OnNext((Header Header, ReadOnlySequence Payload) frame) + { + var header = frame.Header; + var payload = frame.Payload; + + switch (header.Type) + { + case PayloadTypes.Stream: + _streamingSession.ReceiveStream(header, GetArraySegment(payload)); + + break; + case PayloadTypes.Request: + + var requestPayload = DeserializeTo(payload); + var request = new ReceiveRequest() + { + Verb = requestPayload.Verb, + Path = requestPayload.Path, + Streams = new List(), + }; + + CreatePlaceholderStreams(header, request.Streams, requestPayload.Streams); + _streamingSession.ReceiveRequest(header, request); + + break; + + case PayloadTypes.Response: + + var responsePayload = DeserializeTo(payload); + var response = new ReceiveResponse() + { + StatusCode = responsePayload.StatusCode, + Streams = new List(), + }; + + CreatePlaceholderStreams(header, response.Streams, responsePayload.Streams); + _streamingSession.ReceiveResponse(header, response); + + break; + + case PayloadTypes.CancelAll: + break; + + case PayloadTypes.CancelStream: + break; + } + } + + private static T DeserializeTo(ReadOnlySequence payload) + { + // The payload here will likely have a UTF-8 byte-order-mark (BOM). + // The JsonSerializer and UtfJsonReader explicitly expect no BOM in this overload that takes a ReadOnlySequence. + // With that in mind, we check for a UTF-8 BOM and remove it if present. The main reason to call this specific flow instead of + // the stream version or using Json.Net is that the ReadOnlySequence API allows us to do a no-copy deserialization. + // The ReadOnlySequence was allocated from the memory pool by the transport layer and gets sent all the way here without copies. + + // Check for UTF-8 BOM and remove if present: https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-5-0#filter-data-using-utf8jsonreader + var potentialBomSequence = payload.Slice(payload.Start, _utf8Bom.Length); + var potentialBomSpan = potentialBomSequence.IsSingleSegment + ? potentialBomSequence.First.Span + : potentialBomSequence.ToArray(); + + ReadOnlySequence mainPayload = payload; + + if (potentialBomSpan.StartsWith(_utf8Bom)) + { + mainPayload = payload.Slice(_utf8Bom.Length); + } + + var reader = new Utf8JsonReader(mainPayload); + return System.Text.Json.JsonSerializer.Deserialize( + ref reader, + new JsonSerializerOptions() { IgnoreNullValues = true, PropertyNameCaseInsensitive = true }); + } + + private static void CreatePlaceholderStreams(Header header, List placeholders, List streamInfo) + { + if (streamInfo != null) + { + foreach (var streamDescription in streamInfo) + { + if (!Guid.TryParse(streamDescription.Id, out Guid id)) + { + throw new InvalidDataException($"Stream description id '{streamDescription.Id}' is not a Guid"); + } + + placeholders.Add(new StreamDefinition() + { + ContentType = streamDescription.ContentType, + Length = streamDescription.Length, + Id = Guid.Parse(streamDescription.Id), + Stream = new MemoryStream(), + PayloadType = header.Type, + PayloadId = header.Id + }); + } + } + } + } + + internal class StreamDefinition : IContentStream + { + public Guid Id { get; set; } + + public string ContentType { get; set; } + + public int? Length { get; set; } + + public Stream Stream { get; set; } + + public bool Complete { get; set; } + + public char PayloadType { get; set; } + + public Guid PayloadId { get; set; } + } + + private class Log + { + private static readonly Action _orphanedStream = + LoggerMessage.Define(LogLevel.Error, new EventId(1, nameof(OrphanedStream)), "Stream has no associated payload. Header: ID {Guid} Type {char} Payload length:{int}. End :{bool}."); + + private static readonly Action _payloadReceived = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, nameof(PayloadReceived)), "Payload received in session. Header: ID {Guid} Type {char} Payload length:{int}. End :{bool}.."); + + public static void OrphanedStream(ILogger logger, Header header) => _orphanedStream(logger, header.Id, header.Type, header.PayloadLength, header.End, null); + + public static void PayloadReceived(ILogger logger, Header header) => _payloadReceived(logger, header.Id, header.Type, header.PayloadLength, header.End, null); + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/DuplexPipe.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/DuplexPipe.cs new file mode 100644 index 0000000000..8ec3ccb557 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/DuplexPipe.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.IO.Pipelines; + +namespace Microsoft.Bot.Connector.Streaming.Transport +{ + internal class DuplexPipe : IDuplexPipe + { + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transport = new DuplexPipe(output.Reader, input.Writer); + var application = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(transport, application); + } + + internal readonly struct DuplexPipePair + { + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + + public IDuplexPipe Transport { get; } + + public IDuplexPipe Application { get; } + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/RequestPayload.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/RequestPayload.cs new file mode 100644 index 0000000000..364c55b85c --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/RequestPayload.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.Bot.Streaming.Payloads; +using Newtonsoft.Json; + +namespace Microsoft.Bot.Connector.Streaming.Payloads +{ + internal class RequestPayload + { +#pragma warning disable SA1609 + /// + /// Gets or sets request verb, null on responses. + /// + [JsonProperty("verb")] + public string Verb { get; set; } + + /// + /// Gets or sets request path; null on responses. + /// + [JsonProperty("path")] + public string Path { get; set; } + + /// + /// Gets or sets assoicated stream descriptions. + /// + [JsonProperty("streams")] + public List Streams { get; set; } +#pragma warning restore SA1609 + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/ResponsePayload.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/ResponsePayload.cs new file mode 100644 index 0000000000..d99ce28950 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/Payloads/ResponsePayload.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.Bot.Streaming.Payloads; +using Newtonsoft.Json; + +namespace Microsoft.Bot.Connector.Streaming.Payloads +{ + internal class ResponsePayload + { +#pragma warning disable SA1609 + /// + /// Gets or sets status - The Response Status. + /// + [JsonProperty("statusCode")] + public int StatusCode { get; set; } + + /// + /// Gets or sets assoicated stream descriptions. + /// + [JsonProperty("streams")] + public List Streams { get; set; } +#pragma warning restore SA1609 + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/TransportHandler.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/TransportHandler.cs new file mode 100644 index 0000000000..dc6e48b955 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/TransportHandler.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Payloads; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Bot.Streaming.Transport; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; + +namespace Microsoft.Bot.Connector.Streaming.Transport +{ + internal class TransportHandler : IObservable<(Header Header, ReadOnlySequence Payload)>, IDisposable + { + private readonly IDuplexPipe _transport; + private readonly ILogger _logger; + + private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); + private readonly TimeSpan _semaphoreTimeout = TimeSpan.FromSeconds(10); + private readonly byte[] _sendHeaderBuffer = new byte[TransportConstants.MaxHeaderLength]; + + private IObserver<(Header, ReadOnlySequence)> _observer; + private bool _disposedValue; + + public TransportHandler(IDuplexPipe transport, ILogger logger) + { + _transport = transport; + _logger = logger; + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We want to catch all exceptions in the message loop.")] + public async Task ListenAsync(CancellationToken cancellationToken) + { + var input = _transport.Input; + bool aborted = false; + + while (!cancellationToken.IsCancellationRequested) + { + ReadResult result; + + result = await input.ReadAsync().ConfigureAwait(false); + + var buffer = result.Buffer; + + try + { + if (result.IsCanceled) + { + break; + } + + if (!buffer.IsEmpty) + { + while (TryParseHeader(ref buffer, out Header header)) + { + Log.PayloadReceived(_logger, header); + + ReadOnlySequence payload = ReadOnlySequence.Empty; + + if (header.PayloadLength > 0) + { + if (buffer.Length < header.PayloadLength) + { + input.AdvanceTo(buffer.Start, buffer.End); + + result = await input.ReadAsync().ConfigureAwait(false); + + if (result.IsCanceled) + { + break; + } + + buffer = result.Buffer; + } + + if (buffer.Length >= header.PayloadLength) + { + payload = buffer.Slice(buffer.Start, header.PayloadLength); + buffer = buffer.Slice(header.PayloadLength); + } + else + { + break; + } + } + + _observer.OnNext((header, payload)); + } + } + + if (result.IsCompleted) + { + if (buffer.IsEmpty) + { + break; + } + } + } + catch (OperationCanceledException) + { + // Don't treat OperationCanceledException as an error, it's basically a "control flow" + // exception to stop things from running. + } + catch (Exception ex) + { + Log.ReadFrameFailed(_logger, ex); + + // This failure means we are tearing down the connection, so return and let the cancellation + // and draining take place. + await input.CompleteAsync(ex).ConfigureAwait(false); + aborted = true; + + Log.ListenError(_logger, ex); + + return; + } + finally + { + if (!aborted) + { + input.AdvanceTo(buffer.Start, buffer.End); + } + } + } + + await input.CompleteAsync().ConfigureAwait(false); + + await _transport.Output.CompleteAsync().ConfigureAwait(false); + Log.ListenCompleted(_logger); + } + + public Task StopAsync() + { + _transport.Input.CancelPendingRead(); + return Task.CompletedTask; + } + + public virtual async Task SendResponseAsync(Guid id, ResponsePayload response, CancellationToken cancellationToken = default) + { + if (response == null) + { + throw new ArgumentNullException(nameof(response)); + } + + var responseBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(response)); + + var responseHeader = new Header() + { + Type = PayloadTypes.Response, + Id = id, + PayloadLength = (int)responseBytes.Length, + End = true, + }; + + await WriteAsync( + header: responseHeader, + writeFunc: async pipeWriter => await pipeWriter.WriteAsync(responseBytes).ConfigureAwait(false)).ConfigureAwait(false); + } + + public virtual async Task SendRequestAsync(Guid id, RequestPayload request, CancellationToken cancellationToken) + { + if (request == null) + { + throw new ArgumentNullException(nameof(request)); + } + + var requestBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(request)); + + var requestHeader = new Header() + { + Type = PayloadTypes.Request, + Id = id, + PayloadLength = (int)requestBytes.Length, + End = true, + }; + + await WriteAsync( + header: requestHeader, + writeFunc: async pipeWriter => await pipeWriter.WriteAsync(requestBytes).ConfigureAwait(false)).ConfigureAwait(false); + } + + public virtual async Task SendStreamAsync(Guid id, Stream stream, CancellationToken cancellationToken) + { + if (stream == null) + { + throw new ArgumentNullException(nameof(stream)); + } + + var streamHeader = new Header() + { + Type = PayloadTypes.Stream, + Id = id, + PayloadLength = (int)stream.Length, + End = true, + }; + + await WriteAsync(streamHeader, pipeWriter => stream.CopyToAsync(pipeWriter)).ConfigureAwait(false); + } + + public IDisposable Subscribe(IObserver<(Header, ReadOnlySequence)> observer) + { + if (_observer != null) + { + throw new InvalidOperationException("The protocol expects only a single observer."); + } + + _observer = observer ?? throw new ArgumentNullException(nameof(observer)); + + return null; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _writeLock?.Dispose(); + } + + _disposedValue = true; + } + } + + private static bool TryParseHeader(ref ReadOnlySequence buffer, out Header header) + { + if (buffer.IsEmpty) + { + header = null; + return false; + } + + var length = Math.Min(TransportConstants.MaxHeaderLength, buffer.Length); + var headerBuffer = buffer.Slice(0, length); + + if (headerBuffer.Length != TransportConstants.MaxHeaderLength) + { + header = null; + return false; + } + + // Optimization opportunity: instead of headerBuffer.ToArray() which does a 48 byte heap allocation, + // do a best effort attempt to use MemoryMashal.TryGetArray. Since it has a lot of corner cases, + // keeping it simple for now and we can optimize further if data says we required it. + // Alternatively we can have a 48 byte buffer that we reuse, considering that we always + // have a single thread running a given transportHandler instance. + header = HeaderSerializer.Deserialize(headerBuffer.ToArray(), 0, TransportConstants.MaxHeaderLength); + + buffer = buffer.Slice(TransportConstants.MaxHeaderLength); + + return true; + } + + private async Task WriteAsync(Header header, Func writeFunc, CancellationToken cancellationToken = default) + { + var output = _transport.Output; + Log.SendingPayload(_logger, header); + + if (await _writeLock.WaitAsync(_semaphoreTimeout, cancellationToken).ConfigureAwait(false)) + { + try + { + HeaderSerializer.Serialize(header, _sendHeaderBuffer, 0); + await output.WriteAsync(_sendHeaderBuffer).ConfigureAwait(false); + await writeFunc(output).ConfigureAwait(false); + } + finally + { + _writeLock.Release(); + } + } + else + { + Log.SemaphoreTimeOut(_logger, header); + } + } + + private static class Log + { + private static readonly Action _payloadReceived = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(PayloadReceived)), "Payload received. Header: ID {Guid} Type {char} Payload length:{int}. End :{bool}."); + + private static readonly Action _readFrameFailed = + LoggerMessage.Define(LogLevel.Error, new EventId(2, nameof(ReadFrameFailed)), "Failed to read frame from transport."); + + private static readonly Action _payloadSending = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(SendingPayload)), "Sending Payload. Header: ID {Guid} Type {char} Payload length:{int}. End :{bool}."); + + private static readonly Action _semaphoreTimeOut = + LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(SemaphoreTimeOut)), "Timed out trying to acquire write semaphore. Header: ID {Guid} Type {char} Payload length:{int}. End :{bool}."); + + private static readonly Action _listenError = + LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(ListenError)), "TransportHandler encountered an error and will stop listening."); + + private static readonly Action _listenCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(6, nameof(ListenCompleted)), "TransportHandler listen task completed."); + + public static void PayloadReceived(ILogger logger, Header header) => _payloadReceived(logger, header.Id, header.Type, header.PayloadLength, header.End, null); + + public static void ReadFrameFailed(ILogger logger, Exception ex) => _readFrameFailed(logger, ex); + + public static void SendingPayload(ILogger logger, Header header) => _payloadSending(logger, header.Id, header.Type, header.PayloadLength, header.End, null); + + public static void SemaphoreTimeOut(ILogger logger, Header header) => _semaphoreTimeOut(logger, header.Id, header.Type, header.PayloadLength, header.End, null); + + public static void ListenError(ILogger logger, Exception ex) => _listenError(logger, ex); + + public static void ListenCompleted(ILogger logger) => _listenCompleted(logger, null); + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketExtensions.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketExtensions.cs new file mode 100644 index 0000000000..6ef8a297ad --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketExtensions.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Buffers; +using System.Net.WebSockets; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Bot.Connector.Streaming.Transport +{ + internal static class WebSocketExtensions + { + public static ValueTask SendAsync(this WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + { + if (buffer.IsSingleSegment) + { + var isArray = MemoryMarshal.TryGetArray(buffer.First, out var segment); + return new ValueTask(webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: true, cancellationToken)); + } + else + { + return SendMultiSegmentAsync(webSocket, buffer, webSocketMessageType, cancellationToken); + } + } + + private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + { + var position = buffer.Start; + + // Get a segment before the loop so we can be one segment behind while writing + // This allows us to do a non-zero byte write for the endOfMessage = true send + buffer.TryGet(ref position, out var prevSegment); + while (buffer.TryGet(ref position, out var segment)) + { + var isArray = MemoryMarshal.TryGetArray(prevSegment, out var arraySegment); + await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken).ConfigureAwait(false); + prevSegment = segment; + } + + // End of message frame + if (MemoryMarshal.TryGetArray(prevSegment, out var arraySegmentEnd)) + { + await webSocket.SendAsync(arraySegmentEnd, webSocketMessageType, endOfMessage: true, cancellationToken).ConfigureAwait(false); + } + } + } +} diff --git a/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketTransport.cs b/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketTransport.cs new file mode 100644 index 0000000000..ddc0eda223 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector.Streaming/Transport/WebSocketTransport.cs @@ -0,0 +1,382 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net.WebSockets; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Transport +{ + internal class WebSocketTransport + { + private readonly IDuplexPipe _application; + private readonly ILogger _logger; + + private volatile bool _aborted; + + public WebSocketTransport(IDuplexPipe application, ILogger logger) + { + _application = application ?? throw new ArgumentNullException(nameof(application)); + _logger = logger ?? NullLogger.Instance; + } + + public async Task ConnectAsync(HttpContext context, CancellationToken cancellationToken) + { + using (var ws = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false)) + { + Log.SocketOpened(_logger); + + try + { + await ProcessSocketAsync(ws, cancellationToken).ConfigureAwait(false); + } + finally + { + Log.SocketClosed(_logger); + } + } + } + + public async Task ConnectAsync(string url, IDictionary requestHeaders = null, CancellationToken cancellationToken = default) + { + using (var ws = new ClientWebSocket()) + { + Log.SocketOpened(_logger); + + try + { + if (requestHeaders != null) + { + foreach (var key in requestHeaders.Keys) + { + ws.Options.SetRequestHeader(key, requestHeaders[key]); + } + } + + await ws.ConnectAsync(new Uri(url), cancellationToken).ConfigureAwait(false); + + await ProcessSocketAsync(ws, cancellationToken).ConfigureAwait(false); + } + finally + { + Log.SocketClosed(_logger); + } + } + } + + internal async Task ProcessSocketAsync(WebSocket socket, CancellationToken cancellationToken) + { + // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + var receiving = StartReceivingAsync(socket, cancellationToken); + var sending = StartSendingAsync(socket); + + // Wait for send or receive to complete + var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); + + if (trigger == receiving) + { + Log.WaitingForSend(_logger); + + // We're waiting for the application to finish and there are 2 things it could be doing + // 1. Waiting for application data + // 2. Waiting for a websocket send to complete + + // Cancel the application so that ReadAsync yields + _application.Input.CancelPendingRead(); + + using (var delayCts = new CancellationTokenSource()) + { + // TODO: flow this timeout to allow draining + var resultTask = await Task.WhenAny(sending, Task.Delay(TimeSpan.FromSeconds(1), delayCts.Token)).ConfigureAwait(false); + + if (resultTask != sending) + { + // We timed out so now we're in ungraceful shutdown mode + Log.CloseTimedOut(_logger); + + // Abort the websocket if we're stuck in a pending send to the client + _aborted = true; + + socket.Abort(); + } + else + { + delayCts.Cancel(); + } + } + } + else + { + Log.WaitingForClose(_logger); + + // We're waiting on the websocket to close and there are 2 things it could be doing + // 1. Waiting for websocket data + // 2. Waiting on a flush to complete (backpressure being applied) + + using (var delayCts = new CancellationTokenSource()) + { + var resultTask = await Task.WhenAny(receiving, Task.Delay(TimeSpan.FromSeconds(1), delayCts.Token)).ConfigureAwait(false); + + if (resultTask != receiving) + { + // Abort the websocket if we're stuck in a pending receive from the client + _aborted = true; + + socket.Abort(); + + // Cancel any pending flush so that we can quit + _application.Output.CancelPendingFlush(); + } + else + { + delayCts.Cancel(); + } + } + } + } + + private static ArraySegment GetArraySegment(ReadOnlyMemory memory) + { + if (!MemoryMarshal.TryGetArray(memory, out var result)) + { + throw new InvalidOperationException("Buffer backed by array was expected"); + } + + return result; + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We want to catch all exceptions in the message loop.")] + private async Task StartReceivingAsync(WebSocket socket, CancellationToken cancellationToken) + { + try + { + while (!cancellationToken.IsCancellationRequested) + { + // Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read + var result = await socket.ReceiveAsync(GetArraySegment(Memory.Empty), cancellationToken).ConfigureAwait(false); + + if (result.MessageType == WebSocketMessageType.Close) + { + return; + } + + var memory = _application.Output.GetMemory(); + + var arraySegment = GetArraySegment(memory); + var receiveResult = await socket.ReceiveAsync(arraySegment, cancellationToken).ConfigureAwait(false); + + // Need to check again for netcoreapp3.0 and later because a close can happen between a 0-byte read and the actual read + if (receiveResult.MessageType == WebSocketMessageType.Close) + { + return; + } + + Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); + + _application.Output.Advance(receiveResult.Count); + var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); + + // We canceled in the middle of applying back pressure + // or if the consumer is done + if (flushResult.IsCanceled || flushResult.IsCompleted) + { + break; + } + } + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + // Client has closed the WebSocket connection without completing the close handshake + Log.ClosedPrematurely(_logger, ex); + } + catch (OperationCanceledException) + { + // Ignore aborts, don't treat them like transport errors + } + catch (Exception ex) + { + if (!_aborted && !cancellationToken.IsCancellationRequested) + { + await _application.Output.CompleteAsync(ex).ConfigureAwait(false); + Log.TransportError(_logger, ex); + } + } + finally + { + // We're done writing. + await _application.Output.CompleteAsync().ConfigureAwait(false); + Log.ReceivingCompleted(_logger); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We want to catch all exceptions in the message loop.")] + private async Task StartSendingAsync(WebSocket socket) + { + Exception error = null; + + try + { + while (true) + { + var result = await _application.Input.ReadAsync().ConfigureAwait(false); + var buffer = result.Buffer; + + // Get a frame from the application + try + { + if (result.IsCanceled) + { + break; + } + + if (!buffer.IsEmpty) + { + try + { + Log.SendPayload(_logger, buffer.Length); + + if (WebSocketCanSend(socket)) + { + await socket.SendAsync(buffer, WebSocketMessageType.Binary, CancellationToken.None).ConfigureAwait(false); + } + else + { + break; + } + } + catch (Exception ex) + { + if (!_aborted) + { + Log.ErrorWritingFrame(_logger, ex); + } + + break; + } + } + else if (result.IsCompleted) + { + break; + } + } + finally + { + _application.Input.AdvanceTo(buffer.End); + } + } + } + catch (Exception ex) + { + error = ex; + } + finally + { + // Send the close frame before calling into user code + if (WebSocketCanSend(socket)) + { + try + { + // We're done sending, send the close frame to the client if the websocket is still open + await socket.CloseOutputAsync(error != null ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); + } + catch (Exception ex) + { + Log.ClosingWebSocketFailed(_logger, ex); + } + } + + Log.SendingCompleted(_logger); + await _application.Input.CompleteAsync().ConfigureAwait(false); + } + } + + private bool WebSocketCanSend(WebSocket ws) + { + return !(ws.State == WebSocketState.Aborted || + ws.State == WebSocketState.Closed || + ws.State == WebSocketState.CloseSent); + } + + /// + /// Log messages for . + /// + /// + /// Messages implementred using to maximize performance. + /// For more information, see https://docs.microsoft.com/en-us/aspnet/core/fundamentals/logging/loggermessage?view=aspnetcore-5.0. + /// + private static class Log + { + private static readonly Action _socketOpened = + LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(SocketOpened)), "Socket transport connection opened."); + + private static readonly Action _socketClosed = + LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(SocketClosed)), "Socket transport connection closed."); + + private static readonly Action _waitingForSend = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(WaitingForSend)), "Waiting for the application to finish sending data."); + + private static readonly Action _waitingForClose = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(WaitingForClose)), "Waiting for the client to close the socket."); + + private static readonly Action _closeTimedOut = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(CloseTimedOut)), "Timed out waiting for client to send the close frame, aborting the connection."); + + private static readonly Action _messageReceived = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, nameof(MessageReceived)), "Message received. Type: {MessageType}, size: {Size}, EndOfMessage: {EndOfMessage}."); + + private static readonly Action _sendPayload = + LoggerMessage.Define(LogLevel.Trace, new EventId(7, nameof(SendPayload)), "Sending payload: {Size} bytes."); + + private static readonly Action _errorWritingFrame = + LoggerMessage.Define(LogLevel.Debug, new EventId(8, nameof(ErrorWritingFrame)), "Error writing frame."); + + private static readonly Action _closedPrematurely = + LoggerMessage.Define(LogLevel.Debug, new EventId(9, nameof(ClosedPrematurely)), "Socket connection closed prematurely."); + + private static readonly Action _closingWebSocketFailed = + LoggerMessage.Define(LogLevel.Debug, new EventId(10, nameof(ClosingWebSocketFailed)), "Closing webSocket failed."); + + private static readonly Action _sendingCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(11, nameof(SendingCompleted)), "Socket transport sending task completed."); + + private static readonly Action _receivingCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(12, nameof(ReceivingCompleted)), "Socket transport receiving task completed."); + + private static readonly Action _transportError = + LoggerMessage.Define(LogLevel.Error, new EventId(13, nameof(TransportError)), "Transport error deteted."); + + public static void SocketOpened(ILogger logger) => _socketOpened(logger, null); + + public static void SocketClosed(ILogger logger) => _socketClosed(logger, null); + + public static void WaitingForSend(ILogger logger) => _waitingForSend(logger, null); + + public static void WaitingForClose(ILogger logger) => _waitingForClose(logger, null); + + public static void CloseTimedOut(ILogger logger) => _closeTimedOut(logger, null); + + public static void MessageReceived(ILogger logger, WebSocketMessageType type, int size, bool endOfMessage) => _messageReceived(logger, type, size, endOfMessage, null); + + public static void SendPayload(ILogger logger, long size) => _sendPayload(logger, size, null); + + public static void ErrorWritingFrame(ILogger logger, Exception ex) => _errorWritingFrame(logger, ex); + + public static void ClosedPrematurely(ILogger logger, Exception ex) => _closedPrematurely(logger, ex); + + public static void ClosingWebSocketFailed(ILogger logger, Exception ex) => _closingWebSocketFailed(logger, ex); + + public static void SendingCompleted(ILogger logger) => _sendingCompleted(logger, null); + + public static void ReceivingCompleted(ILogger logger) => _receivingCompleted(logger, null); + + public static void TransportError(ILogger logger, Exception ex) => _transportError(logger, ex); + } + } +} diff --git a/libraries/Microsoft.Bot.Connector/Authentication/PasswordServiceClientCredentialFactory.cs b/libraries/Microsoft.Bot.Connector/Authentication/PasswordServiceClientCredentialFactory.cs index f8740d71ec..ca4050ea90 100644 --- a/libraries/Microsoft.Bot.Connector/Authentication/PasswordServiceClientCredentialFactory.cs +++ b/libraries/Microsoft.Bot.Connector/Authentication/PasswordServiceClientCredentialFactory.cs @@ -90,7 +90,7 @@ public override Task CreateCredentialsAsync(string app if (appId != AppId) { - throw new InvalidOperationException("Invalid appId"); + throw new InvalidOperationException($"Invalid appId {appId} does not match expected {AppId}"); } if (loginEndpoint.StartsWith(AuthenticationConstants.ToChannelFromBotLoginUrlTemplate, StringComparison.OrdinalIgnoreCase)) diff --git a/libraries/Microsoft.Bot.Streaming/AssemblyInfo.cs b/libraries/Microsoft.Bot.Streaming/AssemblyInfo.cs new file mode 100644 index 0000000000..0fd6a13f1b --- /dev/null +++ b/libraries/Microsoft.Bot.Streaming/AssemblyInfo.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; + +#if SIGNASSEMBLY +[assembly: InternalsVisibleTo("Microsoft.Bot.Connector.Streaming.Tests", PublicKey=0024000004800000940000000602000000240000525341310004000001000100b5fc90e7027f67871e773a8fde8938c81dd402ba65b9201d60593e96c492651e889cc13f1415ebb53fac1131ae0bd333c5ee6021672d9718ea31a8aebd0da0072f25d87dba6fc90ffd598ed4da35e44c398c454307e8e33b8426143daec9f596836f97c8f74750e5975c64e2189f45def46b2a2b1247adc3652bf5c308055da9")] +#else +[assembly: InternalsVisibleTo("Microsoft.Bot.Connector.Streaming.Tests")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2")] +#endif diff --git a/libraries/Microsoft.Bot.Streaming/StreamingRequest.cs b/libraries/Microsoft.Bot.Streaming/StreamingRequest.cs index b2479ecd83..965b5c0fe6 100644 --- a/libraries/Microsoft.Bot.Streaming/StreamingRequest.cs +++ b/libraries/Microsoft.Bot.Streaming/StreamingRequest.cs @@ -58,7 +58,7 @@ public class StreamingRequest /// A of items associated with this request. /// #pragma warning disable CA2227 // Collection properties should be read only (we can't change this without breaking binary compat) - public List Streams { get; set; } + public List Streams { get; set; } = new List(); #pragma warning restore CA2227 // Collection properties should be read only /// diff --git a/libraries/Microsoft.Bot.Streaming/Transport/WebSocket/WebSocketClient.cs b/libraries/Microsoft.Bot.Streaming/Transport/WebSocket/WebSocketClient.cs index 372dc2abcb..7562d9dbb9 100644 --- a/libraries/Microsoft.Bot.Streaming/Transport/WebSocket/WebSocketClient.cs +++ b/libraries/Microsoft.Bot.Streaming/Transport/WebSocket/WebSocketClient.cs @@ -187,6 +187,36 @@ public void Dispose() GC.SuppressFinalize(this); } + /// + /// Establish a connection with injected web socket for more control in tests. + /// + /// A for the client which msut already be . + /// A that will not resolve until the client stops listening for incoming messages. + internal Task ConnectInternalAsync(WebSocket socket) + { + if (IsConnected) + { + return Task.CompletedTask; + } + + // We don't dispose the websocket, since WebSocketTransport is now + // the owner of the web socket. +#pragma warning disable CA2000 // Dispose objects before losing scope + var socketTransport = new WebSocketTransport(socket); +#pragma warning restore CA2000 // Dispose objects before losing scope + + // Listen for disconnected events. + _sender.Disconnected += OnConnectionDisconnected; + _receiver.Disconnected += OnConnectionDisconnected; + + _sender.Connect(socketTransport); + _receiver.Connect(socketTransport); + + IsConnected = true; + + return Task.CompletedTask; + } + /// /// Disposes objected used by the class. /// diff --git a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/BotFrameworkHttpAdapter.cs b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/BotFrameworkHttpAdapter.cs index e8452091f8..a2331edfa1 100644 --- a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/BotFrameworkHttpAdapter.cs +++ b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/BotFrameworkHttpAdapter.cs @@ -207,6 +207,30 @@ public virtual StreamingRequestHandler CreateStreamingRequestHandler(IBot bot, W return new StreamingRequestHandler(bot, this, socket, audience, Logger); } + /// + /// Create the for processing for a new Web Socket connection request. + /// + /// The implementation which will process the request. + /// The instance on which to accept the web socket. + /// The authorized audience of the incoming connection request. + /// Returns a new implementation. + protected virtual async Task CreateStreamingRequestHandlerAsync(IBot bot, HttpContext context, string audience) + { + if (bot == null) + { + throw new ArgumentNullException(nameof(bot)); + } + + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false); + + return CreateStreamingRequestHandler(bot, socket, audience); + } + private static async Task WriteUnauthorizedResponseAsync(string headerName, HttpRequest httpRequest) { httpRequest.HttpContext.Response.StatusCode = (int)HttpStatusCode.Unauthorized; @@ -253,12 +277,10 @@ private async Task ConnectWebSocketAsync(IBot bot, HttpRequest httpRequest, Http try { - var socket = await httpRequest.HttpContext.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false); - // Set ClaimsIdentity on Adapter to enable Skills and User OAuth in WebSocket-based streaming scenarios. var audience = GetAudience(claimsIdentity); - var requestHandler = CreateStreamingRequestHandler(bot, socket, audience); + var requestHandler = await CreateStreamingRequestHandlerAsync(bot, httpRequest.HttpContext, audience).ConfigureAwait(false); if (RequestHandlers == null) { @@ -267,7 +289,9 @@ private async Task ConnectWebSocketAsync(IBot bot, HttpRequest httpRequest, Http RequestHandlers.Add(requestHandler); + Log.WebSocketConnectionStarted(Logger); await requestHandler.ListenAsync().ConfigureAwait(false); + Log.WebSocketConnectionCompleted(Logger); } catch (Exception ex) { @@ -355,5 +379,18 @@ private string GetAudience(ClaimsIdentity claimsIdentity) return null; } + + private class Log + { + private static readonly Action _webSocketConnectionStarted = + LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(WebSocketConnectionStarted)), "WebSocket connection started."); + + private static readonly Action _webSocketConnectionCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(WebSocketConnectionCompleted)), "WebSocket connection completed."); + + public static void WebSocketConnectionStarted(ILogger logger) => _webSocketConnectionStarted(logger, null); + + public static void WebSocketConnectionCompleted(ILogger logger) => _webSocketConnectionCompleted(logger, null); + } } } diff --git a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs index 5e2a9f5627..ce98aba175 100644 --- a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs +++ b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs @@ -2,9 +2,9 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.Net; using System.Net.Http; -using System.Net.WebSockets; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; @@ -12,6 +12,7 @@ using Microsoft.Bot.Builder.Streaming; using Microsoft.Bot.Connector; using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Connector.Streaming.Application; using Microsoft.Bot.Schema; using Microsoft.Bot.Streaming; using Microsoft.Extensions.Configuration; @@ -24,6 +25,8 @@ namespace Microsoft.Bot.Builder.Integration.AspNet.Core /// public class CloudAdapter : CloudAdapterBase, IBotFrameworkHttpAdapter { + private readonly ConcurrentDictionary _streamingConnections = new ConcurrentDictionary(); + /// /// Initializes a new instance of the class. (Public cloud. No auth. For testing.) /// @@ -146,10 +149,42 @@ public async Task ConnectNamedPipeAsync(string pipeName, IBot bot, string appId, }; // Tie the authentication results, the named pipe, the adapter and the bot together to be ready to handle any inbound activities - var streamingActivityProcessor = new StreamingActivityProcessor(authenticationRequestResult, pipeName, this, bot); + using (var streamingActivityProcessor = new StreamingActivityProcessor(authenticationRequestResult, pipeName, this, bot)) + { + // Start receiving activities on the named pipe + // TODO /*_applicationLifetime?.ApplicationStopped ?? */ + await streamingActivityProcessor.ListenAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + /// + protected override ConnectorFactory GetStreamingConnectorFactory(Activity activity) + { + foreach (var connection in _streamingConnections.Values) + { + if (connection.HandlesActivity(activity)) + { + return connection.GetConnectorFactory(); + } + } - // Start receiving activities on the named pipe - await streamingActivityProcessor.ListenAsync().ConfigureAwait(false); + throw new ApplicationException($"No streaming connection found for activity: {activity}"); + } + + /// + /// Creates a that uses web sockets. + /// + /// instance on which to accept the web socket. + /// Logger implementation for tracing and debugging information. + /// that uses web socket. + protected virtual StreamingConnection CreateWebSocketConnection(HttpContext httpContext, ILogger logger) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + return new WebSocketStreamingConnection(httpContext, logger); } private async Task ConnectAsync(HttpRequest httpRequest, IBot bot, CancellationToken cancellationToken) @@ -164,29 +199,37 @@ private async Task ConnectAsync(HttpRequest httpRequest, IBot bot, CancellationT var authenticationRequestResult = await BotFrameworkAuthentication.AuthenticateStreamingRequestAsync(authHeader, channelIdHeader, cancellationToken).ConfigureAwait(false); - // Transition the request to a WebSocket connection - var socket = await httpRequest.HttpContext.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false); - - // Tie the authentication results, the socket, the adapter and the bot together to be ready to handle any inbound activities - var streamingActivityProcessor = new StreamingActivityProcessor(authenticationRequestResult, socket, this, bot); + var connectionId = Guid.NewGuid(); + using (var scope = Logger.BeginScope(connectionId)) + { + var connection = CreateWebSocketConnection(httpRequest.HttpContext, Logger); - // Start receiving activities on the socket - await streamingActivityProcessor.ListenAsync().ConfigureAwait(false); + using (var streamingActivityProcessor = new StreamingActivityProcessor(authenticationRequestResult, connection, this, bot)) + { + // Start receiving activities on the socket + // TODO: pass asp.net core lifetime for cancellation here. + _streamingConnections.TryAdd(connectionId, streamingActivityProcessor); + Log.WebSocketConnectionStarted(Logger); + await streamingActivityProcessor.ListenAsync(CancellationToken.None).ConfigureAwait(false); + _streamingConnections.TryRemove(connectionId, out _); + Log.WebSocketConnectionCompleted(Logger); + } + } } - private class StreamingActivityProcessor : IStreamingActivityProcessor + private class StreamingActivityProcessor : IStreamingActivityProcessor, IDisposable { private readonly AuthenticateRequestResult _authenticateRequestResult; private readonly CloudAdapter _adapter; private readonly StreamingRequestHandler _requestHandler; - public StreamingActivityProcessor(AuthenticateRequestResult authenticateRequestResult, WebSocket socket, CloudAdapter adapter, IBot bot) + public StreamingActivityProcessor(AuthenticateRequestResult authenticateRequestResult, StreamingConnection connection, CloudAdapter adapter, IBot bot) { _authenticateRequestResult = authenticateRequestResult; _adapter = adapter; // Internal reuse of the existing StreamingRequestHandler class - _requestHandler = new StreamingRequestHandler(bot, this, socket, _authenticateRequestResult.Audience, adapter.Logger); + _requestHandler = new StreamingRequestHandler(bot, this, connection, authenticateRequestResult.Audience, logger: adapter.Logger); // Fix up the connector factory so connector create from it will send over this connection _authenticateRequestResult.ConnectorFactory = new StreamingConnectorFactory(_requestHandler); @@ -204,7 +247,23 @@ public StreamingActivityProcessor(AuthenticateRequestResult authenticateRequestR _authenticateRequestResult.ConnectorFactory = new StreamingConnectorFactory(_requestHandler); } - public Task ListenAsync() => _requestHandler.ListenAsync(); + public bool HandlesActivity(Activity activity) + { + return _requestHandler.ServiceUrl.Equals(activity.ServiceUrl, StringComparison.OrdinalIgnoreCase) && + _requestHandler.HasConversation(activity.Conversation.Id); + } + + public ConnectorFactory GetConnectorFactory() + { + return _authenticateRequestResult.ConnectorFactory; + } + + public void Dispose() + { + ((IDisposable)_requestHandler)?.Dispose(); + } + + public Task ListenAsync(CancellationToken cancellationToken) => _requestHandler.ListenAsync(cancellationToken); Task IStreamingActivityProcessor.ProcessStreamingActivityAsync(Activity activity, BotCallbackHandler callback, CancellationToken cancellationToken) => _adapter.ProcessActivityAsync(_authenticateRequestResult, activity, callback, cancellationToken); @@ -274,5 +333,18 @@ private async Task CreateHttpResponseAsync(ReceiveResponse } } } + + private class Log + { + private static readonly Action _webSocketConnectionStarted = + LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(WebSocketConnectionStarted)), "WebSocket connection started."); + + private static readonly Action _webSocketConnectionCompleted = + LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(WebSocketConnectionCompleted)), "WebSocket connection completed."); + + public static void WebSocketConnectionStarted(ILogger logger) => _webSocketConnectionStarted(logger, null); + + public static void WebSocketConnectionCompleted(ILogger logger) => _webSocketConnectionCompleted(logger, null); + } } } diff --git a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/Microsoft.Bot.Builder.Integration.AspNet.Core.csproj b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/Microsoft.Bot.Builder.Integration.AspNet.Core.csproj index 9183761947..988208afe1 100644 --- a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/Microsoft.Bot.Builder.Integration.AspNet.Core.csproj +++ b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/Microsoft.Bot.Builder.Integration.AspNet.Core.csproj @@ -34,6 +34,10 @@ + + + + @@ -44,5 +48,7 @@ + + \ No newline at end of file diff --git a/tests/Microsoft.Bot.Connector.Streaming.Perf/Microsoft.Bot.Connector.Streaming.Perf.csproj b/tests/Microsoft.Bot.Connector.Streaming.Perf/Microsoft.Bot.Connector.Streaming.Perf.csproj new file mode 100644 index 0000000000..343a438ea0 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Perf/Microsoft.Bot.Connector.Streaming.Perf.csproj @@ -0,0 +1,18 @@ + + + + + Exe + netcoreapp3.1 + + + + + + + + + + + + diff --git a/tests/Microsoft.Bot.Connector.Streaming.Perf/Program.cs b/tests/Microsoft.Bot.Connector.Streaming.Perf/Program.cs new file mode 100644 index 0000000000..25649169fb --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Perf/Program.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using BenchmarkDotNet.Running; + +namespace Microsoft.Bot.Connector.Streaming.Perf +{ + public class Program + { + public static void Main(string[] args) + => BenchmarkSwitcher.FromAssembly(typeof(Program).Assembly).Run(args); + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Microsoft.Bot.Connector.Streaming.Tests.Client.csproj b/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Microsoft.Bot.Connector.Streaming.Tests.Client.csproj new file mode 100644 index 0000000000..a330a61e30 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Microsoft.Bot.Connector.Streaming.Tests.Client.csproj @@ -0,0 +1,17 @@ + + + + Exe + netcoreapp3.1 + + + + + + + + + + + + diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Program.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Program.cs new file mode 100644 index 0000000000..2806b67c9d --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Client/Program.cs @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Schema; +using Microsoft.Bot.Streaming; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Console; +using Microsoft.Extensions.Options; +using Newtonsoft.Json; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Client +{ + public class Program + { + private static WebSocketClient _client; + private static Task _clientTask; + private static CancellationTokenSource _cancellationSource; + private static string _conversationId; + + public static void Main(string[] args) + { + Menu(); + + do + { + try + { + DispatchAsync(Console.ReadLine()).GetAwaiter().GetResult(); + } + catch (Exception ex) + { + var originalForegroundColor = Console.ForegroundColor; + WriteLine($"Error: {ex}", ConsoleColor.Red); + } + } + while (true); + } + + private static async Task DispatchAsync(string command) + { + switch (command) + { + case "c": + await ConnectAsync(); + break; + + case "car": + await ConnectAsync(true); + break; + + case "m": + await MessageAsync(); + break; + + case "msplit": + await MessageSplitAsync(); + break; + + case "sd": + await ForceServerDisconnectAsync(); + break; + + case "d": + await DisconnectClientAsync(); + break; + + case "h": + Menu(); + break; + } + } + + private static string AskUser(string message) + { + Console.WriteLine(message); + return Console.ReadLine(); + } + + private static async Task ConnectAsync(bool automaticallyReconnect = false) + { + var url = AskUser("Bot url:"); + var appId = AskUser("Bot app id:"); + var appPassword = AskUser("Bot app password:"); + + var headers = new Dictionary() { { "channelId", "Test" } }; + + if (!string.IsNullOrEmpty(appId) && !string.IsNullOrEmpty(appPassword)) + { + var credentials = new MicrosoftAppCredentials(appId, appPassword); + var token = await credentials.GetTokenAsync(); + + headers.Add("Authorization", $"Bearer {token}"); + } + + var configureNamedOptions = new ConfigureNamedOptions(string.Empty, null); + var optionsFactory = new OptionsFactory(new[] { configureNamedOptions }, Enumerable.Empty>()); + var optionsMonitor = new OptionsMonitor(optionsFactory, Enumerable.Empty>(), new OptionsCache()); + + // Improvement opportunity: expose command / argument to control log level. + var loggerFactory = new LoggerFactory(new[] { new ConsoleLoggerProvider(optionsMonitor) }, new LoggerFilterOptions { MinLevel = LogLevel.Debug }); + + _cancellationSource = new CancellationTokenSource(); + + _client = new WebSocketClient(url, new ConsoleRequestHandler(), logger: loggerFactory.CreateLogger("WebSocketClient")); + _client.Disconnected += Client_Disconnected; + _clientTask = _client.ConnectAsync(headers, _cancellationSource.Token); + } + + private static void Client_Disconnected(object sender, Bot.Streaming.Transport.DisconnectedEventArgs e) + { + WriteLine($"[Program] Client disconnected. Reason: {e?.Reason}.", foregroundColor: ConsoleColor.Yellow); + var response = AskUser("Attempt to reconnect the existing connection? y / n"); + + // Let the client gracefully finish + WriteLine("[Program] Waiting for graceful completion...", foregroundColor: ConsoleColor.Yellow); + _clientTask.GetAwaiter().GetResult(); + + if (response == "y") + { + WriteLine("[Program] Reconnecting..."); + ConnectAsync().GetAwaiter().GetResult(); + } + else + { + WriteLine("[Program] Client shut down completed gracefully"); + } + } + + private static async Task MessageAsync() + { + if (_client == null || !_client.IsConnected) + { + WriteLine("[Program] Client is not connected, connect before sending messages."); + } + + var text = AskUser("[Program] Enter text:"); + + WriteLine($"[User]: {text}", ConsoleColor.Cyan); + + if (string.IsNullOrEmpty(_conversationId)) + { + _conversationId = Guid.NewGuid().ToString(); + } + + var activity = new Schema.Activity() + { + Id = Guid.NewGuid().ToString(), + Type = ActivityTypes.Message, + From = new ChannelAccount { Id = "testUser" }, + Conversation = new ConversationAccount { Id = _conversationId }, + Recipient = new ChannelAccount { Id = "testBot" }, + ServiceUrl = "wss://InvalidServiceUrl/api/messages", + ChannelId = "Test", + Text = text, + }; + + var request = StreamingRequest.CreatePost("/api/messages", new StringContent(JsonConvert.SerializeObject(activity), Encoding.UTF8, "application/json")); + + var stopwatch = Stopwatch.StartNew(); + + var response = await _client.SendAsync(request, CancellationToken.None); + } + + private static Task MessageSplitAsync() + { + throw new NotImplementedException(); + } + + private static Task ForceServerDisconnectAsync() + { + throw new NotImplementedException(); + } + + private static async Task DisconnectClientAsync() + { + await _client.DisconnectAsync(); + if (_cancellationSource != null) + { + _cancellationSource.Cancel(); + + _cancellationSource.Dispose(); + _cancellationSource = null; + } + + await _clientTask; + } + + private static void Menu() + { + Console.WriteLine("Welcome to the streaming client."); + Console.WriteLine("Commands:"); + Console.WriteLine("c - Connect client"); + Console.WriteLine("car - Connect client with automatic reconnect"); + Console.WriteLine("m - Send message activity to bot"); + Console.WriteLine("msplit - Send message activity to bot, split between request and stream, allowing commands in between."); + Console.WriteLine("sd - Force server disconnect"); + Console.WriteLine("d - Disconnect client"); + Console.WriteLine("h - Help"); + } + + private static void WriteLine(string message, ConsoleColor foregroundColor = ConsoleColor.White, ConsoleColor backgroundColor = ConsoleColor.Black) + { + // Save original color. + //var originalForegroundColor = Console.ForegroundColor; + //var originalBackgroundColor = Console.BackgroundColor; + + // Set requested color. + Console.ForegroundColor = foregroundColor; + Console.BackgroundColor = backgroundColor; + + // Write message. + Console.WriteLine(message); + + // Restore original colors. + Console.ResetColor(); + + //var Console.ForegroundColor = originalForegroundColor; + //var Console.BackgroundColor = originalBackgroundColor; + } + + private class ConsoleRequestHandler : RequestHandler + { + public override async Task ProcessRequestAsync(ReceiveRequest request, ILogger logger, object context = null, CancellationToken cancellationToken = default) + { + var response = await request.ReadBodyAsJsonAsync().ConfigureAwait(false); + System.Console.WriteLine($"[Bot]: {response?.Text}"); + return await Task.FromResult(StreamingResponse.OK()).ConfigureAwait(false); + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Controllers/BotController.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Controllers/BotController.cs new file mode 100644 index 0000000000..61390784c6 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Controllers/BotController.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Server +{ + // This ASP Controller is created to handle a request. Dependency Injection will provide the Adapter and IBot + // implementation at runtime. Multiple different IBot implementations running at different endpoints can be + // achieved by specifying a more specific type for the bot constructor argument. + [Route("api/messages")] + [ApiController] + public class BotController : ControllerBase + { + private readonly IBotFrameworkHttpAdapter _adapter; + private readonly IBot _bot; + + public BotController(IBotFrameworkHttpAdapter adapter, IBot bot) + { + _adapter = adapter; + _bot = bot; + } + + [HttpPost] + [HttpGet] + public async Task PostAsync() + { + // Delegate the processing of the HTTP POST to the adapter. + // The adapter will invoke the bot. + await _adapter.ProcessAsync(Request, Response, _bot); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Microsoft.Bot.Connector.Streaming.Tests.Server.csproj b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Microsoft.Bot.Connector.Streaming.Tests.Server.csproj new file mode 100644 index 0000000000..a128eb6727 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Microsoft.Bot.Connector.Streaming.Tests.Server.csproj @@ -0,0 +1,27 @@ + + + + netcoreapp3.1 + latest + + + + + + + + + + + + + Always + + + + + + Always + + + diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Program.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Program.cs new file mode 100644 index 0000000000..f69bc91961 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Program.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Server +{ + public class Program + { + public static void Main(string[] args) + { + CreateHostBuilder(args).Build().Run(); + } + + public static IHostBuilder CreateHostBuilder(string[] args) => + Host.CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.ConfigureLogging((logging) => + { + logging.AddDebug(); + logging.AddConsole(); + }); + webBuilder.UseStartup(); + }); + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Startup.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Startup.cs new file mode 100644 index 0000000000..a42a129f93 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/Startup.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Bot.Builder.Dialogs.Adaptive.Runtime.Extensions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Server +{ + public class Startup + { + private readonly IConfiguration _configuration; + + //1private readonly IHostApplicationLifetime _hostAppLifetime; + + public Startup(IConfiguration configuration) + { + _configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + + //_hostAppLifetime = appLifetime ?? throw new ArgumentNullException(nameof(appLifetime)); + } + + // This method gets called by the runtime. Use this method to add services to the container. + // For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940 + public void ConfigureServices(IServiceCollection services) + { + services.AddControllers().AddNewtonsoftJson(); + services.AddBotRuntime(_configuration); + + services.Configure( + opts => opts.ShutdownTimeout = TimeSpan.FromSeconds(30)); + } + + // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseDefaultFiles() + .UseStaticFiles() + .UseWebSockets() + .UseRouting() + .UseAuthorization() + .UseEndpoints(endpoints => + { + endpoints.MapControllers(); + }); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/appsettings.json b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/appsettings.json new file mode 100644 index 0000000000..23754fdd26 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/appsettings.json @@ -0,0 +1,12 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug" + } + }, + "AllowedHosts": "*", + "MicrosoftAppId": "", + "MicrosoftAppPassword": "", + "ConnectionName": "", + "defaultRootDialog": "snapshot.dialog" +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/snapshot.dialog b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/snapshot.dialog new file mode 100644 index 0000000000..6dd2dfe3d1 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests.Server/snapshot.dialog @@ -0,0 +1,19 @@ +{ + "$kind": "Microsoft.AdaptiveDialog", + "triggers": [ + { + "$kind": "Microsoft.OnMessageActivity", + "actions": [ + { + "$kind": "Microsoft.SendActivity", + "activity": "Hello world" + }, + { + "$kind": "Microsoft.SetProperty", + "property": "turn.x", + "value": "y" + } + ] + } + ] +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/LegacyStreamingConnectionTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/LegacyStreamingConnectionTests.cs new file mode 100644 index 0000000000..0e26ee86da --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/LegacyStreamingConnectionTests.cs @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Net.WebSockets; +using System.Runtime.InteropServices.ComTypes; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Bot.Streaming.Transport; +using Microsoft.Bot.Streaming.Transport.NamedPipes; +using Microsoft.Bot.Streaming.Transport.WebSockets; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Application +{ + public class LegacyStreamingConnectionTests + { + [Fact] + public void ConstructorTests() + { + var webSocketConnection = new LegacyStreamingConnection(new TestWebSocket(), null); + var namedPipeConnection = new LegacyStreamingConnection("test", null); + } + + [Fact] + public void CannotCreateWithoutValidWebSocket() + { + Assert.Throws(() => + { + WebSocket socket = null; + _ = new LegacyStreamingConnection(socket, NullLogger.Instance); + }); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void CannotCreateWithoutValidPipeName(string pipeName) + { + Assert.Throws(() => + { + _ = new LegacyStreamingConnection(pipeName, NullLogger.Instance); + }); + } + + [Fact] + public void CanCreateWebSocketServer() + { + var socket = new TestWebSocket(); + var requestHandler = new TestRequestHandler(); + + var sut = new LegacyStreamingConnection(socket, NullLogger.Instance); + + var server = sut.CreateStreamingTransportServer(requestHandler); + Assert.True(server is WebSocketServer); + } + + [Fact] + public void CanCreateNamedPipeServer() + { + var requestHandler = new TestRequestHandler(); + + var sut = new LegacyStreamingConnection("test", NullLogger.Instance); + + var server = sut.CreateStreamingTransportServer(requestHandler); + Assert.True(server is NamedPipeServer); + } + + [Fact] + public void CanSendStreamingRequest() + { + var socket = new TestWebSocket(); + var requestHandler = new TestRequestHandler(); + + using (var sut = new TestLegacyStreamingConnection(socket, NullLogger.Instance)) + { + sut.ListenAsync(requestHandler).Wait(); + + var request = new StreamingRequest + { + Verb = "POST", + Path = "/api/messages", + Streams = new List + { + new ResponseMessageStream { Content = new StringContent("foo") } + } + }; + + var response = sut.SendStreamingRequestAsync(request).Result; + + Assert.Equal(request.Streams.Count, response.Streams.Count); + Assert.Equal(request.Streams[0].Id, response.Streams[0].Id); + } + } + + private class TestLegacyStreamingConnection : LegacyStreamingConnection + { + public TestLegacyStreamingConnection(WebSocket socket, ILogger logger, DisconnectedEventHandler onServerDisconnect = null) + : base(socket, logger, onServerDisconnect) + { + } + + public TestLegacyStreamingConnection(string pipeName, ILogger logger, DisconnectedEventHandler onServerDisconnect = null) + : base(pipeName, logger, onServerDisconnect) + { + } + + internal override IStreamingTransportServer CreateStreamingTransportServer(RequestHandler requestHandler) + { + return new TestStreamingTransportServer(); + } + } + + private class TestStreamingTransportServer : IStreamingTransportServer, IDisposable + { + public event DisconnectedEventHandler Disconnected; + + public Task StartAsync() + { + return Task.CompletedTask; + } + + public Task SendAsync(StreamingRequest request, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(new ReceiveResponse + { + StatusCode = 200, + Streams = new List(request.Streams.Select(s => new TestContentStream(s.Id))) + }); + } + + public void Dispose() + { + if (Disconnected != null) + { + Disconnected(this, DisconnectedEventArgs.Empty); + } + } + } + + private class TestWebSocket : WebSocket + { + public override WebSocketCloseStatus? CloseStatus { get; } + + public override string CloseStatusDescription { get; } + + public override WebSocketState State { get; } + + public override string SubProtocol { get; } + + public override void Abort() + { + throw new NotImplementedException(); + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public override void Dispose() + { + throw new NotImplementedException(); + } + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + } + + private class TestRequestHandler : RequestHandler + { + public override Task ProcessRequestAsync(ReceiveRequest request, ILogger logger, object context = null, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(new StreamingResponse { StatusCode = 200 }); + } + } + + private class TestContentStream : IContentStream + { + public TestContentStream(Guid id) + { + Id = id; + } + + public Guid Id { get; } + + public string ContentType { get; set; } + + public int? Length { get; set; } + + public Stream Stream { get; } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/ObjectWithTimerAwaitable.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/ObjectWithTimerAwaitable.cs new file mode 100644 index 0000000000..c2e360efbf --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/ObjectWithTimerAwaitable.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Application; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Application +{ + // This object holds onto a TimerAwaitable referencing the callback (the async continuation is the callback) + // it also has a finalizer that triggers a tcs so callers can be notified when this object is being cleaned up. + public class ObjectWithTimerAwaitable + { + private readonly TimerAwaitable _timer; + private readonly TaskCompletionSource _tcs; + + public ObjectWithTimerAwaitable(TaskCompletionSource tcs) + { + _tcs = tcs; + _timer = new TimerAwaitable(TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(1)); + _timer.Start(); + } + + ~ObjectWithTimerAwaitable() + { + _tcs.TrySetResult(true); + } + + public async Task Start() + { + using (_timer) + { + while (await _timer) + { + } + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/TimerAwaitableTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/TimerAwaitableTests.cs new file mode 100644 index 0000000000..7821480ee1 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Application/TimerAwaitableTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Connector.Streaming.Tests.Tools; +using Xunit; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Application +{ + public class TimerAwaitableTests + { + [Fact] + public async Task FinalizerRunsIfTimerAwaitableReferencesObject() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + UseTimerAwaitableAndUnref(tcs); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // Make sure the finalizer runs + await tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(30)); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void UseTimerAwaitableAndUnref(TaskCompletionSource tcs) + { + _ = new ObjectWithTimerAwaitable(tcs).Start(); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/ApplicationToApplicationIntegrationTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/ApplicationToApplicationIntegrationTests.cs new file mode 100644 index 0000000000..9bdb56e867 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/ApplicationToApplicationIntegrationTests.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Connector.Streaming.Tests.Features; +using Microsoft.Bot.Connector.Streaming.Tests.Tools; +using Microsoft.Bot.Streaming; +using Moq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class ApplicationToApplicationIntegrationTests + { + private readonly ITestOutputHelper _outputHelper; + + public ApplicationToApplicationIntegrationTests(ITestOutputHelper outputHelper) + { + _outputHelper = outputHelper; + } + + [Fact] + public async Task Integration_ListenSendShutDownServer() + { + // TODO: Transform this test into a theory and do multi-message, multi-thread, multi-client, etc. + var logger = XUnitLogger.CreateLogger(_outputHelper); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Bot / server setup + var botRequestHandler = new Mock(); + + botRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var connection = new WebSocketStreamingConnection(logger); + + var socket = await webSocketFeature.AcceptAsync().ConfigureAwait(false); + var serverTask = Task.Run(() => connection.ListenInternalAsync(socket, botRequestHandler.Object)); + + // Client / channel setup + var clientRequestHandler = new Mock(); + + clientRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var client = new WebSocketClient(clientRequestHandler.Object, logger: logger); + + var clientTask = Task.Run(() => client.ConnectInternalAsync(webSocketFeature.Client, CancellationToken.None)); + + // Send request bot (server) -> channel (client) + const string path = "api/version"; + const string botToClientPayload = "Hello human, I'm Bender!"; + var request = StreamingRequest.CreatePost(path, new StringContent(botToClientPayload)); + + var responseFromClient = await connection.SendStreamingRequestAsync(request).ConfigureAwait(false); + + Assert.Equal(200, responseFromClient.StatusCode); + + const string clientToBotPayload = "Hello bot, I'm Calculon!"; + var clientRequest = StreamingRequest.CreatePost(path, new StringContent(clientToBotPayload)); + + // Send request bot channel (client) -> (server) + var clientToBotResult = await client.SendAsync(clientRequest).ConfigureAwait(false); + + Assert.Equal(200, clientToBotResult.StatusCode); + + await client.DisconnectAsync().ConfigureAwait(false); + + await clientTask.ConfigureAwait(false); + await serverTask.ConfigureAwait(false); + } + } + + [Fact] + public async Task Integration_KeepAlive() + { + // TODO: Transform this test into a theory and do multi-message, multi-thread, multi-client, etc. + var logger = XUnitLogger.CreateLogger(_outputHelper); + var cts = new CancellationTokenSource(); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Bot / server setup + var botRequestHandler = new Mock(); + + botRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var connection = new WebSocketStreamingConnection(logger); + + var socket = await webSocketFeature.AcceptAsync().ConfigureAwait(false); + var serverTask = Task.Run(() => connection.ListenInternalAsync(socket, botRequestHandler.Object, cts.Token)); + + // Client / channel setup + var clientRequestHandler = new Mock(); + + clientRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var client = new WebSocketClient(clientRequestHandler.Object, logger: logger, closeTimeOut: TimeSpan.FromSeconds(10), keepAlive: TimeSpan.FromMilliseconds(200)); + + var clientTask = Task.Run(() => client.ConnectInternalAsync(webSocketFeature.Client, CancellationToken.None)); + + // Send request bot (server) -> channel (client) + const string path = "api/version"; + const string botToClientPayload = "Hello human, I'm Bender!"; + var request = StreamingRequest.CreatePost(path, new StringContent(botToClientPayload)); + + var responseFromClient = await connection.SendStreamingRequestAsync(request).ConfigureAwait(false); + + Assert.Equal(200, responseFromClient.StatusCode); + + const string clientToBotPayload = "Hello bot, I'm Calculon!"; + var clientRequest = StreamingRequest.CreatePost(path, new StringContent(clientToBotPayload)); + + // Send request bot channel (client) -> (server) + var clientToBotResult = await client.SendAsync(clientRequest).ConfigureAwait(false); + + Assert.Equal(200, clientToBotResult.StatusCode); + + await Task.Delay(TimeSpan.FromSeconds(3)).ConfigureAwait(false); + + cts.Cancel(); + + await clientTask.ConfigureAwait(false); + await serverTask.ConfigureAwait(false); + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/EndToEndMiniLoadTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/EndToEndMiniLoadTests.cs new file mode 100644 index 0000000000..d2bec28ef4 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/EndToEndMiniLoadTests.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class EndToEndMiniLoadTests + { + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/InteropApplicationIntegrationTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/InteropApplicationIntegrationTests.cs new file mode 100644 index 0000000000..6e0544bb18 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/InteropApplicationIntegrationTests.cs @@ -0,0 +1,317 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Connector.Streaming.Tests.Features; +using Microsoft.Bot.Connector.Streaming.Tests.Tools; +using Microsoft.Bot.Streaming; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class InteropApplicationIntegrationTests + { + private readonly ITestOutputHelper _outputHelper; + + public InteropApplicationIntegrationTests(ITestOutputHelper outputHelper) + { + _outputHelper = outputHelper; + } + + [Fact] + public async Task Integration_Interop_LegacyClient() + { + // TODO: Transform this test into a theory and do multi-message, multi-thread, multi-client, etc. + var logger = XUnitLogger.CreateLogger(_outputHelper); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Bot / server setup + var botRequestHandler = new Mock(); + + botRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var connection = new WebSocketStreamingConnection(logger); + + var socket = await webSocketFeature.AcceptAsync().ConfigureAwait(false); + var serverTask = Task.Run(() => connection.ListenInternalAsync(socket, botRequestHandler.Object)); + + // Client / channel setup + var clientRequestHandler = new Mock(); + + clientRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + using (var client = new Microsoft.Bot.Streaming.Transport.WebSockets.WebSocketClient("wss://test", clientRequestHandler.Object)) + { + await client.ConnectInternalAsync(webSocketFeature.Client).ConfigureAwait(false); + + // Send request bot (server) -> channel (client) + const string path = "api/version"; + const string botToClientPayload = "Hello human, I'm Bender!"; + var request = StreamingRequest.CreatePost(path, new StringContent(botToClientPayload)); + + var responseFromClient = await connection.SendStreamingRequestAsync(request).ConfigureAwait(false); + + Assert.Equal(200, responseFromClient.StatusCode); + + const string clientToBotPayload = "Hello bot, I'm Calculon!"; + var clientRequest = StreamingRequest.CreatePost(path, new StringContent(clientToBotPayload)); + + // Send request bot channel (client) -> (server) + var clientToBotResult = await client.SendAsync(clientRequest).ConfigureAwait(false); + + Assert.Equal(200, clientToBotResult.StatusCode); + client.Disconnect(); + } + + await serverTask.ConfigureAwait(false); + } + } + + [Theory] + [InlineData(32, 1024)] + [InlineData(4, 1000)] + [InlineData(4, 100)] + [InlineData(8, 100)] + [InlineData(16, 100)] + [InlineData(32, 100)] + public async Task Integration_Interop_LegacyClient_MiniLoad(int threadCount, int messageCount) + { + var logger = XUnitLogger.CreateLogger(_outputHelper); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + var botRequestHandler = new Mock(); + + botRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var connection = new WebSocketStreamingConnection(logger); + + var socket = await webSocketFeature.AcceptAsync().ConfigureAwait(false); + var serverTask = Task.Run(() => connection.ListenInternalAsync(socket, botRequestHandler.Object)); + await Task.Delay(TimeSpan.FromSeconds(1)); + var clients = new List(); + + var clientRequestHandler = new Mock(); + + clientRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + using (var client = new Microsoft.Bot.Streaming.Transport.WebSockets.WebSocketClient( + "wss://test", + clientRequestHandler.Object)) + { + await client.ConnectInternalAsync(webSocketFeature.Client).ConfigureAwait(false); + clients.Add(client); + + // Send request bot (server) -> channel (client) + const string path = "api/version"; + const string botToClientPayload = "Hello human, I'm Bender!"; + + Func testFlow = async (i) => + { + var request = StreamingRequest.CreatePost(path, new StringContent(botToClientPayload)); + + var stopwatch = Stopwatch.StartNew(); + var responseFromClient = + await connection.SendStreamingRequestAsync(request).ConfigureAwait(false); + stopwatch.Stop(); + + Assert.Equal(200, responseFromClient.StatusCode); + logger.LogInformation( + $"Server->Client {i} latency: {stopwatch.ElapsedMilliseconds}. Status code: {responseFromClient.StatusCode}"); + + const string clientToBotPayload = "Hello bot, I'm Calculon!"; + var clientRequest = StreamingRequest.CreatePost(path, new StringContent(clientToBotPayload)); + + stopwatch = Stopwatch.StartNew(); + + // Send request bot channel (client) -> (server) + var clientToBotResult = await client.SendAsync(clientRequest).ConfigureAwait(false); + stopwatch.Stop(); + + Assert.Equal(200, clientToBotResult.StatusCode); + + logger.LogInformation( + $"Client->Server {i} latency: {stopwatch.ElapsedMilliseconds}. Status code: {responseFromClient.StatusCode}"); + }; + + await testFlow(-1).ConfigureAwait(false); + var tasks = new List(); + + using (var throttler = new SemaphoreSlim(threadCount)) + { + for (int j = 0; j < messageCount; j++) + { + await throttler.WaitAsync().ConfigureAwait(false); + + // using Task.Run(...) to run the lambda in its own parallel + // flow on the threadpool + tasks.Add( + Task.Run(async () => + { + try + { + await testFlow(j).ConfigureAwait(false); + } + finally + { + throttler.Release(); + } + })); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + } + + client.Disconnect(); + } + + await serverTask.ConfigureAwait(false); + } + } + + [Theory] + [InlineData(32, 1000)] + public async Task Integration_NewClient_MiniLoad(int threadCount, int messageCount) + { + // TODO: Transform this test into a theory and do multi-message, multi-thread, multi-client, etc. + var logger = XUnitLogger.CreateLogger(_outputHelper); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Bot / server setup + var botRequestHandler = new Mock(); + + botRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + var connection = new WebSocketStreamingConnection(logger); + + var socket = await webSocketFeature.AcceptAsync().ConfigureAwait(false); + var serverTask = Task.Run(() => connection.ListenInternalAsync(socket, botRequestHandler.Object)); + await Task.Delay(TimeSpan.FromSeconds(1)); + + //Parallel.For(0, clientCount, async i => + { + // Client / channel setup + var clientRequestHandler = new Mock(); + + clientRequestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }); + + using (var client = new WebSocketClient($"wss://test", clientRequestHandler.Object, logger: logger)) + { + var clientTask = client.ConnectInternalAsync(webSocketFeature.Client, CancellationToken.None); + + // Send request bot (server) -> channel (client) + const string path = "api/version"; + const string botToClientPayload = "Hello human, I'm Bender!"; + + Func testFlow = async (i) => + { + var request = StreamingRequest.CreatePost(path, new StringContent(botToClientPayload)); + + var stopwatch = Stopwatch.StartNew(); + var responseFromClient = await connection.SendStreamingRequestAsync(request).ConfigureAwait(false); + stopwatch.Stop(); + + Assert.Equal(200, responseFromClient.StatusCode); + logger.LogInformation($"Server->Client {i} latency: {stopwatch.ElapsedMilliseconds}. Status code: {responseFromClient.StatusCode}"); + + const string clientToBotPayload = "Hello bot, I'm Calculon!"; + var clientRequest = StreamingRequest.CreatePost(path, new StringContent(clientToBotPayload)); + + stopwatch = Stopwatch.StartNew(); + + // Send request bot channel (client) -> (server) + var clientToBotResult = await client.SendAsync(clientRequest).ConfigureAwait(false); + stopwatch.Stop(); + + Assert.Equal(200, clientToBotResult.StatusCode); + + logger.LogInformation($"Client->Server {i} latency: {stopwatch.ElapsedMilliseconds}. Status code: {responseFromClient.StatusCode}"); + }; + + await testFlow(-1).ConfigureAwait(false); + var tasks = new List(); + + using (var throttler = new SemaphoreSlim(threadCount)) + { + for (int j = 0; j < messageCount; j++) + { + await throttler.WaitAsync().ConfigureAwait(false); + + // using Task.Run(...) to run the lambda in its own parallel + // flow on the threadpool + tasks.Add( + Task.Run(async () => + { + try + { + await testFlow(j).ConfigureAwait(false); + } + finally + { + throttler.Release(); + } + })); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + } + + await client.DisconnectAsync().ConfigureAwait(false); + await clientTask.ConfigureAwait(false); + } + + await serverTask.ConfigureAwait(false); + } + } + } + + private static void RunWithLimitedParalelism(List tasks, int maxTasksToRunInParallel, int timeoutInMilliseconds, CancellationToken cancellationToken = new CancellationToken()) + { + // Convert to a list of tasks so that we don't enumerate over it multiple times needlessly. + using (var throttler = new SemaphoreSlim(maxTasksToRunInParallel)) + { + var postTaskTasks = new List(); + + // Have each task notify the throttler when it completes so that it decrements the number of tasks currently running. + tasks.ForEach(t => postTaskTasks.Add(t.ContinueWith(tsk => throttler.Release()))); + + // Start running each task. + foreach (var task in tasks) + { + // Increment the number of tasks currently running and wait if too many are running. + throttler.Wait(timeoutInMilliseconds, cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + task.Start(); + } + + // Wait for all of the provided tasks to complete. + // We wait on the list of "post" tasks instead of the original tasks, otherwise there is a potential race condition where the throttler's using block is exited before some Tasks have had their "post" action completed, which references the throttler, resulting in an exception due to accessing a disposed object. + Task.WaitAll(postTaskTasks.ToArray(), cancellationToken); + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/WebSocketTransportClientServerTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/WebSocketTransportClientServerTests.cs new file mode 100644 index 0000000000..ba8d6e6133 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Integration/WebSocketTransportClientServerTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Tests.Features; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class WebSocketTransportClientServerTests + { + [Fact] + public async Task WebSocketTransport_ClientServer_WhatIsSentIsReceived() + { + var serverPipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var clientPipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Build server transport + var serverTransport = new WebSocketTransport(serverPipePair.Application, NullLogger.Instance); + + // Accept server web socket, start receiving / sending at the transport level + var serverTask = serverTransport.ProcessSocketAsync(await webSocketFeature.AcceptAsync(), CancellationToken.None); + + var clientTransport = new WebSocketTransport(clientPipePair.Application, NullLogger.Instance); + var clientTask = clientTransport.ProcessSocketAsync(webSocketFeature.Client, CancellationToken.None); + + // Send a frame client -> server + await clientPipePair.Transport.Output.WriteAsync(new ArraySegment(Encoding.UTF8.GetBytes("Hello"))); + await clientPipePair.Transport.Output.FlushAsync(); + + var result = await serverPipePair.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.ToArray())); + serverPipePair.Transport.Input.AdvanceTo(buffer.End); + + // Send a frame server -> client + await serverPipePair.Transport.Output.WriteAsync(new ArraySegment(Encoding.UTF8.GetBytes("World"))); + await serverPipePair.Transport.Output.FlushAsync(); + + var clientResult = await clientPipePair.Transport.Input.ReadAsync(); + buffer = clientResult.Buffer; + + Assert.Equal("World", Encoding.UTF8.GetString(buffer.ToArray())); + clientPipePair.Transport.Input.AdvanceTo(buffer.End); + + clientPipePair.Transport.Output.Complete(); + serverPipePair.Transport.Output.Complete(); + + // The transport should finish now + await serverTask; + await clientTask; + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Microsoft.Bot.Connector.Streaming.Tests.csproj b/tests/Microsoft.Bot.Connector.Streaming.Tests/Microsoft.Bot.Connector.Streaming.Tests.csproj new file mode 100644 index 0000000000..ca36a5412a --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Microsoft.Bot.Connector.Streaming.Tests.csproj @@ -0,0 +1,47 @@ + + + + netcoreapp2.1 + netcoreapp3.1 + netcoreapp2.1;netcoreapp3.1 + false + false + Debug;Release + + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/ProtocolDispatcherTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/ProtocolDispatcherTests.cs new file mode 100644 index 0000000000..75e7976df4 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/ProtocolDispatcherTests.cs @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using Microsoft.Bot.Connector.Streaming.Session; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Extensions.Logging.Abstractions; +using Moq; +using Newtonsoft.Json; +using Xunit; +using static Microsoft.Bot.Connector.Streaming.Session.StreamingSession; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class ProtocolDispatcherTests + { + [Fact] + public void ProtocolDispatcher_NullSession_Throws() + { + Assert.Throws(() => new ProtocolDispatcher(null)); + } + + [Fact] + public void ProtocolDispatcher_DispatchRequest() + { + // Arrange + var request = new RequestPayload() + { + Verb = "GET", + Path = "api/version", + Streams = new List() + { + new StreamDescription() { ContentType = "json", Id = Guid.NewGuid().ToString(), Length = 18 }, + new StreamDescription() { ContentType = "text", Id = Guid.NewGuid().ToString(), Length = 24 } + } + }; + + var requestJson = JsonConvert.SerializeObject(request); + var requestBytes = Encoding.UTF8.GetBytes(requestJson); + + var header = new Header() + { + End = true, + Id = Guid.NewGuid(), + PayloadLength = requestBytes.Length, + Type = PayloadTypes.Request + }; + + var callCount = 0; + + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var requestHandler = new Mock(); + + var session = new Mock(requestHandler.Object, transportHandler.Object, NullLogger.Instance, CancellationToken.None); + + session.Setup( + s => s.ReceiveRequest(It.IsAny
(), It.IsAny())) + .Callback((Header h, ReceiveRequest r) => + { + callCount++; + + // Assert + Assert.Equal(h.Id, header.Id); + Assert.Equal(request.Verb, r.Verb); + Assert.Equal(request.Path, r.Path); + Assert.Equal(request.Streams.Count, r.Streams.Count); + + var firstStream = r.Streams.First() as StreamDefinition; + Assert.Equal(request.Streams.First().Id, firstStream.Id.ToString()); + Assert.Equal(request.Streams.First().Length, firstStream.Length); + Assert.IsType(firstStream.Stream); + Assert.Equal(h.Id, firstStream.PayloadId); + }); + + var dispatcher = new ProtocolDispatcher(session.Object); + + // Act + dispatcher.OnNext((header, new ReadOnlySequence(requestBytes))); + + // Assert + Assert.Equal(1, callCount); + } + + [Fact] + public void ProtocolDispatcher_DispatchResponse() + { + // Arrange + var request = new ResponsePayload() + { + StatusCode = 200, + Streams = new List() + { + new StreamDescription() { ContentType = "json", Id = Guid.NewGuid().ToString(), Length = 18 }, + new StreamDescription() { ContentType = "text", Id = Guid.NewGuid().ToString(), Length = 24 } + } + }; + + var requestJson = JsonConvert.SerializeObject(request); + var requestBytes = Encoding.UTF8.GetBytes(requestJson); + + var header = new Header() + { + End = true, + Id = Guid.NewGuid(), + PayloadLength = requestBytes.Length, + Type = PayloadTypes.Response + }; + + var callCount = 0; + + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var requestHandler = new Mock(); + + var session = new Mock(requestHandler.Object, transportHandler.Object, NullLogger.Instance, CancellationToken.None); + + session.Setup( + s => s.ReceiveResponse(It.IsAny
(), It.IsAny())) + .Callback((Header h, ReceiveResponse r) => + { + callCount++; + + // Assert + Assert.Equal(h.Id, header.Id); + Assert.Equal(request.StatusCode, r.StatusCode); + Assert.Equal(request.Streams.Count, r.Streams.Count); + + var firstStream = r.Streams.First() as StreamDefinition; + Assert.Equal(request.Streams.First().Id, firstStream.Id.ToString()); + Assert.Equal(request.Streams.First().Length, firstStream.Length); + Assert.IsType(firstStream.Stream); + Assert.Equal(0, firstStream.Stream.Length); + }); + + var dispatcher = new ProtocolDispatcher(session.Object); + + // Act + dispatcher.OnNext((header, new ReadOnlySequence(requestBytes))); + + // Assert + Assert.Equal(1, callCount); + } + + [Fact] + public void ProtocolDispatcher_DispatchStream() + { + // Arrange + var buffer = new byte[256]; + new Random().NextBytes(buffer); + + var header = new Header() + { + End = true, + Id = Guid.NewGuid(), + PayloadLength = buffer.Length, + Type = PayloadTypes.Stream + }; + + var callCount = 0; + + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var requestHandler = new Mock(); + + var session = new Mock(requestHandler.Object, transportHandler.Object, NullLogger.Instance, CancellationToken.None); + + session + .Setup(s => s.ReceiveStream(It.IsAny
(), It.IsAny>())) + .Callback((Header h, ArraySegment s) => + { + callCount++; + + // Assert + Assert.Equal(h.Id, header.Id); + Assert.True(s.Array.SequenceEqual(buffer)); + }); + + var dispatcher = new ProtocolDispatcher(session.Object); + + // Act + dispatcher.OnNext((header, new ReadOnlySequence(buffer))); + + // Assert + Assert.Equal(1, callCount); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/StreamingSessionTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/StreamingSessionTests.cs new file mode 100644 index 0000000000..89b2874856 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Session/StreamingSessionTests.cs @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Session; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Extensions.Logging.Abstractions; +using Moq; +using Xunit; +using static Microsoft.Bot.Connector.Streaming.Session.StreamingSession; +using RequestModel = Microsoft.Bot.Connector.Streaming.Payloads.RequestPayload; +using ResponseModel = Microsoft.Bot.Connector.Streaming.Payloads.ResponsePayload; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class StreamingSessionTests + { + public static IEnumerable ReceiveRequestParameterValidationData => + new List + { + new object[] { new Header() { Type = PayloadTypes.Request }, null, typeof(ArgumentNullException) }, + new object[] { null, new ReceiveRequest(), typeof(ArgumentNullException) }, + new object[] { new Header() { Type = PayloadTypes.Response }, new ReceiveRequest(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.Stream }, new ReceiveRequest(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelStream }, new ReceiveRequest(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelAll }, new ReceiveRequest(), typeof(InvalidOperationException) }, + }; + + public static IEnumerable ReceiveResponseParameterValidationData => + new List + { + new object[] { new Header() { Type = PayloadTypes.Response }, null, typeof(ArgumentNullException) }, + new object[] { null, new ReceiveResponse(), typeof(ArgumentNullException) }, + new object[] { new Header() { Type = PayloadTypes.Request }, new ReceiveResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.Stream }, new ReceiveResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelStream }, new ReceiveResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelAll }, new ReceiveResponse(), typeof(InvalidOperationException) }, + }; + + public static IEnumerable ReceiveStreamParameterValidationData => + new List + { + new object[] { new Header() { Type = PayloadTypes.Response }, null, typeof(ArgumentNullException) }, + new object[] { null, Array.Empty(), typeof(ArgumentNullException) }, + new object[] { new Header() { Type = PayloadTypes.Request }, Array.Empty(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.Response }, Array.Empty(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelStream }, Array.Empty(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelAll }, Array.Empty(), typeof(InvalidOperationException) }, + }; + + public static IEnumerable SendResponseParameterValidationData => + new List + { + new object[] { new Header() { Type = PayloadTypes.Response }, null, typeof(ArgumentNullException) }, + new object[] { null, new StreamingResponse(), typeof(ArgumentNullException) }, + new object[] { new Header() { Type = PayloadTypes.Request }, new StreamingResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.Stream }, new StreamingResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelAll }, new StreamingResponse(), typeof(InvalidOperationException) }, + new object[] { new Header() { Type = PayloadTypes.CancelStream }, new StreamingResponse(), typeof(InvalidOperationException) }, + }; + + [Fact] + public void StreamingSession_Constructor_NullRequestHandler_Throws() + { + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + Assert.Throws( + () => new StreamingSession(null, transportHandler.Object, NullLogger.Instance)); + } + + [Fact] + public void StreamingSession_Constructor_NullTransportHandler_Throws() + { + Assert.Throws( + () => new StreamingSession(new Mock().Object, null, NullLogger.Instance)); + } + + [Fact] + public async Task StreamingSession_SendRequest_ParameterValidation() + { + // Arrange + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var session = new StreamingSession(new Mock().Object, transportHandler.Object, NullLogger.Instance); + + // Act + Assert + await Assert.ThrowsAsync(() => session.SendRequestAsync(null, CancellationToken.None)); + } + + [Theory] + [MemberData(nameof(SendResponseParameterValidationData))] + public async Task StreamingSession_SendResponse_ParameterValidation(Header header, StreamingResponse response, Type exceptionType) + { + // Arrange + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var session = new StreamingSession(new Mock().Object, transportHandler.Object, NullLogger.Instance); + + // Act + Assert + await Assert.ThrowsAsync(exceptionType, () => session.SendResponseAsync(header, response, CancellationToken.None)); + } + + [Theory] + [MemberData(nameof(ReceiveRequestParameterValidationData))] + public void StreamingSession_ReceiveRequest_ParameterValidation(Header header, ReceiveRequest request, Type exceptionType) + { + // Arrange + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var session = new StreamingSession(new Mock().Object, transportHandler.Object, NullLogger.Instance); + + // Act + Assert + Assert.Throws(exceptionType, () => session.ReceiveRequest(header, request)); + } + + [Theory] + [MemberData(nameof(ReceiveResponseParameterValidationData))] + public void StreamingSession_ReceiveResponse_ParameterValidation(Header header, ReceiveResponse response, Type exceptionType) + { + // Arrange + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var session = new StreamingSession(new Mock().Object, transportHandler.Object, NullLogger.Instance); + + // Act + Assert + Assert.Throws(exceptionType, () => session.ReceiveResponse(header, response)); + } + + [Theory] + [MemberData(nameof(ReceiveStreamParameterValidationData))] + public void StreamingSession_Receivestream_ParameterValidation(Header header, byte[] payload, Type exceptionType) + { + // Arrange + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + var session = new StreamingSession(new Mock().Object, transportHandler.Object, NullLogger.Instance); + + // Act + Assert + Assert.Throws(exceptionType, () => session.ReceiveStream(header, new ArraySegment(payload))); + } + + [Theory] + [InlineData(10, 1, 1)] + [InlineData(100, 1, 1)] + [InlineData(1000, 1, 1)] + [InlineData(10000, 1, 1)] + [InlineData(1000, 2, 1)] + [InlineData(1000, 1, 2)] + [InlineData(1000, 1, 10)] + [InlineData(1000, 10, 10)] + [InlineData(1000, 100, 10)] + public async Task StreamingSession_RequestWithStreams_SentToHandler(int streamLength, int streamCount, int chunkCount) + { + // Arrange + var requestId = Guid.NewGuid(); + + var request = new ReceiveRequest() + { + Verb = "GET", + Path = "api/version", + Streams = new List() + }; + + request.Streams = StreamingDataGenerator.CreateStreams(requestId, streamLength, streamCount, chunkCount); + + var requestHandler = new Mock(); + + var requestCompletionSource = new TaskCompletionSource(); + + requestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }) + .Callback(() => requestCompletionSource.SetResult(true)); + + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + + var responseCompletionSource = new TaskCompletionSource(); + + transportHandler + .Setup(t => t.SendResponseAsync(It.IsAny(), It.Is(r => r.StatusCode == 200), CancellationToken.None)) + .Callback(() => responseCompletionSource.SetResult(true)); + + // Act + var session = new StreamingSession(requestHandler.Object, transportHandler.Object, NullLogger.Instance); + + session.ReceiveRequest(new Header() { Id = requestId, Type = PayloadTypes.Request }, request); + + foreach (AugmentedStreamDefinition definition in request.Streams) + { + var chunkList = definition.Chunks; + + for (int i = 0; i < chunkList.Count; i++) + { + bool isLast = i == chunkList.Count - 1; + + session.ReceiveStream( + new Header() { End = isLast, Id = definition.Id, PayloadLength = chunkList[i].Length, Type = PayloadTypes.Stream }, + chunkList[i]); + } + } + + var roundtripTask = Task.WhenAll(requestCompletionSource.Task, responseCompletionSource.Task); + var result = await Task.WhenAny(roundtripTask, Task.Delay(TimeSpan.FromSeconds(5))); + + // Assert + Assert.Equal(result, roundtripTask); + } + + [Theory] + [InlineData(10, 1, 1)] + [InlineData(100, 1, 1)] + [InlineData(1000, 1, 1)] + [InlineData(10000, 1, 1)] + [InlineData(1000, 2, 1)] + [InlineData(1000, 1, 2)] + [InlineData(1000, 1, 10)] + [InlineData(1000, 10, 10)] + [InlineData(1000, 100, 10)] + public async Task StreamingSession_SendRequest_ReceiveResponse(int streamLength, int streamCount, int chunkCount) + { + // Arrange + var request = new StreamingRequest() + { + Verb = "GET", + Path = "api/version", + Streams = new List() + }; + + request.AddStream(new StringContent("Hello human, I'm Bender!")); + + var requestHandler = new Mock(); + + var requestCompletionSource = new TaskCompletionSource(); + + requestHandler + .Setup(r => r.ProcessRequestAsync(It.IsAny(), null, null, CancellationToken.None)) + .ReturnsAsync(() => new StreamingResponse() { StatusCode = 200 }) + .Callback(() => requestCompletionSource.SetResult(true)); + + var transportHandler = new Mock(new Mock().Object, NullLogger.Instance); + + var responseCompletionSource = new TaskCompletionSource(); + + var transportHandlerSetup = transportHandler.Setup(t => t.SendRequestAsync(It.IsAny(), It.IsAny(), CancellationToken.None)); + + var session = new StreamingSession(requestHandler.Object, transportHandler.Object, NullLogger.Instance); + + Header responseHeader = null; + ReceiveResponse response = null; + + transportHandlerSetup.Callback( + (Guid requestId, RequestModel requestPayload, CancellationToken cancellationToken) => + { + responseHeader = new Header() { Id = requestId, Type = PayloadTypes.Response }; + response = new ReceiveResponse() { StatusCode = 200, Streams = StreamingDataGenerator.CreateStreams(requestId, streamLength, streamCount, chunkCount, PayloadTypes.Response) }; + + session.ReceiveResponse(responseHeader, response); + + foreach (AugmentedStreamDefinition definition in response.Streams) + { + var chunkList = definition.Chunks; + + for (int i = 0; i < chunkCount; i++) + { + bool isLast = i == chunkCount - 1; + + session.ReceiveStream( + new Header() { End = isLast, Id = definition.Id, PayloadLength = chunkList[i].Length, Type = PayloadTypes.Stream }, + chunkList[i]); + } + } + }); + + // Act + + var responseTask = session.SendRequestAsync(request, CancellationToken.None); + var responseWithTimeout = await Task.WhenAny(responseTask, Task.Delay(TimeSpan.FromSeconds(5))); + + // Assert + Assert.Equal(responseTask, responseWithTimeout); + + var receivedResponse = await responseTask; + + Assert.Equal(response.StatusCode, receivedResponse.StatusCode); + Assert.Equal(response.Streams.Count, receivedResponse.Streams.Count); + + Assert.True(response.Streams.SequenceEqual(receivedResponse.Streams)); + } + + internal static class StreamingDataGenerator + { + public static List CreateStreams(Guid requestId, int streamLength, int streamCount = 1, int chunkCount = 1, char type = PayloadTypes.Request) + { + var result = new List(); + + for (int i = 0; i < streamCount; i++) + { + // To keep code simple, asking that stream length can be equally divided in chunks. Feel + // free to adapt code to support it if needed. + Assert.Equal(0, streamLength % chunkCount); + + var definition = new AugmentedStreamDefinition() + { + Complete = false, + Id = Guid.NewGuid(), + PayloadId = requestId, + Length = streamLength, + PayloadType = type, + Stream = new MemoryStream() + }; + + int chunkSize = streamLength / chunkCount; + int current = 0; + + while (current < streamLength) + { + var data = new byte[chunkSize]; + new Random().NextBytes(data); + + definition.Chunks.Add(data); + + current += chunkSize; + } + + result.Add(definition); + } + + return result; + } + } + + private class AugmentedStreamDefinition : StreamDefinition + { + public List Chunks { get; set; } = new List(); + } + + //[Fact] + //public async Task StreamingSession_RequestWithNoStreams_SentToHandler() + //{ + //} + + //[Fact] + //public async Task StreamingSession_ResponseWithNoStreams_SentToHandler() + //{ + + //} + + //[Fact] + //public async Task StreamingSession_ResponseWithStreams_SentToHandler() + //{ + + //} + + //[Fact] + //public async Task StreamingSession_SendRequest_ResponseReceivedAsynchronously() + //{ + + //} + + //[Fact] + //public async Task StreamingSession_SendResponse_Succeeds() + //{ + + //} + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/MemorySegment{T}.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/MemorySegment{T}.cs new file mode 100644 index 0000000000..3bb13de1b9 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/MemorySegment{T}.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Tools +{ + internal class MemorySegment : ReadOnlySequenceSegment + { + public MemorySegment(ReadOnlyMemory memory) + { + Memory = memory; + } + + public MemorySegment Append(ReadOnlyMemory memory) + { + var segment = new MemorySegment(memory) + { + RunningIndex = RunningIndex + Memory.Length + }; + + Next = segment; + + return segment; + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/SyncPoint.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/SyncPoint.cs new file mode 100644 index 0000000000..48aa2522c7 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/SyncPoint.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + internal class SyncPoint + { + private readonly TaskCompletionSource _atSyncPoint = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + /// + /// Cretes a sync point and returns the associated handler. + /// + /// The created . + /// A representing the sync point. + public static Func Create(out SyncPoint syncPoint) + { + var handler = Create(1, out var syncPoints); + syncPoint = syncPoints[0]; + return handler; + } + + /// + /// Creates a re-entrant function that waits for sync points in sequence. + /// + /// The number of sync points to expect. + /// The objects that can be used to coordinate the sync point. + /// A representing the sync point next step. + public static Func Create(int count, out SyncPoint[] syncPoints) + { + // Need to use a local so the closure can capture it. You can't use out vars in a closure. + var localSyncPoints = new SyncPoint[count]; + for (var i = 0; i < count; i += 1) + { + localSyncPoints[i] = new SyncPoint(); + } + + syncPoints = localSyncPoints; + + var counter = 0; + return () => + { + if (counter >= localSyncPoints.Length) + { + return Task.CompletedTask; + } + else + { + var syncPoint = localSyncPoints[counter]; + + counter += 1; + return syncPoint.WaitToContinue(); + } + }; + } + + /// + /// Waits for the code-under-test to reach . + /// + /// A representing the asynchronous operation. + public Task WaitForSyncPoint() => _atSyncPoint.Task; + + /// + /// Releases the code-under-test to continue past where it waited for . + /// + /// The result of the sync point continuation. + public void Continue(object obj = null) => _continueFromSyncPoint.TrySetResult(obj); + + /// + /// Used by the code-under-test to wait for the test code to sync up. + /// + /// + /// This code will unblock and then block waiting for to be called. + /// + /// The underlying task result. + /// A representing the asynchronous operation. + public Task WaitToContinue(object obj = null) + { + _atSyncPoint.TrySetResult(obj); + return _continueFromSyncPoint.Task; + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TaskExtensions.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TaskExtensions.cs new file mode 100644 index 0000000000..d35b0d252f --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TaskExtensions.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Tools +{ + internal static class TaskExtensions + { + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + // Don't create a timer if the task is already completed + // or the debugger is attached + if (task.IsCompleted || Debugger.IsAttached) + { + return await task; + } + + using (var cts = new CancellationTokenSource()) + { + if (task == await Task.WhenAny(task, Task.Delay(timeout, cts.Token))) + { + cts.Cancel(); + return await task; + } + else + { + throw new TimeoutException(); + } + } + } + + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + // Don't create a timer if the task is already completed + // or the debugger is attached + if (task.IsCompleted || Debugger.IsAttached) + { + await task; + return; + } + + using (var cts = new CancellationTokenSource()) + { + if (task == await Task.WhenAny(task, Task.Delay(timeout, cts.Token))) + { + cts.Cancel(); + await task; + } + else + { + throw new TimeoutException(); + } + } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestTransportObserver.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestTransportObserver.cs new file mode 100644 index 0000000000..82ca3b400f --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestTransportObserver.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using Microsoft.Bot.Streaming.Payloads; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Features +{ + internal class TestTransportObserver : IObserver<(Header Header, ReadOnlySequence Payload)> + { + public List<(Header Header, byte[] Payload)> Received { get; private set; } = new List<(Header Header, byte[] Payload)>(); + + public void OnCompleted() + { + throw new NotImplementedException(); + } + + public void OnError(Exception error) + { + throw new NotImplementedException(); + } + + public void OnNext((Header Header, ReadOnlySequence Payload) value) + { + Received.Add((value.Header, value.Payload.ToArray())); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestWebSocketConnectionFeature.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestWebSocketConnectionFeature.cs new file mode 100644 index 0000000000..d4d6b6981e --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/TestWebSocketConnectionFeature.cs @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Features +{ + internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable + { + private readonly SyncPoint _sync; + private readonly TaskCompletionSource _accepted = new TaskCompletionSource(); + + public TestWebSocketConnectionFeature() + { + } + + public TestWebSocketConnectionFeature(SyncPoint sync) + { + _sync = sync; + } + + public bool IsWebSocketRequest => true; + + public WebSocketChannel Client { get; private set; } + + public string SubProtocol { get; private set; } + + public Task Accepted => _accepted.Task; + + public Task AcceptAsync() => AcceptAsync(new WebSocketAcceptContext()); + + public Task AcceptAsync(WebSocketAcceptContext context) + { + var clientToServer = Channel.CreateUnbounded(); + var serverToClient = Channel.CreateUnbounded(); + + var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer, _sync); + var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer, _sync); + + Client = clientSocket; + SubProtocol = context.SubProtocol; + + _accepted.TrySetResult(new object()); + return Task.FromResult(serverSocket); + } + + public void Dispose() + { + } + + public class WebSocketChannel : WebSocket + { + private readonly ChannelReader _input; + private readonly ChannelWriter _output; + private readonly SyncPoint _sync; + + private WebSocketCloseStatus? _closeStatus; + private string _closeStatusDescription; + private WebSocketState _state; + private WebSocketMessage _internalBuffer = new WebSocketMessage(); + + public WebSocketChannel(ChannelReader input, ChannelWriter output, SyncPoint sync = null) + { + _input = input; + _output = output; + _sync = sync; + _state = WebSocketState.Open; + } + + public override WebSocketCloseStatus? CloseStatus => _closeStatus; + + public override string CloseStatusDescription => _closeStatusDescription; + + public override WebSocketState State => _state; + + public override string SubProtocol => null; + + public override void Abort() + { + _output.TryComplete(new OperationCanceledException()); + _state = WebSocketState.Aborted; + } + + public void SendAbort() + { + _output.TryComplete(new WebSocketException(WebSocketError.ConnectionClosedPrematurely)); + } + + public override async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + await SendMessageAsync( + new WebSocketMessage + { + CloseStatus = closeStatus, + CloseStatusDescription = statusDescription, + MessageType = WebSocketMessageType.Close, + }, + cancellationToken).ConfigureAwait(false); + + _state = WebSocketState.CloseSent; + + _output.TryComplete(); + } + + public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + await SendMessageAsync( + new WebSocketMessage + { + CloseStatus = closeStatus, + CloseStatusDescription = statusDescription, + MessageType = WebSocketMessageType.Close, + }, + cancellationToken).ConfigureAwait(false); + + _state = WebSocketState.CloseSent; + + _output.TryComplete(); + } + + public override void Dispose() + { + _state = WebSocketState.Closed; + _output.TryComplete(); + } + + public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + try + { + while (_internalBuffer.Buffer == null || _internalBuffer.Buffer.Length == 0) + { + await _input.WaitToReadAsync(cancellationToken).ConfigureAwait(false); + + if (_input.TryRead(out var message)) + { + if (message.MessageType == WebSocketMessageType.Close) + { + _state = WebSocketState.CloseReceived; + _closeStatus = message.CloseStatus; + _closeStatusDescription = message.CloseStatusDescription; + return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); + } + + _internalBuffer = message; + } + else + { + await Task.Delay(100).ConfigureAwait(false); + } + } + + var length = _internalBuffer.Buffer.Length; + if (buffer.Count < _internalBuffer.Buffer.Length) + { + length = Math.Min(buffer.Count, _internalBuffer.Buffer.Length); + Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); + } + else + { + Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); + } + + var endOfMessage = _internalBuffer.EndOfMessage; + if (length > 0) + { + // Remove the sent bytes from the remaining buffer + _internalBuffer.Buffer = _internalBuffer.Buffer.AsMemory().Slice(length).ToArray(); + endOfMessage = _internalBuffer.Buffer.Length == 0 && endOfMessage; + } + + return new WebSocketReceiveResult(length, _internalBuffer.MessageType, endOfMessage); + } + catch (WebSocketException ex) + { + switch (ex.WebSocketErrorCode) + { + case WebSocketError.ConnectionClosedPrematurely: + _state = WebSocketState.Aborted; + break; + } + + // Complete the client side if there's an error + _output.TryComplete(); + + throw; + } + + throw new InvalidOperationException("Unexpected close"); + } + + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + if (_sync != null) + { + await _sync.WaitToContinue().ConfigureAwait(false); + } + + cancellationToken.ThrowIfCancellationRequested(); + + var copy = new byte[buffer.Count]; + Buffer.BlockCopy(buffer.Array, buffer.Offset, copy, 0, buffer.Count); + await SendMessageAsync( + new WebSocketMessage + { + Buffer = copy, + MessageType = messageType, + EndOfMessage = endOfMessage + }, + cancellationToken).ConfigureAwait(false); + } + + public async Task ExecuteAndCaptureFramesAsync() + { + var frames = new List(); + while (await _input.WaitToReadAsync().ConfigureAwait(false)) + { + while (_input.TryRead(out var message)) + { + if (message.MessageType == WebSocketMessageType.Close) + { + _state = WebSocketState.CloseReceived; + _closeStatus = message.CloseStatus; + _closeStatusDescription = message.CloseStatusDescription; + return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, message.MessageType, message.EndOfMessage, message.CloseStatus, message.CloseStatusDescription)); + } + + frames.Add(message); + } + } + + _state = WebSocketState.Closed; + _closeStatus = WebSocketCloseStatus.InternalServerError; + return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, WebSocketMessageType.Close, endOfMessage: true, closeStatus: WebSocketCloseStatus.InternalServerError, closeStatusDescription: string.Empty)); + } + + private async Task SendMessageAsync(WebSocketMessage webSocketMessage, CancellationToken cancellationToken) + { + while (await _output.WaitToWriteAsync(cancellationToken).ConfigureAwait(false)) + { + if (_output.TryWrite(webSocketMessage)) + { + break; + } + } + } + } + + internal class WebSocketConnectionSummary + { + public WebSocketConnectionSummary(IList received, WebSocketReceiveResult closeResult) + { + Received = received; + CloseResult = closeResult; + } + + public IList Received { get; } + + public WebSocketReceiveResult CloseResult { get; } + } + + internal class WebSocketMessage + { + public byte[] Buffer { get; set; } + + public WebSocketMessageType MessageType { get; set; } + + public bool EndOfMessage { get; set; } + + public WebSocketCloseStatus? CloseStatus { get; set; } + + public string CloseStatusDescription { get; set; } + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/XUnitLogger.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/XUnitLogger.cs new file mode 100644 index 0000000000..b0d93bd147 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Tools/XUnitLogger.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; + +namespace Microsoft.Bot.Connector.Streaming.Tests.Tools +{ + internal class XUnitLogger : ILogger + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly string _categoryName; + private readonly LoggerExternalScopeProvider _scopeProvider; + + public XUnitLogger(ITestOutputHelper testOutputHelper, LoggerExternalScopeProvider scopeProvider, string categoryName) + { + _testOutputHelper = testOutputHelper; + _scopeProvider = scopeProvider; + _categoryName = categoryName; + } + + public static ILogger CreateLogger(ITestOutputHelper testOutputHelper) => new XUnitLogger(testOutputHelper, new LoggerExternalScopeProvider(), string.Empty); + + public bool IsEnabled(LogLevel logLevel) => logLevel != LogLevel.None; + + public IDisposable BeginScope(TState state) => _scopeProvider.Push(state); + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { + var sb = new StringBuilder(); + sb.Append(Enum.GetName(typeof(LogLevel), logLevel)) + .Append(" [").Append(_categoryName).Append("] ") + .Append(formatter(state, exception)); + + if (exception != null) + { + sb.Append('\n').Append(exception); + } + + // Append scopes + _scopeProvider.ForEachScope( + (scope, state) => + { + state.Append("\n => "); + state.Append(scope); + }, sb); + + Debug.WriteLine(sb.ToString()); + _testOutputHelper.WriteLine(sb.ToString()); + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/TransportHandlerTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/TransportHandlerTests.cs new file mode 100644 index 0000000000..555dd11d91 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/TransportHandlerTests.cs @@ -0,0 +1,424 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Bot.Connector.Streaming.Tests.Features; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Bot.Streaming.Payloads; +using Microsoft.Bot.Streaming.Transport; +using Microsoft.Extensions.Logging.Abstractions; +using Newtonsoft.Json; +using Xunit; +using RequestModel = Microsoft.Bot.Connector.Streaming.Payloads.RequestPayload; +using ResponseModel = Microsoft.Bot.Connector.Streaming.Payloads.ResponsePayload; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class TransportHandlerTests + { + public static IEnumerable PipeToObserverData => + new List() + { + new object[] { new List<(Header Header, byte[] Payload)>() }, + new object[] { GenerateHeaderPayloadData(1, 1) }, + new object[] { GenerateHeaderPayloadData(1, 1) }, + new object[] { GenerateHeaderPayloadData(10, 1) }, + new object[] { GenerateHeaderPayloadData(100, 1) }, + new object[] { GenerateHeaderPayloadData(1000, 1) }, + new object[] { GenerateHeaderPayloadData(10000, 1) }, + new object[] { GenerateHeaderPayloadData(100000, 1) }, + new object[] { GenerateHeaderPayloadData(100000, 2) }, + new object[] { GenerateHeaderPayloadData(1000, 20) }, + new object[] { GenerateHeaderPayloadData(1000, 200) }, + }; + + public static IEnumerable ErrorScenarioData => + new List() + { + new object[] { GenerateHeaderPayloadData(1000, 2), true, false, false }, + new object[] { GenerateHeaderPayloadData(1000, 2), false, true, false }, + new object[] { GenerateHeaderPayloadData(1000, 2), false, false, true }, + }; + + [Theory] + [MemberData(nameof(PipeToObserverData))] + public async Task TransportHandler_ReceiveFromPipe_IsSentToObserver( + List<(Header Header, byte[] Payload)> transportData) + { + await RunTransportHandlerReceiveTestAsync(transportData, false, false, false); + } + + [Theory] + [MemberData(nameof(ErrorScenarioData))] + public async Task TransportHandler_ReceiveFromPipe_ErrorScenarios( + List<(Header Header, byte[] Payload)> transportData, + bool cancelAfterFirst, + bool cancelWithoutPayload, + bool completeWithException) + { + await RunTransportHandlerReceiveTestAsync(transportData, cancelAfterFirst, cancelWithoutPayload, completeWithException); + } + + [Fact] + public void TransportHandler_NullObserver_Throws() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + Assert.Throws(() => transportHandler.Subscribe(null)); + } + + [Fact] + public void TransportHandler_DoubleObserverRegistration_Throws() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + transportHandler.Subscribe(new TestTransportObserver()); + Assert.Throws(() => transportHandler.Subscribe(new TestTransportObserver())); + } + + [Fact] + public async Task TransportHandler_SendRequest_ThrowsOnNull() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + await Assert.ThrowsAsync( + async () => await transportHandler.SendRequestAsync(Guid.NewGuid(), null, CancellationToken.None)); + } + + [Fact] + public async Task TransportHandler_SendRequest_TransportReceivesHeaderAndPayload() + { + var request = new RequestModel() + { + Verb = "GET", + Path = "api/version", + Streams = new List() + { + new StreamDescription() { ContentType = "json", Id = Guid.NewGuid().ToString(), Length = 18 }, + new StreamDescription() { ContentType = "text", Id = Guid.NewGuid().ToString(), Length = 24 } + } + }; + + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + + var transport = pipePair.Application.Input; + + await transportHandler.SendRequestAsync(Guid.NewGuid(), request, CancellationToken.None); + + var result = await transport.ReadAsync(); + var buffer = result.Buffer; + + var headerBuffer = buffer.Slice(0, Math.Min(TransportConstants.MaxHeaderLength, buffer.Length)); + + var header = HeaderSerializer.Deserialize(headerBuffer.ToArray(), 0, TransportConstants.MaxHeaderLength); + + buffer = buffer.Slice(TransportConstants.MaxHeaderLength); + + if (buffer.Length < header.PayloadLength) + { + transport.AdvanceTo(buffer.Start, buffer.End); + + result = await transport.ReadAsync().ConfigureAwait(false); + + Assert.False(result.IsCanceled); + + buffer = result.Buffer; + } + + var payload = buffer.Slice(buffer.Start, header.PayloadLength).ToArray(); + + var payloadJson = Encoding.UTF8.GetString(payload); + var receivedPayload = JsonConvert.DeserializeObject(payloadJson); + + Assert.NotNull(receivedPayload); + + Assert.Equal(request.Path, receivedPayload.Path); + Assert.Equal(request.Verb, receivedPayload.Verb); + + Assert.Equal(request.Streams.Count, receivedPayload.Streams.Count); + + for (int i = 0; i < request.Streams.Count; i++) + { + Assert.Equal(request.Streams[i].ContentType, receivedPayload.Streams[i].ContentType); + Assert.Equal(request.Streams[i].Id, receivedPayload.Streams[i].Id); + Assert.Equal(request.Streams[i].Length, receivedPayload.Streams[i].Length); + } + } + + [Fact] + public async Task TransportHandler_SendResponse_TransportReceivesHeaderAndPayload() + { + var response = new ResponseModel() + { + StatusCode = 200, + Streams = new List() + { + new StreamDescription() { ContentType = "json", Id = Guid.NewGuid().ToString(), Length = 18 }, + new StreamDescription() { ContentType = "text", Id = Guid.NewGuid().ToString(), Length = 24 } + } + }; + + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + + var transport = pipePair.Application.Input; + + await transportHandler.SendResponseAsync(Guid.NewGuid(), response, CancellationToken.None); + + var result = await transport.ReadAsync(); + var buffer = result.Buffer; + + var headerBuffer = buffer.Slice(0, Math.Min(TransportConstants.MaxHeaderLength, buffer.Length)); + + var header = HeaderSerializer.Deserialize(headerBuffer.ToArray(), 0, TransportConstants.MaxHeaderLength); + + buffer = buffer.Slice(TransportConstants.MaxHeaderLength); + + if (buffer.Length < header.PayloadLength) + { + transport.AdvanceTo(buffer.Start, buffer.End); + + result = await transport.ReadAsync().ConfigureAwait(false); + + Assert.False(result.IsCanceled); + + buffer = result.Buffer; + } + + var payload = buffer.Slice(buffer.Start, header.PayloadLength).ToArray(); + + var payloadJson = Encoding.UTF8.GetString(payload); + var receivedPayload = JsonConvert.DeserializeObject(payloadJson); + + Assert.NotNull(receivedPayload); + + Assert.Equal(response.StatusCode, receivedPayload.StatusCode); + + Assert.Equal(response.Streams.Count, receivedPayload.Streams.Count); + + for (int i = 0; i < response.Streams.Count; i++) + { + Assert.Equal(response.Streams[i].ContentType, receivedPayload.Streams[i].ContentType); + Assert.Equal(response.Streams[i].Id, receivedPayload.Streams[i].Id); + Assert.Equal(response.Streams[i].Length, receivedPayload.Streams[i].Length); + } + } + + [Fact] + public async Task TransportHandler_SendStream_TransportReceivesHeaderAndPayload() + { + var text = "Hello human, I'm Bender"; + + // TODO: make this a theory with increasing byte count. Implement chunking in the transport handler + // to ensure once byte size increases we still send manageable packet size, and test the chunking here. + var stream = new MemoryStream(Encoding.UTF8.GetBytes(text)); + + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + + var transport = pipePair.Application.Input; + + await transportHandler.SendStreamAsync(Guid.NewGuid(), stream, CancellationToken.None); + + var result = await transport.ReadAsync(); + var buffer = result.Buffer; + + var headerBuffer = buffer.Slice(0, Math.Min(TransportConstants.MaxHeaderLength, buffer.Length)); + + var header = HeaderSerializer.Deserialize(headerBuffer.ToArray(), 0, TransportConstants.MaxHeaderLength); + + buffer = buffer.Slice(TransportConstants.MaxHeaderLength); + + if (buffer.Length < header.PayloadLength) + { + transport.AdvanceTo(buffer.Start, buffer.End); + + result = await transport.ReadAsync().ConfigureAwait(false); + + Assert.False(result.IsCanceled); + + buffer = result.Buffer; + } + + var payload = buffer.Slice(buffer.Start, header.PayloadLength).ToArray(); + + var payloadString = Encoding.UTF8.GetString(payload); + + Assert.Equal(text, payloadString); + } + + [Fact] + public async Task TransportHandler_SendResponse_ThrowsOnNull() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + await Assert.ThrowsAsync( + async () => await transportHandler.SendResponseAsync(Guid.NewGuid(), null, CancellationToken.None)); + } + + [Fact] + public async Task TransportHandler_SendStream_ThrowsOnNull() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + + await Assert.ThrowsAsync( + async () => await transportHandler.SendStreamAsync(Guid.NewGuid(), null, CancellationToken.None)); + } + + private static async Task RunTransportHandlerReceiveTestAsync( + List<(Header Header, byte[] Payload)> transportData, + bool cancelAfterFirst, + bool cancelWithoutPayload, + bool completeWithException) + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var applicationDuplexPipe = pipePair.Application; + + var transportHandler = new TransportHandler(pipePair.Transport, NullLogger.Instance); + + var transportObserver = new TestTransportObserver(); + transportHandler.Subscribe(transportObserver); + + var transportTask = transportHandler.ListenAsync(CancellationToken.None); + + var output = applicationDuplexPipe.Output; + bool first = true; + + foreach (var entry in transportData) + { + var headerBuffer = new byte[48]; + HeaderSerializer.Serialize(entry.Header, headerBuffer, 0); + + Assert.Equal(entry.Header.PayloadLength, entry.Payload.Length); + + await output.WriteAsync(headerBuffer, CancellationToken.None).ConfigureAwait(false); + + if (cancelWithoutPayload) + { + output.CancelPendingFlush(); + break; + } + + if (entry.Header.PayloadLength > 0) + { + await output.WriteAsync(entry.Payload, CancellationToken.None).ConfigureAwait(false); + } + + if (first && cancelAfterFirst) + { + output.CancelPendingFlush(); + break; + } + } + + if (completeWithException) + { + await Task.Delay(TimeSpan.FromSeconds(1)); + await output.CompleteAsync(new Exception()).ConfigureAwait(false); + await Assert.ThrowsAsync(async () => await transportTask); + } + else + { + await output.CompleteAsync().ConfigureAwait(false); + + if (Debugger.IsAttached) + { + await transportTask; + } + else + { + var result = await Task.WhenAny(transportTask, Task.Delay(TimeSpan.FromSeconds(5))).ConfigureAwait(false); + Assert.Equal(result, transportTask); + } + } + + var receivedData = transportObserver.Received; + + if (cancelAfterFirst) + { + Assert.Single(receivedData); + } + else if (cancelWithoutPayload) + { + Assert.Empty(receivedData); + } + else if (!completeWithException) + { + Assert.Equal(transportData.Count, receivedData.Count); + } + + if (!cancelAfterFirst && !cancelWithoutPayload && !completeWithException) + { + for (int i = 0; i < transportData.Count; i++) + { + Assert.Equal(transportData[i].Header.End, receivedData[i].Header.End); + Assert.Equal(transportData[i].Header.Id, receivedData[i].Header.Id); + Assert.Equal(transportData[i].Header.PayloadLength, receivedData[i].Header.PayloadLength); + Assert.Equal(transportData[i].Header.Type, receivedData[i].Header.Type); + + Assert.True(transportData[i].Payload.SequenceEqual(receivedData[i].Payload)); + } + } + } + + private static List<(Header Header, byte[] Payload)> GenerateHeaderPayloadData(int totalLength, int packageCount) + { + var result = new List<(Header Header, byte[] Payload)>(); + + if (totalLength == 0) + { + var header = new Header() + { + Id = Guid.NewGuid(), + End = true, + PayloadLength = 0, + Type = PayloadTypes.Stream + }; + + result.Add((header, null)); + return result; + } + + byte[] buffer = new byte[totalLength]; + + var random = new Random(); + random.NextBytes(buffer); + + var chunkSize = totalLength / packageCount; + + var current = 0; + + while (current < totalLength) + { + var currentSize = Math.Min(chunkSize, totalLength - current); + + var header = new Header() + { + Id = Guid.NewGuid(), + End = true, + PayloadLength = currentSize, + Type = PayloadTypes.Stream + }; + + var payload = new byte[currentSize]; + + result.Add((header, payload)); + + current += currentSize; + } + + return result; + } + } +} diff --git a/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/WebSocketTransportTests.cs b/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/WebSocketTransportTests.cs new file mode 100644 index 0000000000..fbdb8d70a1 --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Streaming.Tests/Transport/WebSocketTransportTests.cs @@ -0,0 +1,515 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Linq; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Bot.Connector.Streaming.Tests.Features; +using Microsoft.Bot.Connector.Streaming.Tests.Tools; +using Microsoft.Bot.Connector.Streaming.Transport; +using Microsoft.Extensions.Logging.Abstractions; +using Moq; +using Xunit; +using Xunit.Abstractions; +using static Microsoft.Bot.Connector.Streaming.Tests.Features.TestWebSocketConnectionFeature; + +namespace Microsoft.Bot.Connector.Streaming.Tests +{ + public class WebSocketTransportTests + { + private readonly ITestOutputHelper _testOutput; + + public WebSocketTransportTests(ITestOutputHelper testOutput) + { + _testOutput = testOutput; + } + + [Fact] + public async Task WebSocketTransport_WhatIsReceivedIsWritten() + { + var pipePair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var webSocketFeature = new TestWebSocketConnectionFeature()) + { + // Build transport + var transport = new WebSocketTransport(pipePair.Application, NullLogger.Instance); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(await webSocketFeature.AcceptAsync(), CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = webSocketFeature.Client.ExecuteAndCaptureFramesAsync(); + + // Send a frame, then close + await webSocketFeature.Client.SendAsync( + buffer: new ArraySegment(Encoding.UTF8.GetBytes("Hello")), + messageType: WebSocketMessageType.Binary, + endOfMessage: true, + cancellationToken: CancellationToken.None); + await webSocketFeature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); + + var result = await pipePair.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.ToArray())); + pipePair.Transport.Input.AdvanceTo(buffer.End); + + pipePair.Transport.Output.Complete(); + + // The transport should finish now + await processTask; + + // The connection should close after this, which means the client will get a close frame. + var clientSummary = await clientTask; + + Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.CloseStatus); + } + } + + [Fact] + public async Task TransportCommunicatesErrorToApplicationWhenClientDisconnectsAbnormally() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + async Task CompleteApplicationAfterTransportCompletes() + { + try + { + // Wait until the transport completes so that we can end the application + var result = await pair.Transport.Input.ReadAsync(); + pair.Transport.Input.AdvanceTo(result.Buffer.End); + } + catch (Exception ex) + { + Assert.IsType(ex); + } + finally + { + // Complete the application so that the connection unwinds without aborting + pair.Transport.Output.Complete(); + } + } + + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(await feature.AcceptAsync(), CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + // When the close frame is received, we complete the application so the send + // loop unwinds + _ = CompleteApplicationAfterTransportCompletes(); + + // Terminate the client to server channel with an exception + feature.Client.SendAbort(); + + // Wait for the transport + await processTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + + await clientTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + } + } + } + + [Fact] + public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(await feature.AcceptAsync(), CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + // Fail in the app + pair.Transport.Output.Complete(new InvalidOperationException("Catastrophic failure.")); + var clientSummary = await clientTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus); + + // Close from the client + await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); + + await processTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + } + } + } + + [Fact] + public async Task TransportClosesOnCloseTimeoutIfClientDoesNotSendCloseFrame() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + var serverSocket = await feature.AcceptAsync(); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(serverSocket, CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + // End the app + pair.Transport.Output.Complete(); + + await processTask.TimeoutAfter(TimeSpan.FromSeconds(10)); + + // Now we're closed + Assert.Equal(WebSocketState.Aborted, serverSocket.State); + + serverSocket.Dispose(); + } + } + } + + [Fact] + public async Task TransportFailsOnTimeoutWithErrorWhenApplicationFailsAndClientDoesNotSendCloseFrame() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + var serverSocket = await feature.AcceptAsync(); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(serverSocket, CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + // fail the client to server channel + pair.Transport.Output.Complete(new Exception()); + + await processTask.TimeoutAfter(TimeSpan.FromSeconds(10)); + + Assert.Equal(WebSocketState.Aborted, serverSocket.State); + } + } + } + + [Fact] + public async Task ServerGracefullyClosesWhenApplicationEndsThenClientSendsCloseFrame() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + var serverSocket = await feature.AcceptAsync(); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(serverSocket, CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + // close the client to server channel + pair.Transport.Output.Complete(); + + _ = await clientTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).TimeoutAfter(TimeSpan.FromSeconds(5)); + + await processTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } + } + } + + [Fact] + public async Task ServerGracefullyClosesWhenClientSendsCloseFrameThenApplicationEnds() + { + //using (StartVerifiableLog()) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var transport = new WebSocketTransport(pair.Application, NullLogger.Instance); + var serverSocket = await feature.AcceptAsync(); + + // Accept web socket, start receiving / sending at the transport level + var processTask = transport.ProcessSocketAsync(serverSocket, CancellationToken.None); + + // Start a socket client that will capture traffic for posterior analysis + var clientTask = feature.Client.ExecuteAndCaptureFramesAsync(); + + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).TimeoutAfter(TimeSpan.FromSeconds(5)); + + // close the client to server channel + pair.Transport.Output.Complete(); + + _ = await clientTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + + await processTask.TimeoutAfter(TimeSpan.FromSeconds(5)); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } + } + } + + [Fact] + public async Task MultiSegmentSendWillNotSendEmptyEndOfMessageFrame() + { + using (var feature = new TestWebSocketConnectionFeature()) + { + var serverSocket = await feature.AcceptAsync(); + + var firstSegment = new byte[] { 1 }; + var secondSegment = new byte[] { 15 }; + + var first = new MemorySegment(firstSegment); + var last = first.Append(secondSegment); + + var sequence = new ReadOnlySequence(first, 0, last, last.Memory.Length); + + Assert.False(sequence.IsSingleSegment); + + await serverSocket.SendAsync(sequence, WebSocketMessageType.Text); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + await serverSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, default); + + var messages = await client.TimeoutAfter(TimeSpan.FromSeconds(5)); + Assert.Equal(2, messages.Received.Count); + + // First message: 1 byte, endOfMessage false + Assert.Single(messages.Received[0].Buffer); + Assert.Equal(1, messages.Received[0].Buffer[0]); + Assert.False(messages.Received[0].EndOfMessage); + + // Second message: 1 byte, endOfMessage true + Assert.Single(messages.Received[1].Buffer); + Assert.Equal(15, messages.Received[1].Buffer[0]); + Assert.True(messages.Received[1].EndOfMessage); + } + } + + [Fact] + public async Task ServerTransportCanReceiveMessages() + { + var logger = XUnitLogger.CreateLogger(_testOutput); + + using (var connection = new TestWebSocketConnectionFeature()) + { + var server = connection.AcceptAsync(); + var client = connection.Client; + + var fromTransport = new Pipe(PipeOptions.Default); + var toTransport = new Pipe(PipeOptions.Default); + var listenerRunning = ListenAsync(fromTransport.Reader); + + var webSocketManager = new Mock(); + webSocketManager.Setup(m => m.AcceptWebSocketAsync()).Returns(server); + var httpContext = new Mock(); + httpContext.Setup(c => c.WebSockets).Returns(webSocketManager.Object); + + var sut = new WebSocketTransport(new DuplexPipe(toTransport.Reader, fromTransport.Writer), logger); + var serverTransportRunning = sut.ConnectAsync(httpContext.Object, CancellationToken.None); + + var messages = new List { Encoding.UTF8.GetBytes("foo"), Encoding.UTF8.GetBytes("bar") }; + SendBinaryAsync(client, messages).Wait(); + client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Done sending.", CancellationToken.None).Wait(); + + serverTransportRunning.Wait(); + + var output = await listenerRunning; + + Assert.Equal("foo", Encoding.UTF8.GetString(output[0])); + Assert.Equal("bar", Encoding.UTF8.GetString(output[1])); + } + } + + [Fact] + public void ServerTransportCanSendMessages() + { + var logger = XUnitLogger.CreateLogger(_testOutput); + + using (var connection = new TestWebSocketConnectionFeature()) + { + var server = connection.AcceptAsync().GetAwaiter().GetResult(); + var client = connection.Client; + var receiverRunning = ReceiveAsync(client); + + var fromTransport = new Pipe(PipeOptions.Default); + var toTransport = new Pipe(PipeOptions.Default); + + var webSocketManager = new Mock(); + webSocketManager.Setup(m => m.AcceptWebSocketAsync()).Returns(Task.FromResult(server)); + var httpContext = new Mock(); + httpContext.Setup(c => c.WebSockets).Returns(webSocketManager.Object); + + var sut = new WebSocketTransport(new DuplexPipe(toTransport.Reader, fromTransport.Writer), logger); + var serverTransportRunning = sut.ConnectAsync(httpContext.Object, CancellationToken.None); + + var messages = new List { Encoding.UTF8.GetBytes("foo") }; + WriteAsync(toTransport.Writer, messages).Wait(); + toTransport.Writer.CompleteAsync().GetAwaiter().GetResult(); + + serverTransportRunning.Wait(); + + var output = receiverRunning.GetAwaiter().GetResult(); + + Assert.Equal("foo", Encoding.UTF8.GetString(output[0])); + } + } + + [Fact] + public void ClientTransportCanReceiveMessages() + { + var logger = XUnitLogger.CreateLogger(_testOutput); + + using (var connection = new TestWebSocketConnectionFeature()) + { + var server = connection.AcceptAsync().GetAwaiter().GetResult(); + var client = connection.Client; + + var fromTransport = new Pipe(PipeOptions.Default); + var toTransport = new Pipe(PipeOptions.Default); + var listenerRunning = ListenAsync(fromTransport.Reader); + + var sut = new WebSocketTransport(new DuplexPipe(toTransport.Reader, fromTransport.Writer), logger); + var clientTransportRunning = sut.ProcessSocketAsync(client, CancellationToken.None); + + var messages = new List { Encoding.UTF8.GetBytes("foo") }; + SendBinaryAsync(server, messages).Wait(); + server.CloseAsync(WebSocketCloseStatus.NormalClosure, "Done sending.", CancellationToken.None).Wait(); + + clientTransportRunning.Wait(); + + var output = listenerRunning.GetAwaiter().GetResult(); + + Assert.Equal("foo", Encoding.UTF8.GetString(output[0])); + } + } + + [Fact] + public void ClientTransportCanSendMessages() + { + var logger = XUnitLogger.CreateLogger(_testOutput); + + using (var connection = new TestWebSocketConnectionFeature()) + { + var server = connection.AcceptAsync().GetAwaiter().GetResult(); + var client = connection.Client; + var receiverRunning = ReceiveAsync(server); + + var fromTransport = new Pipe(PipeOptions.Default); + var toTransport = new Pipe(PipeOptions.Default); + + var sut = new WebSocketTransport(new DuplexPipe(toTransport.Reader, fromTransport.Writer), logger); + var clientTransportRunning = sut.ProcessSocketAsync(client, CancellationToken.None); + + var messages = new List { Encoding.UTF8.GetBytes("foo") }; + WriteAsync(toTransport.Writer, messages).Wait(); + toTransport.Writer.CompleteAsync().GetAwaiter().GetResult(); + + clientTransportRunning.Wait(); + + var output = receiverRunning.GetAwaiter().GetResult(); + + Assert.Equal("foo", Encoding.UTF8.GetString(output[0])); + } + } + + private static async Task> ListenAsync(PipeReader input) + { + var messages = new List(); + const int messageLength = 3; + + while (true) + { + var result = await input.ReadAsync(); + + var buffer = result.Buffer; + + while (!buffer.IsEmpty) + { + var payload = buffer.Slice(0, messageLength); + messages.Add(payload.ToArray()); + buffer = buffer.Slice(messageLength); + } + + input.AdvanceTo(buffer.Start, buffer.End); + + if (result.IsCompleted) + { + break; + } + } + + await input.CompleteAsync(); + + return messages; + } + + private static async Task SendBinaryAsync(WebSocket socket, List messages) + { + foreach (var message in messages) + { + var buffer = new ArraySegment(message); + await socket.SendAsync(buffer, WebSocketMessageType.Binary, endOfMessage: true, CancellationToken.None); + } + } + + private static async Task> ReceiveAsync(WebSocket socket) + { + var messages = new List(); + + while (true) + { + var buffer = new byte[1024 * 4]; + var result = await socket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Close) + { + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); + break; + } + + messages.Add(buffer.Take(result.Count).ToArray()); + } + + return messages; + } + + private static async Task WriteAsync(PipeWriter output, List messages) + { + foreach (var message in messages) + { + var buffer = new ReadOnlyMemory(message); + await output.WriteAsync(buffer); + } + } + } +} diff --git a/tests/Microsoft.Bot.Streaming.Tests/RequestTests.cs b/tests/Microsoft.Bot.Streaming.Tests/RequestTests.cs index 02870a18ed..61a2ade605 100644 --- a/tests/Microsoft.Bot.Streaming.Tests/RequestTests.cs +++ b/tests/Microsoft.Bot.Streaming.Tests/RequestTests.cs @@ -129,7 +129,7 @@ public void Request_Create_Get_Success() Assert.Equal(StreamingRequest.GET, r.Verb); Assert.Null(r.Path); - Assert.Null(r.Streams); + Assert.Empty(r.Streams); } [Fact] @@ -139,7 +139,7 @@ public void Request_Create_Post_Success() Assert.Equal(StreamingRequest.POST, r.Verb); Assert.Null(r.Path); - Assert.Null(r.Streams); + Assert.Empty(r.Streams); } [Fact] @@ -149,7 +149,7 @@ public void Request_Create_Delete_Success() Assert.Equal(StreamingRequest.DELETE, r.Verb); Assert.Null(r.Path); - Assert.Null(r.Streams); + Assert.Empty(r.Streams); } [Fact] @@ -159,7 +159,7 @@ public void Request_Create_Put_Success() Assert.Equal(StreamingRequest.PUT, r.Verb); Assert.Null(r.Path); - Assert.Null(r.Streams); + Assert.Empty(r.Streams); } [Fact] diff --git a/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs b/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs index 5d34b3515c..5b0391592f 100644 --- a/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs +++ b/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs @@ -17,8 +17,14 @@ using Microsoft.AspNetCore.Http; using Microsoft.Bot.Connector; using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Connector.Streaming.Application; using Microsoft.Bot.Schema; +using Microsoft.Bot.Streaming; +using Microsoft.Bot.Streaming.Payloads; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Primitives; using Microsoft.Rest; using Microsoft.Rest.Serialization; using Moq; @@ -186,6 +192,143 @@ public async Task WebSocketRequestShouldCallAuthenticateStreamingRequestAsync() botFrameworkAuthenticationMock.Verify(x => x.AuthenticateStreamingRequestAsync(It.Is(v => true), It.Is(v => true), It.Is(ct => true)), Times.Once()); } + [Fact] + public void CanContinueConversationOverWebSocket() + { + // Arrange + var continueConversationWaiter = new AutoResetEvent(false); + var verifiedValidContinuation = false; + + var appId = "testAppId"; + var tenantId = "testTenantId"; + var token = "Bearer testjwt"; + var channelId = "testChannel"; + var audience = "testAudience"; + var callerId = "testCallerId"; + + var authResult = new AuthenticateRequestResult + { + Audience = audience, + CallerId = callerId, + ClaimsIdentity = new ClaimsIdentity(new List + { + new Claim("aud", audience), + new Claim("iss", $"https://login.microsoftonline.com/{tenantId}/"), + new Claim("azp", appId), + new Claim("tid", tenantId), + new Claim("ver", "2.0") + }) + }; + + var userTokenClient = new TestUserTokenClient(appId); + + var validActivity = new Activity + { + Id = Guid.NewGuid().ToString("N"), + Type = ActivityTypes.Message, + From = new ChannelAccount { Id = "testUser" }, + Conversation = new ConversationAccount { Id = Guid.NewGuid().ToString("N") }, + Recipient = new ChannelAccount { Id = "testBot" }, + ServiceUrl = "wss://InvalidServiceUrl/api/messages", + ChannelId = channelId, + Text = "hi", + }; + var validContent = new StringContent(JsonConvert.SerializeObject(validActivity), Encoding.UTF8, "application/json"); + + var invalidActivity = new Activity + { + Id = Guid.NewGuid().ToString("N"), + Type = ActivityTypes.Message, + From = new ChannelAccount { Id = "testUser" }, + Conversation = new ConversationAccount { Id = Guid.NewGuid().ToString("N") }, + Recipient = new ChannelAccount { Id = "testBot" }, + ServiceUrl = "wss://InvalidServiceUrl/api/messages", + ChannelId = channelId, + Text = "hi", + }; + + var streamingConnection = new Mock(); + streamingConnection + .Setup(c => c.ListenAsync(It.IsAny(), It.IsAny())) + .Returns((handler, cancellationToken) => handler.ProcessRequestAsync( + new ReceiveRequest + { + Verb = "POST", + Path = "/api/messages", + Streams = new List + { + new TestContentStream + { + Id = Guid.NewGuid(), + ContentType = "application/json", + Length = (int?)validContent.Headers.ContentLength, + Stream = validContent.ReadAsStreamAsync().GetAwaiter().GetResult() + } + } + }, + null, + cancellationToken: cancellationToken)); + + var auth = new Mock(); + auth.Setup(a => a.AuthenticateStreamingRequestAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(authResult)); + auth.Setup(a => a.CreateUserTokenClientAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(userTokenClient)); + + var webSocketManager = new Mock(); + webSocketManager.Setup(m => m.IsWebSocketRequest).Returns(true); + var httpContext = new Mock(); + httpContext.Setup(c => c.WebSockets).Returns(webSocketManager.Object); + var httpRequest = new Mock(); + httpRequest.Setup(r => r.Method).Returns("GET"); + httpRequest.Setup(r => r.HttpContext).Returns(httpContext.Object); + httpRequest.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", new StringValues(token) }, + { "channelid", new StringValues(channelId) } + }); + + var httpResponse = new Mock(); + + var bot = new Mock(); + bot.Setup(b => b.OnTurnAsync(It.IsAny(), It.IsAny())) + .Returns(Task.Factory.StartNew(() => { continueConversationWaiter.WaitOne(); })); // Simulate listening on web socket + + // Act + var adapter = new StreamingTestCloudAdapter(auth.Object, streamingConnection.Object); + var processRequest = adapter.ProcessAsync(httpRequest.Object, httpResponse.Object, bot.Object, CancellationToken.None); + + var validContinuation = adapter.ContinueConversationAsync( + authResult.ClaimsIdentity, + validActivity, + (turn, cancellationToken) => + { + var connectorFactory = turn.TurnState.Get(); + Assert.NotNull(connectorFactory); + var connectorFactoryTypeName = connectorFactory.GetType().FullName ?? string.Empty; + Assert.EndsWith("StreamingConnectorFactory", connectorFactoryTypeName); + verifiedValidContinuation = true; + + return Task.CompletedTask; + }, + CancellationToken.None); + + var invalidContinuation = adapter.ContinueConversationAsync( + authResult.ClaimsIdentity, invalidActivity, (turn, cancellationToken) => Task.CompletedTask, CancellationToken.None); + + continueConversationWaiter.Set(); + processRequest.Wait(); + + // Assert + Assert.True(processRequest.IsCompletedSuccessfully); + Assert.True(verifiedValidContinuation); + Assert.True(validContinuation.IsCompletedSuccessfully); + Assert.Null(validContinuation.Exception); + Assert.True(invalidContinuation.IsFaulted); + Assert.NotEmpty(invalidContinuation.Exception.InnerExceptions); + Assert.True(invalidContinuation.Exception.InnerExceptions[0] is ApplicationException); + } + [Fact] public async Task MessageActivityWithHttpClient() { @@ -831,5 +974,32 @@ public override Task CreateAsync(string serviceUrl, string aud return Task.FromResult((IConnectorClient)new ConnectorClient(new Uri(serviceUrl), credentials, null, disposeHttpClient: true)); } } + + private class TestContentStream : IContentStream + { + public Guid Id { get; set; } + + public string ContentType { get; set; } + + public int? Length { get; set; } + + public Stream Stream { get; set; } + } + + private class StreamingTestCloudAdapter : CloudAdapter + { + private readonly StreamingConnection _connection; + + public StreamingTestCloudAdapter(BotFrameworkAuthentication auth, StreamingConnection connection) + : base(auth) + { + _connection = connection; + } + + protected override StreamingConnection CreateWebSocketConnection(HttpContext httpContext, ILogger logger) + { + return _connection; + } + } } } diff --git a/tests/tests.uischema b/tests/tests.uischema index f926d1b4e2..3f601049c5 100644 --- a/tests/tests.uischema +++ b/tests/tests.uischema @@ -8,7 +8,8 @@ "triggers", "generator", "selector", - "schema" + "schema", + "dialogs" ], "label": "Adaptive dialog", "order": [