diff --git a/Sources/NIO/BaseSocket.swift b/Sources/NIO/BaseSocket.swift index 0750f3c513..fa8a571422 100644 --- a/Sources/NIO/BaseSocket.swift +++ b/Sources/NIO/BaseSocket.swift @@ -332,7 +332,12 @@ class BaseSocket: Selectable { /// - value: The value for the option. /// - throws: An `IOError` if the operation failed. final func setOption(level: Int32, name: Int32, value: T) throws { - try withUnsafeFileDescriptor { fd in + if level == SocketOptionValue(IPPROTO_TCP) && name == TCP_NODELAY && (try? self.localAddress().protocolFamily) == Optional.some(Int32(Posix.AF_UNIX)) { + // setting TCP_NODELAY on UNIX domain sockets will fail. Previously we had a bug where we would ignore + // most socket options settings so for the time being we'll just ignore this. Let's revisit for NIO 2.0. + return + } + return try withUnsafeFileDescriptor { fd in var val = value _ = try Posix.setsockopt( @@ -355,10 +360,15 @@ class BaseSocket: Selectable { final func getOption(level: Int32, name: Int32) throws -> T { return try withUnsafeFileDescriptor { fd in var length = socklen_t(MemoryLayout.size) - var val = UnsafeMutablePointer.allocate(capacity: 1) + let storage = UnsafeMutableRawBufferPointer.allocate(byteCount: MemoryLayout.stride, + alignment: MemoryLayout.alignment) + // write zeroes into the memory as Linux's getsockopt doesn't zero them out + _ = storage.initializeMemory(as: UInt8.self, repeating: 0) + var val = storage.bindMemory(to: T.self).baseAddress! + // initialisation will be done by getsockopt defer { val.deinitialize(count: 1) - val.deallocate() + storage.deallocate() } try Posix.getsockopt(socket: fd, level: level, optionName: name, optionValue: val, optionLen: &length) diff --git a/Sources/NIO/Bootstrap.swift b/Sources/NIO/Bootstrap.swift index 85d58861c4..43eeface73 100644 --- a/Sources/NIO/Bootstrap.swift +++ b/Sources/NIO/Bootstrap.swift @@ -672,11 +672,30 @@ public final class DatagramBootstrap { } } -fileprivate struct ChannelOptionStorage { +/* for tests */ internal struct ChannelOptionStorage { private var storage: [(Any, (Any, (Channel) -> (Any, Any) -> EventLoopFuture))] = [] + mutating func put(key: K, value: K.OptionType) { + return self.put(key: key, value: value, equalsFunc: ==) + } + + // HACK: this function should go for NIO 2.0, all ChannelOptions should be equatable + mutating func put(key: K, value: K.OptionType) { + if K.self == SocketOption.self { + return self.put(key: key as! SocketOption, value: value as! SocketOptionValue) { lhs, rhs in + switch (lhs, rhs) { + case (.const(let lLevel, let lName), .const(let rLevel, let rName)): + return lLevel == rLevel && lName == rName + } + } + } else { + return self.put(key: key, value: value) { _, _ in true } + } + } + mutating func put(key: K, - value newValue: K.OptionType) { + value newValue: K.OptionType, + equalsFunc: (K, K) -> Bool) { func applier(_ t: Channel) -> (Any, Any) -> EventLoopFuture { return { (x, y) in return t.setOption(option: x as! K, value: y as! K.OptionType) @@ -685,7 +704,7 @@ fileprivate struct ChannelOptionStorage { var hasSet = false self.storage = self.storage.map { typeAndValue in let (type, value) = typeAndValue - if type is K { + if type is K && equalsFunc(type as! K, key) { hasSet = true return (key, (newValue, applier)) } else { diff --git a/Sources/NIO/ByteBuffer-core.swift b/Sources/NIO/ByteBuffer-core.swift index fbc84b429f..2eec4fb48b 100644 --- a/Sources/NIO/ByteBuffer-core.swift +++ b/Sources/NIO/ByteBuffer-core.swift @@ -23,6 +23,16 @@ let sysFree: @convention(c) (UnsafeMutableRawPointer?) -> Void = free } } public extension UnsafeMutableRawBufferPointer { + internal static func allocate(byteCount: Int, alignment: Int) -> UnsafeMutableRawBufferPointer { + return UnsafeMutableRawBufferPointer.allocate(count: byteCount) + } + + internal func initializeMemory(as type: T.Type, repeating repeatedValue: T) -> UnsafeMutableBufferPointer { + let ptr = self.bindMemory(to: T.self) + ptr.initialize(from: repeatElement(repeatedValue, count: self.count / MemoryLayout.stride)) + return ptr + } + public func copyMemory(from src: UnsafeRawBufferPointer) { self.copyBytes(from: src) } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index e4dfc85e30..976f6c643d 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -40,6 +40,7 @@ import XCTest testCase(ByteBufferTest.allTests), testCase(ByteToMessageDecoderTest.allTests), testCase(ChannelNotificationTest.allTests), + testCase(ChannelOptionStorageTest.allTests), testCase(ChannelPipelineTest.allTests), testCase(ChannelTests.allTests), testCase(CircularBufferTests.allTests), diff --git a/Tests/NIOTests/ChannelOptionStorageTest+XCTest.swift b/Tests/NIOTests/ChannelOptionStorageTest+XCTest.swift new file mode 100644 index 0000000000..00a7ec9a66 --- /dev/null +++ b/Tests/NIOTests/ChannelOptionStorageTest+XCTest.swift @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// ChannelOptionStorageTest+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension ChannelOptionStorageTest { + + static var allTests : [(String, (ChannelOptionStorageTest) -> () throws -> Void)] { + return [ + ("testWeStartWithNoOptions", testWeStartWithNoOptions), + ("testSetTwoOptionsOfDifferentType", testSetTwoOptionsOfDifferentType), + ("testSetTwoOptionsOfSameType", testSetTwoOptionsOfSameType), + ("testSetOneOptionTwice", testSetOneOptionTwice), + ] + } +} + diff --git a/Tests/NIOTests/ChannelOptionStorageTest.swift b/Tests/NIOTests/ChannelOptionStorageTest.swift new file mode 100644 index 0000000000..8abd14efc4 --- /dev/null +++ b/Tests/NIOTests/ChannelOptionStorageTest.swift @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest + +@testable import NIO + +class ChannelOptionStorageTest: XCTestCase { + func testWeStartWithNoOptions() throws { + let cos = ChannelOptionStorage() + let optionsCollector = OptionsCollectingChannel() + XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait()) + XCTAssertEqual(0, optionsCollector.allOptions.count) + } + + func testSetTwoOptionsOfDifferentType() throws { + var cos = ChannelOptionStorage() + let optionsCollector = OptionsCollectingChannel() + cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + cos.put(key: ChannelOptions.backlog, value: 2) + XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait()) + XCTAssertEqual(2, optionsCollector.allOptions.count) + } + + func testSetTwoOptionsOfSameType() throws { + let options: [(SocketOption, SocketOptionValue)] = [(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), 1), + (ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), 2)] + var cos = ChannelOptionStorage() + let optionsCollector = OptionsCollectingChannel() + for kv in options { + cos.put(key: kv.0, value: kv.1) + } + XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait()) + XCTAssertEqual(2, optionsCollector.allOptions.count) + XCTAssertEqual(options.map { $0.0 }, + (optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.0 }) + XCTAssertEqual(options.map { $0.1 }, + (optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.1 }) + } + + func testSetOneOptionTwice() throws { + var cos = ChannelOptionStorage() + let optionsCollector = OptionsCollectingChannel() + cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 2) + XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait()) + XCTAssertEqual(1, optionsCollector.allOptions.count) + XCTAssertEqual([ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR)], + (optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.0 }) + XCTAssertEqual([SocketOptionValue(2)], + (optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.1 }) + } +} + +class OptionsCollectingChannel: Channel { + var allOptions: [(Any, Any)] = [] + + var allocator: ByteBufferAllocator { fatalError() } + + var closeFuture: EventLoopFuture { fatalError() } + + var pipeline: ChannelPipeline { fatalError() } + + var localAddress: SocketAddress? { fatalError() } + + var remoteAddress: SocketAddress? { fatalError() } + + var parent: Channel? { fatalError() } + + func setOption(option: T, value: T.OptionType) -> EventLoopFuture where T : ChannelOption { + self.allOptions.append((option, value)) + return self.eventLoop.newSucceededFuture(result: ()) + } + + func getOption(option: T) -> EventLoopFuture where T : ChannelOption { + fatalError() + } + + var isWritable: Bool { fatalError() } + + var isActive: Bool { fatalError() } + + var _unsafe: ChannelCore { fatalError() } + + var eventLoop: EventLoop { + return EmbeddedEventLoop() + } +} diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index c260c82e95..831a806989 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -76,6 +76,7 @@ extension ChannelTests { ("testFailedRegistrationOfServerSocket", testFailedRegistrationOfServerSocket), ("testTryingToBindOnPortThatIsAlreadyBoundFailsButDoesNotCrash", testTryingToBindOnPortThatIsAlreadyBoundFailsButDoesNotCrash), ("testCloseInReadTriggeredByDrainingTheReceiveBufferBecauseOfWriteError", testCloseInReadTriggeredByDrainingTheReceiveBufferBecauseOfWriteError), + ("testApplyingTwoDistinctSocketOptionsOfSameTypeWorks", testApplyingTwoDistinctSocketOptionsOfSameTypeWorks), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 219879d294..78c946aafa 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -2565,6 +2565,83 @@ public class ChannelTests: XCTestCase { XCTAssertNoThrow(try allDonePromise.futureResult.wait()) XCTAssertFalse(c.isActive) } + + func testApplyingTwoDistinctSocketOptionsOfSameTypeWorks() throws { + let singleThreadedELG = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try singleThreadedELG.syncShutdownGracefully()) + } + var numberOfAcceptedChannel = 0 + var acceptedChannels: [EventLoopPromise] = [singleThreadedELG.next().newPromise(), + singleThreadedELG.next().newPromise(), + singleThreadedELG.next().newPromise()] + let server = try assertNoThrowWithValue(ServerBootstrap(group: singleThreadedELG) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_TIMESTAMP), value: 1) + .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_KEEPALIVE), value: 1) + .childChannelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) + .childChannelInitializer { channel in + acceptedChannels[numberOfAcceptedChannel].succeed(result: channel) + numberOfAcceptedChannel += 1 + return channel.eventLoop.newSucceededFuture(result: ()) + } + .bind(host: "127.0.0.1", port: 0) + .wait()) + defer { + XCTAssertNoThrow(try server.close().wait()) + } + XCTAssertTrue(try getBoolSocketOption(channel: server, level: SOL_SOCKET, name: SO_REUSEADDR)) + XCTAssertTrue(try getBoolSocketOption(channel: server, level: SOL_SOCKET, name: SO_TIMESTAMP)) + + let client1 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) + .connect(to: server.localAddress!) + .wait()) + let accepted1 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait()) + defer { + XCTAssertNoThrow(try client1.close().wait()) + } + let client2 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .connect(to: server.localAddress!) + .wait()) + let accepted2 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait()) + defer { + XCTAssertNoThrow(try client2.close().wait()) + } + let client3 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG) + .connect(to: server.localAddress!) + .wait()) + let accepted3 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait()) + defer { + XCTAssertNoThrow(try client3.close().wait()) + } + + XCTAssertTrue(try getBoolSocketOption(channel: client1, level: SOL_SOCKET, name: SO_REUSEADDR)) + + XCTAssertTrue(try getBoolSocketOption(channel: client1, level: IPPROTO_TCP, name: TCP_NODELAY)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted1, level: SOL_SOCKET, name: SO_KEEPALIVE)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted1, level: IPPROTO_TCP, name: TCP_NODELAY)) + + XCTAssertTrue(try getBoolSocketOption(channel: client2, level: SOL_SOCKET, name: SO_REUSEADDR)) + + XCTAssertFalse(try getBoolSocketOption(channel: client2, level: IPPROTO_TCP, name: TCP_NODELAY)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted2, level: SOL_SOCKET, name: SO_KEEPALIVE)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted2, level: IPPROTO_TCP, name: TCP_NODELAY)) + + XCTAssertFalse(try getBoolSocketOption(channel: client3, level: SOL_SOCKET, name: SO_REUSEADDR)) + + XCTAssertFalse(try getBoolSocketOption(channel: client3, level: IPPROTO_TCP, name: TCP_NODELAY)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted3, level: SOL_SOCKET, name: SO_KEEPALIVE)) + + XCTAssertTrue(try getBoolSocketOption(channel: accepted3, level: IPPROTO_TCP, name: TCP_NODELAY)) + } } fileprivate final class FailRegistrationAndDelayCloseHandler: ChannelOutboundHandler { @@ -2620,3 +2697,12 @@ fileprivate class VerifyConnectionFailureHandler: ChannelInboundHandler { ctx.fireChannelUnregistered() } } + +extension SocketOption: Equatable { + public static func == (lhs: SocketOption, rhs: SocketOption) -> Bool { + switch (lhs, rhs) { + case (.const(let lLevel, let lName), .const(let rLevel, let rName)): + return lLevel == rLevel && lName == rName + } + } +} diff --git a/Tests/NIOTests/DatagramChannelTests+XCTest.swift b/Tests/NIOTests/DatagramChannelTests+XCTest.swift index cc0e559a93..f9f37bc13d 100644 --- a/Tests/NIOTests/DatagramChannelTests+XCTest.swift +++ b/Tests/NIOTests/DatagramChannelTests+XCTest.swift @@ -41,6 +41,7 @@ extension DatagramChannelTests { ("testRecvFromFailsWithEFAULT", testRecvFromFailsWithEFAULT), ("testSetGetOptionClosedDatagramChannel", testSetGetOptionClosedDatagramChannel), ("testWritesAreAccountedCorrectly", testWritesAreAccountedCorrectly), + ("testSettingTwoDistinctChannelOptionsWorksForDatagramChannel", testSettingTwoDistinctChannelOptionsWorksForDatagramChannel), ] } } diff --git a/Tests/NIOTests/DatagramChannelTests.swift b/Tests/NIOTests/DatagramChannelTests.swift index 52c5a6026d..6bd635f00e 100644 --- a/Tests/NIOTests/DatagramChannelTests.swift +++ b/Tests/NIOTests/DatagramChannelTests.swift @@ -443,4 +443,18 @@ final class DatagramChannelTests: XCTestCase { XCTAssertEqual(reads[1].data, buffer) XCTAssertEqual(reads[1].remoteAddress, self.firstChannel.localAddress!) } + + func testSettingTwoDistinctChannelOptionsWorksForDatagramChannel() throws { + let channel = try assertNoThrowWithValue(DatagramBootstrap(group: group) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_TIMESTAMP), value: 1) + .bind(host: "127.0.0.1", port: 0) + .wait()) + defer { + XCTAssertNoThrow(try channel.close().wait()) + } + XCTAssertTrue(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_REUSEADDR)) + XCTAssertTrue(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_TIMESTAMP)) + XCTAssertFalse(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_KEEPALIVE)) + } } diff --git a/Tests/NIOTests/TestUtils.swift b/Tests/NIOTests/TestUtils.swift index a992e70ae4..655a385688 100644 --- a/Tests/NIOTests/TestUtils.swift +++ b/Tests/NIOTests/TestUtils.swift @@ -220,3 +220,11 @@ func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testI XCTFail(message) } } + +func getBoolSocketOption(channel: Channel, level: IntType, name: SocketOptionName, + file: StaticString = #file, line: UInt = #line) throws -> Bool { + return try assertNoThrowWithValue(channel.getOption(option: ChannelOptions.socket(SocketOptionLevel(level), + name)), + file: file, + line: line).wait() != 0 +}