Skip to content

Commit

Permalink
allow two distinct ChannelOptions of one type (#597)
Browse files Browse the repository at this point in the history
Motivation:

Quite embarrasingly, we previously would only store one `ChannelOption`
per `ChannelOption` type. Most channel option types are distinct and
that's probably why it took so long to find this issue. Thanks
@pushkarnk for reporting. Unfortunately though, the most important
`ChannelOption` is `.socket` which crucially also holds a level and a
name. That means if you set two `ChannelOptions.socket` options with
distinct name/level, one would still override the other.

Example:

    .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1)
    .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)

would only actually set the latter.

Modifications:

- made all common `ChannelOption` types equatable (for 2.0 this will
  be a protocol requirement)
- deprecated non-Equatable `ChannelOption` types
- zero out buffer before calling getsockopt as Linux doesn't do that

Result:

you can now set two distinct `ChannelOptions` for one type
Motivation:

Explain here the context, and why you're making that change.
What is the problem you're trying to solve.

Modifications:

Describe the modifications you've done.

Result:

After your change, what will change.
  • Loading branch information
weissi committed Aug 29, 2018
1 parent 34ec7b3 commit 8220bdf
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 6 deletions.
16 changes: 13 additions & 3 deletions Sources/NIO/BaseSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,12 @@ class BaseSocket: Selectable {
/// - value: The value for the option.
/// - throws: An `IOError` if the operation failed.
final func setOption<T>(level: Int32, name: Int32, value: T) throws {
try withUnsafeFileDescriptor { fd in
if level == SocketOptionValue(IPPROTO_TCP) && name == TCP_NODELAY && (try? self.localAddress().protocolFamily) == Optional<Int32>.some(Int32(Posix.AF_UNIX)) {
// setting TCP_NODELAY on UNIX domain sockets will fail. Previously we had a bug where we would ignore
// most socket options settings so for the time being we'll just ignore this. Let's revisit for NIO 2.0.
return
}
return try withUnsafeFileDescriptor { fd in
var val = value

_ = try Posix.setsockopt(
Expand All @@ -355,10 +360,15 @@ class BaseSocket: Selectable {
final func getOption<T>(level: Int32, name: Int32) throws -> T {
return try withUnsafeFileDescriptor { fd in
var length = socklen_t(MemoryLayout<T>.size)
var val = UnsafeMutablePointer<T>.allocate(capacity: 1)
let storage = UnsafeMutableRawBufferPointer.allocate(byteCount: MemoryLayout<T>.stride,
alignment: MemoryLayout<T>.alignment)
// write zeroes into the memory as Linux's getsockopt doesn't zero them out
_ = storage.initializeMemory(as: UInt8.self, repeating: 0)
var val = storage.bindMemory(to: T.self).baseAddress!
// initialisation will be done by getsockopt
defer {
val.deinitialize(count: 1)
val.deallocate()
storage.deallocate()
}

try Posix.getsockopt(socket: fd, level: level, optionName: name, optionValue: val, optionLen: &length)
Expand Down
25 changes: 22 additions & 3 deletions Sources/NIO/Bootstrap.swift
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,30 @@ public final class DatagramBootstrap {
}
}

fileprivate struct ChannelOptionStorage {
/* for tests */ internal struct ChannelOptionStorage {
private var storage: [(Any, (Any, (Channel) -> (Any, Any) -> EventLoopFuture<Void>))] = []

mutating func put<K: ChannelOption & Equatable>(key: K, value: K.OptionType) {
return self.put(key: key, value: value, equalsFunc: ==)
}

// HACK: this function should go for NIO 2.0, all ChannelOptions should be equatable
mutating func put<K: ChannelOption>(key: K, value: K.OptionType) {
if K.self == SocketOption.self {
return self.put(key: key as! SocketOption, value: value as! SocketOptionValue) { lhs, rhs in
switch (lhs, rhs) {
case (.const(let lLevel, let lName), .const(let rLevel, let rName)):
return lLevel == rLevel && lName == rName
}
}
} else {
return self.put(key: key, value: value) { _, _ in true }
}
}

mutating func put<K: ChannelOption>(key: K,
value newValue: K.OptionType) {
value newValue: K.OptionType,
equalsFunc: (K, K) -> Bool) {
func applier(_ t: Channel) -> (Any, Any) -> EventLoopFuture<Void> {
return { (x, y) in
return t.setOption(option: x as! K, value: y as! K.OptionType)
Expand All @@ -685,7 +704,7 @@ fileprivate struct ChannelOptionStorage {
var hasSet = false
self.storage = self.storage.map { typeAndValue in
let (type, value) = typeAndValue
if type is K {
if type is K && equalsFunc(type as! K, key) {
hasSet = true
return (key, (newValue, applier))
} else {
Expand Down
10 changes: 10 additions & 0 deletions Sources/NIO/ByteBuffer-core.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ let sysFree: @convention(c) (UnsafeMutableRawPointer?) -> Void = free
}
}
public extension UnsafeMutableRawBufferPointer {
internal static func allocate(byteCount: Int, alignment: Int) -> UnsafeMutableRawBufferPointer {
return UnsafeMutableRawBufferPointer.allocate(count: byteCount)
}

internal func initializeMemory<T>(as type: T.Type, repeating repeatedValue: T) -> UnsafeMutableBufferPointer<T> {
let ptr = self.bindMemory(to: T.self)
ptr.initialize(from: repeatElement(repeatedValue, count: self.count / MemoryLayout<T>.stride))
return ptr
}

public func copyMemory(from src: UnsafeRawBufferPointer) {
self.copyBytes(from: src)
}
Expand Down
1 change: 1 addition & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import XCTest
testCase(ByteBufferTest.allTests),
testCase(ByteToMessageDecoderTest.allTests),
testCase(ChannelNotificationTest.allTests),
testCase(ChannelOptionStorageTest.allTests),
testCase(ChannelPipelineTest.allTests),
testCase(ChannelTests.allTests),
testCase(CircularBufferTests.allTests),
Expand Down
36 changes: 36 additions & 0 deletions Tests/NIOTests/ChannelOptionStorageTest+XCTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// ChannelOptionStorageTest+XCTest.swift
//
import XCTest

///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///

extension ChannelOptionStorageTest {

static var allTests : [(String, (ChannelOptionStorageTest) -> () throws -> Void)] {
return [
("testWeStartWithNoOptions", testWeStartWithNoOptions),
("testSetTwoOptionsOfDifferentType", testSetTwoOptionsOfDifferentType),
("testSetTwoOptionsOfSameType", testSetTwoOptionsOfSameType),
("testSetOneOptionTwice", testSetOneOptionTwice),
]
}
}

99 changes: 99 additions & 0 deletions Tests/NIOTests/ChannelOptionStorageTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import XCTest

@testable import NIO

class ChannelOptionStorageTest: XCTestCase {
func testWeStartWithNoOptions() throws {
let cos = ChannelOptionStorage()
let optionsCollector = OptionsCollectingChannel()
XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait())
XCTAssertEqual(0, optionsCollector.allOptions.count)
}

func testSetTwoOptionsOfDifferentType() throws {
var cos = ChannelOptionStorage()
let optionsCollector = OptionsCollectingChannel()
cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
cos.put(key: ChannelOptions.backlog, value: 2)
XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait())
XCTAssertEqual(2, optionsCollector.allOptions.count)
}

func testSetTwoOptionsOfSameType() throws {
let options: [(SocketOption, SocketOptionValue)] = [(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), 1),
(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), 2)]
var cos = ChannelOptionStorage()
let optionsCollector = OptionsCollectingChannel()
for kv in options {
cos.put(key: kv.0, value: kv.1)
}
XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait())
XCTAssertEqual(2, optionsCollector.allOptions.count)
XCTAssertEqual(options.map { $0.0 },
(optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.0 })
XCTAssertEqual(options.map { $0.1 },
(optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.1 })
}

func testSetOneOptionTwice() throws {
var cos = ChannelOptionStorage()
let optionsCollector = OptionsCollectingChannel()
cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
cos.put(key: ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 2)
XCTAssertNoThrow(try cos.applyAll(channel: optionsCollector).wait())
XCTAssertEqual(1, optionsCollector.allOptions.count)
XCTAssertEqual([ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR)],
(optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.0 })
XCTAssertEqual([SocketOptionValue(2)],
(optionsCollector.allOptions as! [(SocketOption, SocketOptionValue)]).map { $0.1 })
}
}

class OptionsCollectingChannel: Channel {
var allOptions: [(Any, Any)] = []

var allocator: ByteBufferAllocator { fatalError() }

var closeFuture: EventLoopFuture<Void> { fatalError() }

var pipeline: ChannelPipeline { fatalError() }

var localAddress: SocketAddress? { fatalError() }

var remoteAddress: SocketAddress? { fatalError() }

var parent: Channel? { fatalError() }

func setOption<T>(option: T, value: T.OptionType) -> EventLoopFuture<Void> where T : ChannelOption {
self.allOptions.append((option, value))
return self.eventLoop.newSucceededFuture(result: ())
}

func getOption<T>(option: T) -> EventLoopFuture<T.OptionType> where T : ChannelOption {
fatalError()
}

var isWritable: Bool { fatalError() }

var isActive: Bool { fatalError() }

var _unsafe: ChannelCore { fatalError() }

var eventLoop: EventLoop {
return EmbeddedEventLoop()
}
}
1 change: 1 addition & 0 deletions Tests/NIOTests/ChannelTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ extension ChannelTests {
("testFailedRegistrationOfServerSocket", testFailedRegistrationOfServerSocket),
("testTryingToBindOnPortThatIsAlreadyBoundFailsButDoesNotCrash", testTryingToBindOnPortThatIsAlreadyBoundFailsButDoesNotCrash),
("testCloseInReadTriggeredByDrainingTheReceiveBufferBecauseOfWriteError", testCloseInReadTriggeredByDrainingTheReceiveBufferBecauseOfWriteError),
("testApplyingTwoDistinctSocketOptionsOfSameTypeWorks", testApplyingTwoDistinctSocketOptionsOfSameTypeWorks),
]
}
}
Expand Down
86 changes: 86 additions & 0 deletions Tests/NIOTests/ChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2565,6 +2565,83 @@ public class ChannelTests: XCTestCase {
XCTAssertNoThrow(try allDonePromise.futureResult.wait())
XCTAssertFalse(c.isActive)
}

func testApplyingTwoDistinctSocketOptionsOfSameTypeWorks() throws {
let singleThreadedELG = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try singleThreadedELG.syncShutdownGracefully())
}
var numberOfAcceptedChannel = 0
var acceptedChannels: [EventLoopPromise<Channel>] = [singleThreadedELG.next().newPromise(),
singleThreadedELG.next().newPromise(),
singleThreadedELG.next().newPromise()]
let server = try assertNoThrowWithValue(ServerBootstrap(group: singleThreadedELG)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_TIMESTAMP), value: 1)
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_KEEPALIVE), value: 1)
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
.childChannelInitializer { channel in
acceptedChannels[numberOfAcceptedChannel].succeed(result: channel)
numberOfAcceptedChannel += 1
return channel.eventLoop.newSucceededFuture(result: ())
}
.bind(host: "127.0.0.1", port: 0)
.wait())
defer {
XCTAssertNoThrow(try server.close().wait())
}
XCTAssertTrue(try getBoolSocketOption(channel: server, level: SOL_SOCKET, name: SO_REUSEADDR))
XCTAssertTrue(try getBoolSocketOption(channel: server, level: SOL_SOCKET, name: SO_TIMESTAMP))

let client1 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
.connect(to: server.localAddress!)
.wait())
let accepted1 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait())
defer {
XCTAssertNoThrow(try client1.close().wait())
}
let client2 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.connect(to: server.localAddress!)
.wait())
let accepted2 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait())
defer {
XCTAssertNoThrow(try client2.close().wait())
}
let client3 = try assertNoThrowWithValue(ClientBootstrap(group: singleThreadedELG)
.connect(to: server.localAddress!)
.wait())
let accepted3 = try assertNoThrowWithValue(acceptedChannels[0].futureResult.wait())
defer {
XCTAssertNoThrow(try client3.close().wait())
}

