From 4fda8d77cb4bdaf50cb0722f5b7ac53e1b69bf52 Mon Sep 17 00:00:00 2001 From: Stevenson Michel <130018170+thoven87@users.noreply.github.com> Date: Thu, 10 Oct 2024 22:48:37 -0400 Subject: [PATCH 1/2] feat(transaction): Adding withTransaction --- Sources/PostgresNIO/Pool/PostgresClient.swift | 22 ++++++++++++ .../PostgresClientTests.swift | 35 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0907f1f8..0ba7725b 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -303,6 +303,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } + + /// Lease a connection for the provided `closure`'s lifetime. + /// A transation starts with call to withConnection + /// A transaction should end with a call to COMMIT or ROLLBACK + /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withTransaction(logger: Logger, _ process: (PostgresConnection) async throws -> Result) async throws -> Result { + try await withConnection { connection in + do { + try await connection.query("BEGIN;", logger: logger) + let value = try await process(connection) + try await connection.query("COMMIT;", logger: logger) + return value + } catch { + try await connection.query("ROLLBACK;", logger: logger) + throw error + } + } + } /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index d6d89dc3..9daae857 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,41 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + let iterations = 1000 + + for _ in 0.. Date: Mon, 21 Oct 2024 07:40:23 -0400 Subject: [PATCH 2/2] adding tests for commit and rollback --- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- .../PostgresClientTests.swift | 83 +++++++++++++++---- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0ba7725b..9c79b9d9 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -314,8 +314,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// - Returns: The closure's return value. public func withTransaction(logger: Logger, _ process: (PostgresConnection) async throws -> Result) async throws -> Result { try await withConnection { connection in + try await connection.query("BEGIN;", logger: logger) do { - try await connection.query("BEGIN;", logger: logger) let value = try await process(connection) try await connection.query("COMMIT;", logger: logger) return value diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 9daae857..6e9e843d 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -52,29 +52,80 @@ final class PostgresClientTests: XCTestCase { try await eventLoopGroup.shutdownGracefully() } + let tableName = "test_client_trasactions" + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) - await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - await client.run() - } - - let iterations = 1000 - - for _ in 0..