diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml
index 5b84cb4ebc3..c6796085d83 100644
--- a/google-cloud-spanner/clirr-ignored-differences.xml
+++ b/google-cloud-spanner/clirr-ignored-differences.xml
@@ -814,6 +814,12 @@
com/google/cloud/spanner/connection/TransactionRetryListener
void retryDmlAsPartitionedDmlFailed(java.util.UUID, com.google.cloud.spanner.Statement, java.lang.Throwable)
-
+
+
+
+ 7012
+ com/google/cloud/spanner/connection/Connection
+ java.lang.Object runTransaction(com.google.cloud.spanner.connection.Connection$TransactionCallable)
+
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java
index cafb27ba6b7..b1d37f3e4cd 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java
@@ -99,7 +99,7 @@ public void rollback() {
public TransactionContext resetForRetry() {
if (txn == null || !txn.isAborted() && txnState != TransactionState.ABORTED) {
throw new IllegalStateException(
- "resetForRetry can only be called if the previous attempt" + " aborted");
+ "resetForRetry can only be called if the previous attempt aborted");
}
try (IScope s = tracer.withSpan(span)) {
boolean useInlinedBegin = txn.transactionId != null;
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java
index 547d2466e3e..eb69ae132cc 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java
@@ -835,6 +835,21 @@ default boolean isKeepTransactionAlive() {
*/
ApiFuture rollbackAsync();
+ /** Functional interface for the {@link #runTransaction(TransactionCallable)} method. */
+ interface TransactionCallable {
+ /** This method is invoked with a fresh transaction on the connection. */
+ T run(Connection transaction);
+ }
+
+ /**
+ * Runs the given callable in a transaction. The transaction type is determined by the current
+ * state of the connection. That is; if the connection is in read/write mode, the transaction type
+ * will be a read/write transaction. If the connection is in read-only mode, it will be a
+ * read-only transaction. The transaction will automatically be retried if it is aborted by
+ * Spanner.
+ */
+ T runTransaction(TransactionCallable callable);
+
/** Returns the current savepoint support for this connection. */
SavepointSupport getSavepointSupport();
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
index 2d7c917d230..5ea249ee0ac 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
@@ -194,6 +194,11 @@ private LeakedConnectionException() {
*/
private final ConnectionOptions options;
+ enum Caller {
+ APPLICATION,
+ TRANSACTION_RUNNER,
+ }
+
/** The supported batch modes. */
enum BatchMode {
NONE,
@@ -267,6 +272,9 @@ static UnitOfWorkType of(TransactionMode transactionMode) {
*/
private boolean transactionBeginMarked = false;
+ /** This field is set to true when a transaction runner is active for this connection. */
+ private boolean transactionRunnerActive = false;
+
private BatchMode batchMode;
private UnitOfWorkType unitOfWorkType;
private final Stack transactionStack = new Stack<>();
@@ -1164,16 +1172,19 @@ public void onFailure() {
@Override
public void commit() {
- get(commitAsync(CallType.SYNC));
+ get(commitAsync(CallType.SYNC, Caller.APPLICATION));
}
@Override
public ApiFuture commitAsync() {
- return commitAsync(CallType.ASYNC);
+ return commitAsync(CallType.ASYNC, Caller.APPLICATION);
}
- private ApiFuture commitAsync(CallType callType) {
+ ApiFuture commitAsync(CallType callType, Caller caller) {
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
+ ConnectionPreconditions.checkState(
+ !transactionRunnerActive || caller == Caller.TRANSACTION_RUNNER,
+ "Cannot call commit when a transaction runner is active");
maybeAutoCommitOrFlushCurrentUnitOfWork(COMMIT_STATEMENT.getType(), COMMIT_STATEMENT);
return endCurrentTransactionAsync(callType, commit, COMMIT_STATEMENT);
}
@@ -1201,16 +1212,19 @@ public void onFailure() {
@Override
public void rollback() {
- get(rollbackAsync(CallType.SYNC));
+ get(rollbackAsync(CallType.SYNC, Caller.APPLICATION));
}
@Override
public ApiFuture rollbackAsync() {
- return rollbackAsync(CallType.ASYNC);
+ return rollbackAsync(CallType.ASYNC, Caller.APPLICATION);
}
- private ApiFuture rollbackAsync(CallType callType) {
+ ApiFuture rollbackAsync(CallType callType, Caller caller) {
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
+ ConnectionPreconditions.checkState(
+ !transactionRunnerActive || caller == Caller.TRANSACTION_RUNNER,
+ "Cannot call rollback when a transaction runner is active");
maybeAutoCommitOrFlushCurrentUnitOfWork(ROLLBACK_STATEMENT.getType(), ROLLBACK_STATEMENT);
return endCurrentTransactionAsync(callType, rollback, ROLLBACK_STATEMENT);
}
@@ -1243,6 +1257,27 @@ private ApiFuture endCurrentTransactionAsync(
return res;
}
+ @Override
+ public T runTransaction(TransactionCallable callable) {
+ ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
+ ConnectionPreconditions.checkState(!isBatchActive(), "Cannot run transaction while in a batch");
+ ConnectionPreconditions.checkState(
+ !isTransactionStarted(), "Cannot run transaction when a transaction is already active");
+ ConnectionPreconditions.checkState(
+ !transactionRunnerActive, "A transaction runner is already active for this connection");
+ this.transactionRunnerActive = true;
+ try {
+ return new TransactionRunnerImpl(this).run(callable);
+ } finally {
+ this.transactionRunnerActive = false;
+ }
+ }
+
+ void resetForRetry(UnitOfWork retryUnitOfWork) {
+ retryUnitOfWork.resetForRetry();
+ this.currentUnitOfWork = retryUnitOfWork;
+ }
+
@Override
public SavepointSupport getSavepointSupport() {
return getConnectionPropertyValue(SAVEPOINT_SUPPORT);
@@ -2000,7 +2035,7 @@ private UnitOfWork maybeStartAutoDmlBatch(UnitOfWork transaction) {
return transaction;
}
- private UnitOfWork getCurrentUnitOfWorkOrStartNewUnitOfWork() {
+ UnitOfWork getCurrentUnitOfWorkOrStartNewUnitOfWork() {
return getCurrentUnitOfWorkOrStartNewUnitOfWork(
StatementType.UNKNOWN, /* parsedStatement = */ null, /* internalMetadataQuery = */ false);
}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
index 4ae0ae00608..1f6ab6bf0c6 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
@@ -1261,6 +1261,11 @@ private ApiFuture rollbackAsync(CallType callType, boolean updateStatusAnd
}
}
+ @Override
+ public void resetForRetry() {
+ txContextFuture = ApiFutures.immediateFuture(txManager.resetForRetry());
+ }
+
@Override
String getUnitOfWorkName() {
return "read/write transaction";
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/TransactionRunnerImpl.java
new file mode 100644
index 00000000000..6c959d3e5f9
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/TransactionRunnerImpl.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import static com.google.cloud.spanner.SpannerApiFutures.get;
+
+import com.google.cloud.spanner.AbortedException;
+import com.google.cloud.spanner.SpannerExceptionFactory;
+import com.google.cloud.spanner.connection.Connection.TransactionCallable;
+import com.google.cloud.spanner.connection.ConnectionImpl.Caller;
+import com.google.cloud.spanner.connection.UnitOfWork.CallType;
+
+class TransactionRunnerImpl {
+ private final ConnectionImpl connection;
+
+ TransactionRunnerImpl(ConnectionImpl connection) {
+ this.connection = connection;
+ }
+
+ T run(TransactionCallable callable) {
+ connection.beginTransaction();
+ // Disable internal retries during this transaction.
+ connection.setRetryAbortsInternally(/* retryAbortsInternally = */ false, /* local = */ true);
+ UnitOfWork transaction = connection.getCurrentUnitOfWorkOrStartNewUnitOfWork();
+ while (true) {
+ try {
+ T result = callable.run(connection);
+ get(connection.commitAsync(CallType.SYNC, Caller.TRANSACTION_RUNNER));
+ return result;
+ } catch (AbortedException abortedException) {
+ try {
+ //noinspection BusyWait
+ Thread.sleep(abortedException.getRetryDelayInMillis());
+ connection.resetForRetry(transaction);
+ } catch (InterruptedException interruptedException) {
+ connection.rollbackAsync(CallType.SYNC, Caller.TRANSACTION_RUNNER);
+ throw SpannerExceptionFactory.propagateInterrupt(interruptedException);
+ } catch (Throwable t) {
+ connection.rollbackAsync(CallType.SYNC, Caller.TRANSACTION_RUNNER);
+ throw t;
+ }
+ } catch (Throwable t) {
+ connection.rollbackAsync(CallType.SYNC, Caller.TRANSACTION_RUNNER);
+ throw t;
+ }
+ }
+ }
+}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/UnitOfWork.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/UnitOfWork.java
index ffa93d486e1..80981922225 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/UnitOfWork.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/UnitOfWork.java
@@ -125,6 +125,10 @@ interface EndTransactionCallback {
ApiFuture rollbackAsync(
@Nonnull CallType callType, @Nonnull EndTransactionCallback callback);
+ default void resetForRetry() {
+ throw new UnsupportedOperationException();
+ }
+
/** @see Connection#savepoint(String) */
void savepoint(@Nonnull String name, @Nonnull Dialect dialect);
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RunTransactionMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RunTransactionMockServerTest.java
new file mode 100644
index 00000000000..91662ef8668
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RunTransactionMockServerTest.java
@@ -0,0 +1,226 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.spanner.ErrorCode;
+import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.SpannerException;
+import com.google.spanner.v1.CommitRequest;
+import com.google.spanner.v1.ExecuteSqlRequest;
+import com.google.spanner.v1.RollbackRequest;
+import io.grpc.Status;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class RunTransactionMockServerTest extends AbstractMockServerTest {
+
+ @Test
+ public void testRunTransaction() {
+ try (Connection connection = createConnection()) {
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ return null;
+ });
+ }
+ assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testRunTransactionInAutoCommit() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(true);
+
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ return null;
+ });
+ }
+ assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testRunTransactionInReadOnly() {
+ try (Connection connection = createConnection()) {
+ connection.setReadOnly(true);
+ connection.setAutocommit(false);
+
+ assertEquals(
+ RANDOM_RESULT_SET_ROW_COUNT,
+ connection
+ .runTransaction(
+ transaction -> {
+ int rows = 0;
+ try (ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) {
+ while (resultSet.next()) {
+ rows++;
+ }
+ }
+ return rows;
+ })
+ .intValue());
+ }
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class));
+ assertEquals(0, mockSpanner.countRequestsOfType(RollbackRequest.class));
+ }
+
+ @Test
+ public void testRunTransaction_rollbacksAfterException() {
+ try (Connection connection = createConnection()) {
+ SpannerException exception =
+ assertThrows(
+ SpannerException.class,
+ () ->
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ mockSpanner.setExecuteSqlExecutionTime(
+ SimulatedExecutionTime.ofException(
+ Status.INVALID_ARGUMENT
+ .withDescription("invalid statement")
+ .asRuntimeException()));
+ // This statement will fail.
+ transaction.executeUpdate(INSERT_STATEMENT);
+ return null;
+ }));
+ assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode());
+ assertTrue(exception.getMessage(), exception.getMessage().endsWith("invalid statement"));
+ }
+ assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class));
+ }
+
+ @Test
+ public void testRunTransactionCommitAborted() {
+ final AtomicInteger attempts = new AtomicInteger();
+ try (Connection connection = createConnection()) {
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ if (attempts.incrementAndGet() == 1) {
+ mockSpanner.abortNextStatement();
+ }
+ return null;
+ });
+ }
+ assertEquals(2, attempts.get());
+ assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testRunTransactionDmlAborted() {
+ final AtomicInteger attempts = new AtomicInteger();
+ try (Connection connection = createConnection()) {
+ assertTrue(connection.isRetryAbortsInternally());
+ connection.runTransaction(
+ transaction -> {
+ assertFalse(transaction.isRetryAbortsInternally());
+ if (attempts.incrementAndGet() == 1) {
+ mockSpanner.abortNextStatement();
+ }
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ return null;
+ });
+ assertTrue(connection.isRetryAbortsInternally());
+ }
+ assertEquals(2, attempts.get());
+ assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testRunTransactionQueryAborted() {
+ final AtomicInteger attempts = new AtomicInteger();
+ try (Connection connection = createConnection()) {
+ int rowCount =
+ connection.runTransaction(
+ transaction -> {
+ if (attempts.incrementAndGet() == 1) {
+ mockSpanner.abortNextStatement();
+ }
+ int rows = 0;
+ try (ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) {
+ while (resultSet.next()) {
+ rows++;
+ }
+ }
+ return rows;
+ });
+ assertEquals(RANDOM_RESULT_SET_ROW_COUNT, rowCount);
+ }
+ assertEquals(2, attempts.get());
+ assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testCommitInRunTransaction() {
+ try (Connection connection = createConnection()) {
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ SpannerException exception = assertThrows(SpannerException.class, transaction::commit);
+ assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode());
+ assertEquals(
+ "FAILED_PRECONDITION: Cannot call commit when a transaction runner is active",
+ exception.getMessage());
+ return null;
+ });
+ }
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ }
+
+ @Test
+ public void testRollbackInRunTransaction() {
+ try (Connection connection = createConnection()) {
+ connection.runTransaction(
+ transaction -> {
+ assertEquals(1L, transaction.executeUpdate(INSERT_STATEMENT));
+ SpannerException exception =
+ assertThrows(SpannerException.class, transaction::rollback);
+ assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode());
+ assertEquals(
+ "FAILED_PRECONDITION: Cannot call rollback when a transaction runner is active",
+ exception.getMessage());
+ return null;
+ });
+ }
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class));
+ assertEquals(0, mockSpanner.countRequestsOfType(RollbackRequest.class));
+ }
+}