XCTAssertTrue(try getBoolSocketOption(channel: client1, level: SOL_SOCKET, name: SO_REUSEADDR))

XCTAssertTrue(try getBoolSocketOption(channel: client1, level: IPPROTO_TCP, name: TCP_NODELAY))

XCTAssertTrue(try getBoolSocketOption(channel: accepted1, level: SOL_SOCKET, name: SO_KEEPALIVE))

XCTAssertTrue(try getBoolSocketOption(channel: accepted1, level: IPPROTO_TCP, name: TCP_NODELAY))

XCTAssertTrue(try getBoolSocketOption(channel: client2, level: SOL_SOCKET, name: SO_REUSEADDR))

XCTAssertFalse(try getBoolSocketOption(channel: client2, level: IPPROTO_TCP, name: TCP_NODELAY))

XCTAssertTrue(try getBoolSocketOption(channel: accepted2, level: SOL_SOCKET, name: SO_KEEPALIVE))

XCTAssertTrue(try getBoolSocketOption(channel: accepted2, level: IPPROTO_TCP, name: TCP_NODELAY))

XCTAssertFalse(try getBoolSocketOption(channel: client3, level: SOL_SOCKET, name: SO_REUSEADDR))

XCTAssertFalse(try getBoolSocketOption(channel: client3, level: IPPROTO_TCP, name: TCP_NODELAY))

