From d6f7e8e38b9d5dfa33fd1c4f4eae475c992454d2 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Aug 2023 16:16:14 +0200 Subject: [PATCH] Make sure correct error is thrown, if server closes connection --- .../ConnectionStateMachine.swift | 28 ++++++++++--------- .../New/PostgresConnectionTests.swift | 28 +++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index bbfa0faa..b7ecc461 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -203,7 +203,7 @@ struct ConnectionStateMachine { preconditionFailure("How can a connection be closed, if it was never connected.") case .closed: - preconditionFailure("How can a connection be closed, if it is already closed.") + return .wait case .authenticated, .sslRequestSent, @@ -214,8 +214,8 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand: - return self.errorHappened(.uncleanShutdown) - + return self.errorHappened(.serverClosedConnection(underlying: nil)) + case .closing(let error): self.state = .closed(clientInitiated: true, error: error) self.quiescingState = .notQuiescing @@ -910,7 +910,7 @@ struct ConnectionStateMachine { // the error state and will try to close the connection. However the server might have // send further follow up messages. In those cases we will run into this method again // and again. We should just ignore those events. - return .wait + return .closeConnection(closePromise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -1034,16 +1034,16 @@ extension ConnectionStateMachine { case .clientClosesConnection, .clientClosedConnection: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .serverClosedConnection: - preconditionFailure("Pure client error, that is thrown directly and should never ") + return true } } mutating func setErrorAndCreateCleanupContextIfNeeded(_ error: PSQLError) -> ConnectionAction.CleanUpContext? { - guard self.shouldCloseConnection(reason: error) else { - return nil + if self.shouldCloseConnection(reason: error) { + return self.setErrorAndCreateCleanupContext(error) } - return self.setErrorAndCreateCleanupContext(error) + return nil } mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { @@ -1060,13 +1060,15 @@ extension ConnectionStateMachine { forwardedPromise = closePromise } - self.state = .closing(error) - - var action = ConnectionAction.CleanUpContext.Action.close - if case .uncleanShutdown = error.code.base { + let action: ConnectionAction.CleanUpContext.Action + if case .serverClosedConnection = error.code.base { + self.state = .closed(clientInitiated: false, error: error) action = .fireChannelInactive + } else { + self.state = .closing(error) + action = .close } - + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 9c4dc5cb..59917c40 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,34 @@ class PostgresConnectionTests: XCTestCase { } } + func testIfServerJustClosesTheErrorReflectsThat() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + async let response = try await connection.query("SELECT 1;", logger: self.logger) + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + + do { + _ = try await response + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + + // retry on same connection + + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + } + struct TestPrepareStatement: PostgresPreparedStatement { static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String