diff --git a/drift-client/src/main/java/io/airlift/drift/client/DriftMethodInvocation.java b/drift-client/src/main/java/io/airlift/drift/client/DriftMethodInvocation.java index eb6d80e60..d7c264ce8 100644 --- a/drift-client/src/main/java/io/airlift/drift/client/DriftMethodInvocation.java +++ b/drift-client/src/main/java/io/airlift/drift/client/DriftMethodInvocation.java @@ -16,6 +16,8 @@ package io.airlift.drift.client; import com.google.common.base.Ticker; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; @@ -71,6 +73,8 @@ class DriftMethodInvocation @GuardedBy("this") private final Set attemptedAddresses = new LinkedHashSet<>(); @GuardedBy("this") + private final Multiset failedConnectionAttempts = HashMultiset.create(); + @GuardedBy("this") private int failedConnections; @GuardedBy("this") private int overloadedRejects; @@ -104,7 +108,7 @@ static DriftMethodInvocation createDriftMethodInvocation( stat, ticker); // invocation can not be started from constructor, because it may start threads that can call back into the unpublished object - invocation.nextAttempt(); + invocation.nextAttempt(true); return invocation; } @@ -138,7 +142,7 @@ private DriftMethodInvocation( }, directExecutor()); } - private synchronized void nextAttempt() + private synchronized void nextAttempt(boolean noConnectDelay) { try { // request was already canceled @@ -156,8 +160,32 @@ private synchronized void nextAttempt() stat.recordRetry(); } + if (noConnectDelay) { + invoke(address.get()); + return; + } + + int connectionFailuresCount = failedConnectionAttempts.count(address.get()); + if (connectionFailuresCount == 0) { + invoke(address.get()); + return; + } + + Duration connectDelay = retryPolicy.getBackoffDelay(connectionFailuresCount); + log.debug("Failed connection to %s with attempt %s, will retry in %s", address.get(), connectionFailuresCount, connectDelay); + schedule(connectDelay, () -> invoke(address.get())); + } + catch (Throwable t) { + // this should never happen, but ensure that invocation always finishes + unexpectedError(t); + } + } + + private synchronized void invoke(A address) + { + try { long invocationStartTime = ticker.read(); - ListenableFuture result = invoker.invoke(new InvokeRequest(metadata, address.get(), headers, parameters)); + ListenableFuture result = invoker.invoke(new InvokeRequest(metadata, address, headers, parameters)); stat.recordResult(invocationStartTime, result); currentTask = result; @@ -166,13 +194,14 @@ private synchronized void nextAttempt() @Override public void onSuccess(Object result) { + resetConnectionFailures(address); set(result); } @Override public void onFailure(Throwable t) { - handleFailure(address.get(), t); + handleFailure(address, t); } }, directExecutor()); @@ -183,6 +212,11 @@ public void onFailure(Throwable t) } } + private synchronized void resetConnectionFailures(A address) + { + failedConnectionAttempts.setCount(address, 0); + } + private synchronized void handleFailure(A address, Throwable throwable) { try { @@ -199,12 +233,12 @@ private synchronized void handleFailure(A address, Throwable throwable) lastException = throwable; invocationAttempts++; } - else if (exceptionClassification.getHostStatus() == DOWN) { + else if (exceptionClassification.getHostStatus() == DOWN || exceptionClassification.getHostStatus() == OVERLOADED) { addressSelector.markdown(address); - } - else if (exceptionClassification.getHostStatus() == OVERLOADED) { - addressSelector.markdown(address); - overloadedRejects++; + failedConnectionAttempts.add(address); + if (exceptionClassification.getHostStatus() == OVERLOADED) { + overloadedRejects++; + } } // should retry? @@ -224,9 +258,11 @@ else if (exceptionClassification.getHostStatus() == OVERLOADED) { return; } - // A request to down or overloaded server is not counted as an attempt, and retries are not delayed + // A request to down or overloaded server is not counted as an attempt + // Retries are not delayed based on the invocationAttempts, but may be delayed + // based on the failed connection attempts for a selected address if (exceptionClassification.getHostStatus() != NORMAL) { - nextAttempt(); + nextAttempt(false); return; } @@ -238,22 +274,32 @@ else if (exceptionClassification.getHostStatus() == OVERLOADED) { backoffDelay, overloadedRejects, throwable.getMessage()); + schedule(backoffDelay, () -> nextAttempt(true)); + } + catch (Throwable t) { + // this should never happen, but ensure that invocation always finishes + unexpectedError(t); + } + } - ListenableFuture delay = invoker.delay(backoffDelay); + private synchronized void schedule(Duration timeout, Runnable task) + { + try { + ListenableFuture delay = invoker.delay(timeout); currentTask = delay; Futures.addCallback(delay, new FutureCallback() { @Override public void onSuccess(Object result) { - nextAttempt(); + task.run(); } @Override - public void onFailure(Throwable throwable) + public void onFailure(Throwable t) { // this should never happen in a delay future - unexpectedError(throwable); + unexpectedError(t); } }, directExecutor()); diff --git a/drift-client/src/test/java/io/airlift/drift/client/TestDriftMethodInvocation.java b/drift-client/src/test/java/io/airlift/drift/client/TestDriftMethodInvocation.java index b22e279fa..c3cdb5ce9 100644 --- a/drift-client/src/test/java/io/airlift/drift/client/TestDriftMethodInvocation.java +++ b/drift-client/src/test/java/io/airlift/drift/client/TestDriftMethodInvocation.java @@ -25,6 +25,7 @@ import io.airlift.drift.TException; import io.airlift.drift.client.ExceptionClassification.HostStatus; import io.airlift.drift.client.address.AddressSelector; +import io.airlift.drift.client.address.SimpleAddressSelector.SimpleAddress; import io.airlift.drift.codec.ThriftCodec; import io.airlift.drift.codec.internal.builtin.ShortThriftCodec; import io.airlift.drift.protocol.TTransportException; @@ -39,7 +40,7 @@ import javax.annotation.concurrent.GuardedBy; -import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -370,6 +371,57 @@ private static void testConnectionFailed(int expectedInvocationAttempts, int fai } } + @Test + public void testConnectionFailedDelay() + throws Exception + { + testConnectionFailedDelay(0, 0, 0); + testConnectionFailedDelay(1, 1, 0); + testConnectionFailedDelay(10, 1, 0); + testConnectionFailedDelay(1, 2, 1); + testConnectionFailedDelay(2, 2, 2); + testConnectionFailedDelay(10, 2, 10); + testConnectionFailedDelay(10, 5, 40); + } + + private static void testConnectionFailedDelay(int numberOfAddresses, int numberOfRetriesPerAddress, int expectedDelays) + throws Exception + { + testConnectionFailedDelay(false, numberOfAddresses, numberOfRetriesPerAddress, expectedDelays); + testConnectionFailedDelay(true, numberOfAddresses, numberOfRetriesPerAddress, expectedDelays); + } + + private static void testConnectionFailedDelay(boolean overloaded, int numberOfAddresses, int numberOfRetriesPerAddress, int expectedDelays) + throws Exception + { + ImmutableList.Builder
addresses = ImmutableList.builder(); + for (int i = 0; i < numberOfAddresses; i++) { + Address address = createTestingAddress(20_000 + i); + for (int j = 0; j < numberOfRetriesPerAddress; j++) { + addresses.add(address); + } + } + + MockMethodInvoker invoker = new MockMethodInvoker(request -> immediateFailedFuture(createClassifiedException(true, overloaded ? OVERLOADED : DOWN))); + DriftMethodInvocation methodInvocation = createDriftMethodInvocation( + new RetryPolicy(new DriftClientConfig(), new TestingExceptionClassifier()), + new TestingMethodInvocationStat(), + invoker, + new TestingAddressSelector(addresses.build()), + systemTicker()); + + try { + methodInvocation.get(); + fail("Expected exception"); + } + catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TTransportException); + TTransportException transportException = (TTransportException) e.getCause(); + assertTrue(transportException.getMessage().startsWith("No hosts available")); + } + assertEquals(invoker.getDelays().size(), expectedDelays); + } + @Test(timeOut = 60000) public void testExceptionFromInvokerInvoke() throws Exception @@ -711,6 +763,11 @@ private static void assertDelays(MockMethodInvoker invoker, RetryPolicy retryPol .collect(toImmutableList())); } + private static Address createTestingAddress(int port) + { + return new SimpleAddress(HostAndPort.fromParts("localhost", port)); + } + private static class TestingExceptionClassifier implements ExceptionClassifier { @@ -758,10 +815,10 @@ public ExceptionClassification getClassification() public static class TestingAddressSelector implements AddressSelector
{ - private final int maxAddresses; + private List
addresses; @GuardedBy("this") - private final List markdownHosts = new ArrayList<>(); + private final Set
markdownHosts = new HashSet<>(); @GuardedBy("this") private int addressCount; @@ -771,7 +828,19 @@ public static class TestingAddressSelector public TestingAddressSelector(int maxAddresses) { - this.maxAddresses = maxAddresses; + this(createAddresses(maxAddresses)); + } + + private static List
createAddresses(int count) + { + return IntStream.range(0, count) + .mapToObj(i -> createTestingAddress(20_000 + i)) + .collect(toImmutableList()); + } + + public TestingAddressSelector(List
addresses) + { + this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null")); } @Override @@ -784,24 +853,21 @@ public synchronized Optional
selectAddress(Optional addressSele public synchronized Optional
selectAddress(Optional addressSelectionContext, Set
attempted) { lastAttemptedSet = ImmutableSet.copyOf(attempted); - if (addressCount >= maxAddresses) { + if (addressCount >= addresses.size()) { return Optional.empty(); } - HostAndPort hostAndPort = HostAndPort.fromParts("localhost", 20_000 + addressCount++); - return Optional.of(() -> hostAndPort); + return Optional.of(addresses.get(addressCount++)); } @Override public synchronized void markdown(Address address) { - markdownHosts.add(address.getHostAndPort()); + markdownHosts.add(address); } public synchronized void assertAllDown() { - assertEquals(markdownHosts, IntStream.range(0, addressCount) - .mapToObj(i -> HostAndPort.fromParts("localhost", 20_000 + i)) - .collect(toImmutableList())); + assertEquals(markdownHosts, ImmutableSet.copyOf(addresses)); } public synchronized Set
getLastAttemptedSet()