XCTAssertTrue(try getBoolSocketOption(channel: accepted3, level: SOL_SOCKET, name: SO_KEEPALIVE))

XCTAssertTrue(try getBoolSocketOption(channel: accepted3, level: IPPROTO_TCP, name: TCP_NODELAY))
}
}

fileprivate final class FailRegistrationAndDelayCloseHandler: ChannelOutboundHandler {
Expand Down Expand Up @@ -2620,3 +2697,12 @@ fileprivate class VerifyConnectionFailureHandler: ChannelInboundHandler {
ctx.fireChannelUnregistered()
}
}

extension SocketOption: Equatable {
public static func == (lhs: SocketOption, rhs: SocketOption) -> Bool {
switch (lhs, rhs) {
case (.const(let lLevel, let lName), .const(let rLevel, let rName)):
return lLevel == rLevel && lName == rName
}
}
}
1 change: 1 addition & 0 deletions Tests/NIOTests/DatagramChannelTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ extension DatagramChannelTests {
("testRecvFromFailsWithEFAULT", testRecvFromFailsWithEFAULT),
("testSetGetOptionClosedDatagramChannel", testSetGetOptionClosedDatagramChannel),
("testWritesAreAccountedCorrectly", testWritesAreAccountedCorrectly),
("testSettingTwoDistinctChannelOptionsWorksForDatagramChannel", testSettingTwoDistinctChannelOptionsWorksForDatagramChannel),
]
}
}
Expand Down
14 changes: 14 additions & 0 deletions Tests/NIOTests/DatagramChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -443,4 +443,18 @@ final class DatagramChannelTests: XCTestCase {
XCTAssertEqual(reads[1].data, buffer)
XCTAssertEqual(reads[1].remoteAddress, self.firstChannel.localAddress!)
}

func testSettingTwoDistinctChannelOptionsWorksForDatagramChannel() throws {
let channel = try assertNoThrowWithValue(DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_TIMESTAMP), value: 1)
.bind(host: "127.0.0.1", port: 0)
.wait())
defer {
XCTAssertNoThrow(try channel.close().wait())
}
XCTAssertTrue(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_REUSEADDR))
XCTAssertTrue(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_TIMESTAMP))
XCTAssertFalse(try getBoolSocketOption(channel: channel, level: SOL_SOCKET, name: SO_KEEPALIVE))
}
}
8 changes: 8 additions & 0 deletions Tests/NIOTests/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,11 @@ func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testI
XCTFail(message)
}
}

func getBoolSocketOption<IntType: SignedInteger>(channel: Channel, level: IntType, name: SocketOptionName,
file: StaticString = #file, line: UInt = #line) throws -> Bool {
return try assertNoThrowWithValue(channel.getOption(option: ChannelOptions.socket(SocketOptionLevel(level),
name)),
file: file,
line: line).wait() != 0
}

0 comments on commit 8220bdf

Please sign in to comment.