diff --git a/Sources/NIOSSL/NIOSSLServerHandler.swift b/Sources/NIOSSL/NIOSSLServerHandler.swift index 60e785de1..322f4202d 100644 --- a/Sources/NIOSSL/NIOSSLServerHandler.swift +++ b/Sources/NIOSSL/NIOSSLServerHandler.swift @@ -18,6 +18,11 @@ import NIO /// handler can be used in channels that are acting as the server in /// the TLS dialog. For client connections, use the `NIOSSLClientHandler`. public final class NIOSSLServerHandler: NIOSSLHandler { + public convenience init(context: NIOSSLContext) { + self.init(context: context, optionalCustomVerificationCallback: nil) + } + + @available(*, deprecated, renamed: "init(context:serverHostname:customVerificationCallback:)") public init(context: NIOSSLContext, verificationCallback: NIOSSLVerificationCallback? = nil) throws { guard let connection = context.createConnection() else { fatalError("Failed to create new connection in NIOSSLContext") @@ -31,4 +36,23 @@ public final class NIOSSLServerHandler: NIOSSLHandler { super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) } + + public convenience init(context: NIOSSLContext, customVerificationCallback: @escaping NIOSSLCustomVerificationCallback) { + self.init(context: context, optionalCustomVerificationCallback: customVerificationCallback) + } + + /// This exists to handle the explosion of initializers I got when I deprecated the first one. + private init(context: NIOSSLContext, optionalCustomVerificationCallback: NIOSSLCustomVerificationCallback?) { + guard let connection = context.createConnection() else { + fatalError("Failed to create new connection in NIOSSLContext") + } + + connection.setAcceptState() + + if let customVerificationCallback = optionalCustomVerificationCallback { + connection.setCustomVerificationCallback(.init(callback: customVerificationCallback)) + } + + super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) + } } diff --git a/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift b/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift index 6e1fedf90..f9b05f8eb 100644 --- a/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift +++ b/Sources/NIOSSLPerformanceTester/BenchManyWrites.swift @@ -34,7 +34,7 @@ final class BenchManyWrites: Benchmark { } func setUp() throws { - let serverHandler = try NIOSSLServerHandler(context: self.serverContext) + let serverHandler = NIOSSLServerHandler(context: self.serverContext) let clientHandler = try NIOSSLClientHandler(context: self.clientContext, serverHostname: "localhost") try self.backToBack.client.pipeline.addHandler(clientHandler).wait() try self.backToBack.server.pipeline.addHandler(serverHandler).wait() diff --git a/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift b/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift index c76a605d1..6b8fe0a3e 100644 --- a/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift +++ b/Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift @@ -35,7 +35,7 @@ final class BenchRepeatedHandshakes: Benchmark { func run() throws -> Int { for _ in 0.. Channel { return try assertNoThrowWithValue(ServerBootstrap(group: group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in let results = preHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop).flatMapThrowing { - try NIOSSLServerHandler(context: context) - }.flatMap { - channel.pipeline.addHandler($0) - }.flatMap { - let results = postHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop) + return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop).map { + if let verify = customVerificationCallback { + return NIOSSLServerHandler(context: context, customVerificationCallback: verify) + } else { + return NIOSSLServerHandler(context: context) + } + }.flatMap { + channel.pipeline.addHandler($0) + }.flatMap { + let results = postHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(results, on: channel.eventLoop) } }.bind(host: "127.0.0.1", port: 0).wait(), file: file, line: line) } @@ -902,7 +907,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr: SocketAddress = try SocketAddress(unixDomainSocketPath: "/tmp/whatever") @@ -1039,7 +1044,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr = try SocketAddress(unixDomainSocketPath: "/tmp/whatever2") @@ -1076,7 +1081,7 @@ class NIOSSLIntegrationTest: XCTestCase { let completePromise: EventLoopPromise = serverChannel.eventLoop.makePromise() - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(ReadRecordingHandler(completePromise: completePromise)).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1158,7 +1163,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() let addr = try SocketAddress(unixDomainSocketPath: "/tmp/whatever2") @@ -1244,7 +1249,7 @@ class NIOSSLIntegrationTest: XCTestCase { let completePromise: EventLoopPromise = serverChannel.eventLoop.makePromise() - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(ReadRecordingHandler(completePromise: completePromise)).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1607,6 +1612,46 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertNoThrow(try handshakeCompletePromise.futureResult.wait()) } + func testServerHasNewCallbackCalledToo() throws { + let config = TLSConfiguration.forServer(certificateChain: [.certificate(NIOSSLIntegrationTest.cert)], + privateKey: .privateKey(NIOSSLIntegrationTest.key), + certificateVerification: .fullVerification, + trustRoots: .default) + let context = try assertNoThrowWithValue(NIOSSLContext(configuration: config)) + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let handshakeResultPromise = group.next().makePromise(of: Void.self) + let handshakeWatcher = WaitForHandshakeHandler(handshakeResultPromise: handshakeResultPromise) + let serverChannel: Channel = try serverTLSChannel(context: context, + preHandlers: [], + postHandlers: [handshakeWatcher], + group: group, + customVerificationCallback: { _, promise in + promise.succeed(.failed) + }) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + + let clientChannel = try clientTLSChannel(context: try configuredSSLContext(), + preHandlers: [], + postHandlers: [], + group: group, + connectingTo: serverChannel.localAddress!) + + defer { + // Ignore errors here, the channel should be closed already by the time this happens. + try? clientChannel.close().wait() + } + + XCTAssertThrowsError(try handshakeResultPromise.futureResult.wait()) + } + func testRepeatedClosure() throws { let serverChannel = EmbeddedChannel() let clientChannel = EmbeddedChannel() @@ -1834,7 +1879,7 @@ class NIOSSLIntegrationTest: XCTestCase { } XCTAssertNoThrow(try serverChannel.pipeline.addHandler(SecondChannelInactiveSwallower()).wait()) - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(FlushOnReadHandler()).wait()) XCTAssertNoThrow(try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait()) @@ -1955,7 +2000,7 @@ class NIOSSLIntegrationTest: XCTestCase { let context = try configuredSSLContext() - try serverChannel.pipeline.addHandler(try NIOSSLServerHandler(context: context)).wait() + try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait() try clientChannel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: nil)).wait() // Do the handshake.