diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/SkipTrailersTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/SkipTrailersTest.java index 07ac7deee..935516d83 100644 --- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/SkipTrailersTest.java +++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/SkipTrailersTest.java @@ -15,14 +15,11 @@ */ package com.google.cloud.bigtable.data.v2.stub; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; +import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.when; import com.google.api.core.ApiFuture; import com.google.api.gax.core.NoCredentialsProvider; -import com.google.api.gax.tracing.ApiTracer; import com.google.api.gax.tracing.ApiTracerFactory; import com.google.auto.value.AutoValue; import com.google.bigtable.v2.BigtableGrpc; @@ -43,6 +40,7 @@ import com.google.cloud.bigtable.data.v2.models.TableId; import com.google.cloud.bigtable.data.v2.models.TargetId; import com.google.cloud.bigtable.data.v2.stub.metrics.BigtableTracer; +import com.google.cloud.bigtable.data.v2.stub.metrics.NoopMetricsProvider; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; @@ -56,10 +54,13 @@ import io.grpc.ServerServiceDefinition; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -69,7 +70,6 @@ import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.exceptions.verification.WantedButNotInvoked; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -85,7 +85,7 @@ public class SkipTrailersTest { private Server server; @Mock private ApiTracerFactory tracerFactory; - @Mock private BigtableTracer tracer; + private FakeTracer tracer = new FakeTracer(); private BigtableDataClient client; @@ -95,12 +95,12 @@ public void setUp() throws Exception { server = FakeServiceBuilder.create(hackedService).start(); when(tracerFactory.newTracer(Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(tracer); - when(tracer.inScope()).thenReturn(Mockito.mock(ApiTracer.Scope.class)); BigtableDataSettings.Builder clientBuilder = BigtableDataSettings.newBuilderForEmulator(server.getPort()) .setProjectId(PROJECT_ID) .setInstanceId(INSTANCE_ID) + .setMetricsProvider(NoopMetricsProvider.INSTANCE) .setCredentialsProvider(NoCredentialsProvider.create()); clientBuilder.stubSettings().setEnableSkipTrailers(true).setTracerFactory(tracerFactory); @@ -159,7 +159,7 @@ private void test(Supplier> invoker, T fakeResponse) // Wait for the call to start on the server @SuppressWarnings("unchecked") - ServerRpc rpc = (ServerRpc) hackedService.rpcs.poll(10, TimeUnit.SECONDS); + ServerRpc rpc = (ServerRpc) hackedService.rpcs.poll(30, TimeUnit.SECONDS); Preconditions.checkNotNull( rpc, "Timed out waiting for the call to be received by the mock server"); @@ -173,8 +173,21 @@ private void test(Supplier> invoker, T fakeResponse) Assert.fail("timed out waiting for the trailer optimization future to resolve"); } - verify(tracer, times(1)).operationFinishEarly(); - verify(tracer, never()).operationSucceeded(); + // The tracer will be notified in parallel to the future being resolved + // This normal and expected, but requires the test to wait a bit + for (int i = 10; i > 0; i--) { + try { + assertThat(tracer.getCallCount("operationFinishEarly")).isEqualTo(1); + break; + } catch (AssertionError e) { + if (i > 1) { + Thread.sleep(100); + } else { + throw e; + } + } + } + assertThat(tracer.getCallCount("operationSucceeded")).isEqualTo(0); // clean up rpc.getResponseStream().onCompleted(); @@ -183,9 +196,9 @@ private void test(Supplier> invoker, T fakeResponse) // Since we dont have a way to know exactly when this happens, we poll for (int i = 10; i > 0; i--) { try { - verify(tracer, times(1)).operationSucceeded(); + assertThat(tracer.getCallCount("operationSucceeded")).isEqualTo(1); break; - } catch (WantedButNotInvoked e) { + } catch (AssertionError e) { if (i > 1) { Thread.sleep(100); } else { @@ -195,6 +208,27 @@ private void test(Supplier> invoker, T fakeResponse) } } + static class FakeTracer extends BigtableTracer { + ConcurrentHashMap callCounts = new ConcurrentHashMap<>(); + + @Override + public void operationFinishEarly() { + record("operationFinishEarly"); + } + + @Override + public void operationSucceeded() { + record("operationSucceeded"); + } + + private void record(String op) { + callCounts.computeIfAbsent(op, (ignored) -> new AtomicInteger()).getAndIncrement(); + } + + private int getCallCount(String op) { + return Optional.ofNullable(callCounts.get(op)).map(AtomicInteger::get).orElse(0); + } + } /** * Hack the srvice definition to allow grpc server to simulate delayed trailers. This will augment * the bigtable service definition to promote unary rpcs to server streaming