diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java index 50f8fb52..728dd992 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java @@ -98,7 +98,10 @@ public class GcpManagedChannel extends ManagedChannel { private final ExecutorService stateNotificationExecutor = Executors.newCachedThreadPool( new ThreadFactoryBuilder().setNameFormat("gcp-mc-state-notifications-%d").build()); - private List stateChangeCallbacks = Collections.synchronizedList(new LinkedList<>()); + + // Callbacks to call when state changes. + @GuardedBy("this") + private List stateChangeCallbacks = new LinkedList<>(); // Metrics configuration. private MetricRegistry metricRegistry; @@ -882,16 +885,19 @@ private void recordUnresponsiveDetection(long nanos, long dropCount) { @Override public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) { - if (!getState(false).equals(source)) { - try { - stateNotificationExecutor.execute(callback); - } catch (RejectedExecutionException e) { - // Ignore exceptions on shutdown. - logger.fine(log("State notification change task rejected: %s", e.getMessage())); + if (getState(false).equals(source)) { + synchronized (this) { + stateChangeCallbacks.add(callback); } return; } - stateChangeCallbacks.add(callback); + + try { + stateNotificationExecutor.execute(callback); + } catch (RejectedExecutionException e) { + // Ignore exceptions on shutdown. + logger.fine(log("State notification change task rejected: %s", e.getMessage())); + } } /** diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java index 2bd51f5a..4041ceaa 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java @@ -32,6 +32,7 @@ import com.google.cloud.grpc.proto.ApiConfig; import com.google.cloud.grpc.proto.ChannelPoolConfig; import com.google.cloud.grpc.proto.MethodConfig; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.spanner.v1.PartitionReadRequest; import com.google.spanner.v1.TransactionSelector; import io.grpc.CallOptions; @@ -1265,6 +1266,69 @@ public void run() { .isAnyOf(ConnectivityState.CONNECTING, ConnectivityState.TRANSIENT_FAILURE); } + @Test + public void testParallelStateNotifications() throws InterruptedException { + AtomicReference exception = new AtomicReference<>(); + + ExecutorService grpcExecutor = Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setUncaughtExceptionHandler((t, e) -> + exception.set(e) + ).build() + ); + + ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress(TARGET, 443); + GcpManagedChannel pool = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder( + builder) + .executor(grpcExecutor) + .withOptions(GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions(GcpChannelPoolOptions.newBuilder() + .setMaxSize(1) + .build()) + .build()) + .build(); + + // Pre-populate with a fake channel to control state changes. + FakeManagedChannel channel = new FakeManagedChannel(grpcExecutor); + ChannelRef ref = pool.new ChannelRef(channel, 0); + pool.channelRefs.add(ref); + + // Always re-subscribe for notification to have constant callbacks flowing. + final Runnable callback = new Runnable() { + @Override + public void run() { + ConnectivityState state = pool.getState(false); + pool.notifyWhenStateChanged(state, this); + } + }; + + // Update channels state and subscribe for pool state changes in parallel. + final ExecutorService executor = Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("gcp-mc-test-%d").build()); + + for (int i = 0; i < 300; i++) { + executor.execute(() -> { + ConnectivityState currentState = pool.getState(true); + pool.notifyWhenStateChanged(currentState, callback); + }); + executor.execute(() -> { + channel.setState(ConnectivityState.IDLE); + channel.setState(ConnectivityState.CONNECTING); + }); + } + + executor.shutdown(); + //noinspection StatementWithEmptyBody + while (!executor.awaitTermination(10, TimeUnit.MILLISECONDS)) {} + + channel.setState(ConnectivityState.SHUTDOWN); + pool.shutdownNow(); + + // Make sure no exceptions were raised in callbacks. + assertThat(exception.get()).isNull(); + + grpcExecutor.shutdown(); + } + @Test public void testParallelGetChannelRefWontExceedMaxSize() throws InterruptedException { resetGcpChannel();