From ca407737c8379001c514e906dc5abdc484edfd2d Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Tue, 16 Jun 2020 14:45:43 +0100 Subject: [PATCH] Add support for custom verify callback to servers. Motivation: In #171 when we worked on providing access to the better verification callback, we managed to entirely miss that we had not provided that access to servers. This meant they were stuck only with the substantially-less-useful old-school callback, instead of the much better new-school one. While we're here, as we had to add multiple new initializers to NIOSSLServerHandler, I took the opportunity to also resolve the server handler portion of #147. The issue itself is still open because the client handlers still have throwing inits, but all "preferred" initializers on NIOSSLServerHandler no longer throw. Modifications: - Deprecated NIOSSLServerHandler.init(context:verificationCallback:) - Implemented two new initializers on NIOSSLServerHandler. - Added tests to verify that the NIOSSLServerHandler verification callback is actually called. - Removed all now-unnecessary try keywords. Result: Users will be able to provide custom verification callbacks that work much better than they currently can when on the server, and the server is now back into feature parity with the client. --- Sources/NIOSSL/NIOSSLServerHandler.swift | 24 ++++++ .../BenchManyWrites.swift | 2 +- .../BenchRepeatedHandshakes.swift | 2 +- Sources/NIOTLSServer/main.swift | 2 +- .../NIOSSLIntegrationTest+XCTest.swift | 1 + Tests/NIOSSLTests/NIOSSLIntegrationTest.swift | 73 +++++++++++++++---- 6 files changed, 87 insertions(+), 17 deletions(-) diff --git a/Sources/NIOSSL/NIOSSLServerHandler.swift b/Sources/NIOSSL/NIOSSLServerHandler.swift index 60e785de1..322f4202d 100644 --- a/Sources/NIOSSL/NIOSSLServerHandler.swift +++ b/Sources/NIOSSL/NIOSSLServerHandler.swift @@ -18,6 +18,11 @@ import NIO /// handler can be used in channels that are acting as the server in /// the TLS dialog. For client connections, use the `NIOSSLClientHandler`. public final class NIOSSLServerHandler: NIOSSLHandler { + public convenience init(context: NIOSSLContext) { + self.init(context: context, optionalCustomVerificationCallback: nil) + } + + @available(*, deprecated, renamed: "init(context:serverHostname:customVerificationCallback:)") public init(context: NIOSSLContext, verificationCallback: NIOSSLVerificationCallback? = nil) throws { guard let connection = context.createConnection() else { fatalError("Failed to create new connection in NIOSSLContext") @@ -31,4 +36,23 @@ public final class NIOSSLServerHandler: NIOSSLHandler { super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) } + + public convenience init(context: NIOSSLContext, customVerificationCallback: @escaping NIOSSLCustomVerificationCallback) { + self.init(context: context, optionalCustomVerificationCallback: customVerificationCallback) + } + + /// This exists to handle the explosion of initializers I got when I deprecated the first one. + private init(context: NIOSSLContext, optionalCustomVerificationCallback: NIOSSLCustomVerificationCallback?) { + guard let connection = context.createConnection() else { + fatalError("Failed to create new connection in NIOSSLContext") + } + + connection.setAcceptState() + + if let customVerificationCallback = optionalCustomVerificationCallback { + connection.setCustomVerificationCallback(.init(callback: customVerificationCallback)) + } + + super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) + } } diff --git a/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift b/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift index 6e1fedf90..f9b05f8eb 100644 --- a/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift +++ b/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift @@ -34,7 +34,7 @@ final class BenchManyWrites: Benchmark { } func setUp() throws { - let serverHandler = try NIOSSLServerHandler(context: self.serverContext) + let serverHandler = NIOSSLServerHandler(context: self.serverContext) let clientHandler = try NIOSSLClientHandler(context: self.clientContext, serverHostname: "localhost") try self.backToBack.client.pipeline.addHandler(clientHandler).wait() try self.backToBack.server.pipeline.addHandler(serverHandler).wait() diff --git a/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift b/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift index c76a605d1..6b8fe0a3e 100644 --- a/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift +++ b/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift @@ -35,7 +35,7 @@ final class BenchRepeatedHandshakes: Benchmark { func run() throws -> Int { for _ in 0.. Channel { return try assertNoThrowWithValue(ServerBootstrap(group: group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in let results = preHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop).flatMapThrowing { - try NIOSSLServerHandler(context: context) - }.flatMap { - channel.pipeline.addHandler($0) - }.flatMap { - let results = postHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop) + return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop).map { + if let verify = customVerificationCallback { + return NIOSSLServerHandler(context: context, customVerificationCallback: verify) + } else { + return NIOSSLServerHandler(context: context) + } + }.flatMap { + channel.pipeline.addHandler($0) + }.flatMap { + let results = postHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop) } }.bind(host: "127.0.0.1", port: 0).wait(), file: file, line: line) } @@ -902,7 +907,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr: SocketAddress = try SocketAddress(unixDomainSocketPath: "/tmp/whatever") @@ -1039,7 +1044,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr = try SocketAddress(unixDomainSocketPath: "/tmp/whatever2") @@ -1076,7 +1081,7 @@ class NIOSSLIntegrationTest: XCTestCase { let completePromise: EventLoopPromise = serverChannel.eventLoop.makePromise() - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(ReadRecordingHandler(completePromise: completePromise)).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1158,7 +1163,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr = try SocketAddress(unixDomainSocketPath: "/tmp/whatever2") @@ -1244,7 +1249,7 @@ class NIOSSLIntegrationTest: XCTestCase { let completePromise: EventLoopPromise = serverChannel.eventLoop.makePromise() - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(ReadRecordingHandler(completePromise: completePromise)).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1607,6 +1612,46 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertNoThrow(try handshakeCompletePromise.futureResult.wait()) } + func testServerHasNewCallbackCalledToo() throws { + let config = TLSConfiguration.forServer(certificateChain: [.certificate(NIOSSLIntegrationTest.cert)], + privateKey: .privateKey(NIOSSLIntegrationTest.key), + certificateVerification: .fullVerification, + trustRoots: .default) + let context = try assertNoThrowWithValue(NIOSSLContext(configuration: config)) + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let handshakeResultPromise = group.next().makePromise(of: Void.self) + let handshakeWatcher = WaitForHandshakeHandler(handshakeResultPromise: handshakeResultPromise) + let serverChannel: Channel = try serverTLSChannel(context: context, + preHandlers: [], + postHandlers: [handshakeWatcher], + group: group, + customVerificationCallback: { _, promise in + promise.succeed(.failed) + }) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + + let clientChannel = try clientTLSChannel(context: try configuredSSLContext(), + preHandlers: [], + postHandlers: [], + group: group, + connectingTo: serverChannel.localAddress!) + + defer { + // Ignore errors here, the channel should be closed already by the time this happens. + try? clientChannel.close().wait() + } + + XCTAssertThrowsError(try handshakeResultPromise.futureResult.wait()) + } + func testRepeatedClosure() throws { let serverChannel = EmbeddedChannel() let clientChannel = EmbeddedChannel() @@ -1834,7 +1879,7 @@ class NIOSSLIntegrationTest: XCTestCase { } XCTAssertNoThrow(try serverChannel.pipeline.addHandler(SecondChannelInactiveSwallower()).wait()) - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(FlushOnReadHandler()).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1955,7 +2000,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() // Do the handshake.