Skip to content

Commit

Permalink
[Fix] Query Hangs if Connection is Closed (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM authored Jun 24, 2024
1 parent 7b621c1 commit f55caa7
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 12 deletions.
39 changes: 28 additions & 11 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
self.write(.extendedQuery(context), cascadingFailureTo: promise)

return promise.futureResult
}
Expand All @@ -239,7 +239,8 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
self.write(.extendedQuery(context), cascadingFailureTo: promise)

return promise.futureResult.map { rowDescription in
PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription)
}
Expand All @@ -255,15 +256,17 @@ public final class PostgresConnection: @unchecked Sendable {
logger: logger,
promise: promise)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
self.write(.extendedQuery(context), cascadingFailureTo: promise)

return promise.futureResult
}

func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
let context = CloseCommandContext(target: target, logger: logger, promise: promise)

self.channel.write(HandlerTask.closeCommand(context), promise: nil)
self.write(.closeCommand(context), cascadingFailureTo: promise)

return promise.futureResult
}

Expand Down Expand Up @@ -426,7 +429,7 @@ extension PostgresConnection {
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
self.write(.extendedQuery(context), cascadingFailureTo: promise)

do {
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
Expand Down Expand Up @@ -455,7 +458,11 @@ extension PostgresConnection {

let task = HandlerTask.startListening(listener)

self.channel.write(task, promise: nil)
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.whenFailure { error in
listener.failed(error)
}
}
} onCancel: {
let task = HandlerTask.cancelListening(channel, id)
Expand All @@ -480,7 +487,9 @@ extension PostgresConnection {
logger: logger,
promise: promise
))
self.channel.write(task, promise: nil)

self.write(task, cascadingFailureTo: promise)

do {
return try await promise.futureResult
.map { $0.asyncSequence() }
Expand Down Expand Up @@ -515,7 +524,9 @@ extension PostgresConnection {
logger: logger,
promise: promise
))
self.channel.write(task, promise: nil)

self.write(task, cascadingFailureTo: promise)

do {
return try await promise.futureResult
.map { $0.commandTag }
Expand All @@ -530,6 +541,12 @@ extension PostgresConnection {
throw error // rethrow with more metadata
}
}

private func write<T>(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise<T>) {
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.cascadeFailure(to: promise)
}
}

// MARK: EventLoopFuture interface
Expand Down Expand Up @@ -674,7 +691,7 @@ internal enum PostgresCommands: PostgresRequest {

/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support.
public final class PostgresListenContext: Sendable {
private let promise: EventLoopPromise<Void>
let promise: EventLoopPromise<Void>

var future: EventLoopFuture<Void> {
self.promise.futureResult
Expand Down Expand Up @@ -713,8 +730,7 @@ extension PostgresConnection {
closure: notificationHandler
)

let task = HandlerTask.startListening(listener)
self.channel.write(task, promise: nil)
self.write(.startListening(listener), cascadingFailureTo: listenContext.promise)

listenContext.future.whenComplete { _ in
let task = HandlerTask.cancelListening(channel, id)
Expand Down Expand Up @@ -761,3 +777,4 @@ extension PostgresConnection {
#endif
}
}

1 change: 0 additions & 1 deletion Tests/IntegrationTests/PSQLIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,4 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual(obj?.bar, 2)
}
}

}
169 changes: 169 additions & 0 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testSimpleListenFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.listen("test_channel")
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
let events = try await connection.listen("foo")
var iterator = events.makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first?.payload, "wooohooo")
do {
_ = try await iterator.next()
XCTFail("Did not expect to not throw")
} catch let error as PSQLError {
XCTAssertEqual(error.code, .clientClosedConnection)
}
}

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#)

try await channel.writeInbound(PostgresBackendMessage.parseComplete)
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
try await channel.writeInbound(PostgresBackendMessage.noData)
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN"))
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo")))

try await connection.close()

XCTAssertEqual(channel.isActive, false)

switch await taskGroup.nextResult()! {
case .success:
break
case .failure(let failure):
XCTFail("Unexpected error: \(failure)")
}
}
}

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
Expand Down Expand Up @@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testQueryFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.query("SELECT version;", logger: self.logger)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testPrepareStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecuteFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil)
_ = try await connection.execute(statement, logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
typealias Row = (Int, String)

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
try row.decode(Row.self)
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1"
typealias Row = ()

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
()
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
let eventLoop = NIOAsyncTestingEventLoop()
let channel = await NIOAsyncTestingChannel(handlers: [
Expand Down

0 comments on commit f55caa7

Please sign in to comment.