Skip to content

Commit

Permalink
Add delay before retrying same address
Browse files Browse the repository at this point in the history
Retrying same address without a delay may lead to infinite loop
  • Loading branch information
arhimondr authored and electrum committed Dec 13, 2018
1 parent 035d946 commit 800c823
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package io.airlift.drift.client;

import com.google.common.base.Ticker;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
Expand Down Expand Up @@ -71,6 +73,8 @@ class DriftMethodInvocation<A extends Address>
@GuardedBy("this")
private final Set<A> attemptedAddresses = new LinkedHashSet<>();
@GuardedBy("this")
private final Multiset<A> failedConnectionAttempts = HashMultiset.create();
@GuardedBy("this")
private int failedConnections;
@GuardedBy("this")
private int overloadedRejects;
Expand Down Expand Up @@ -104,7 +108,7 @@ static <A extends Address> DriftMethodInvocation<A> createDriftMethodInvocation(
stat,
ticker);
// invocation can not be started from constructor, because it may start threads that can call back into the unpublished object
invocation.nextAttempt();
invocation.nextAttempt(true);
return invocation;
}

Expand Down Expand Up @@ -138,7 +142,7 @@ private DriftMethodInvocation(
}, directExecutor());
}

private synchronized void nextAttempt()
private synchronized void nextAttempt(boolean noConnectDelay)
{
try {
// request was already canceled
Expand All @@ -156,8 +160,32 @@ private synchronized void nextAttempt()
stat.recordRetry();
}

if (noConnectDelay) {
invoke(address.get());
return;
}

int connectionFailuresCount = failedConnectionAttempts.count(address.get());
if (connectionFailuresCount == 0) {
invoke(address.get());
return;
}

Duration connectDelay = retryPolicy.getBackoffDelay(connectionFailuresCount);
log.debug("Failed connection to %s with attempt %s, will retry in %s", address.get(), connectionFailuresCount, connectDelay);
schedule(connectDelay, () -> invoke(address.get()));
}
catch (Throwable t) {
// this should never happen, but ensure that invocation always finishes
unexpectedError(t);
}
}

private synchronized void invoke(A address)
{
try {
long invocationStartTime = ticker.read();
ListenableFuture<Object> result = invoker.invoke(new InvokeRequest(metadata, address.get(), headers, parameters));
ListenableFuture<Object> result = invoker.invoke(new InvokeRequest(metadata, address, headers, parameters));
stat.recordResult(invocationStartTime, result);
currentTask = result;

Expand All @@ -166,13 +194,14 @@ private synchronized void nextAttempt()
@Override
public void onSuccess(Object result)
{
resetConnectionFailures(address);
set(result);
}

@Override
public void onFailure(Throwable t)
{
handleFailure(address.get(), t);
handleFailure(address, t);
}
},
directExecutor());
Expand All @@ -183,6 +212,11 @@ public void onFailure(Throwable t)
}
}

private synchronized void resetConnectionFailures(A address)
{
failedConnectionAttempts.setCount(address, 0);
}

private synchronized void handleFailure(A address, Throwable throwable)
{
try {
Expand All @@ -199,12 +233,12 @@ private synchronized void handleFailure(A address, Throwable throwable)
lastException = throwable;
invocationAttempts++;
}
else if (exceptionClassification.getHostStatus() == DOWN) {
else if (exceptionClassification.getHostStatus() == DOWN || exceptionClassification.getHostStatus() == OVERLOADED) {
addressSelector.markdown(address);
}
else if (exceptionClassification.getHostStatus() == OVERLOADED) {
addressSelector.markdown(address);
overloadedRejects++;
failedConnectionAttempts.add(address);
if (exceptionClassification.getHostStatus() == OVERLOADED) {
overloadedRejects++;
}
}

// should retry?
Expand All @@ -224,9 +258,11 @@ else if (exceptionClassification.getHostStatus() == OVERLOADED) {
return;
}

// A request to down or overloaded server is not counted as an attempt, and retries are not delayed
// A request to down or overloaded server is not counted as an attempt
// Retries are not delayed based on the invocationAttempts, but may be delayed
// based on the failed connection attempts for a selected address
if (exceptionClassification.getHostStatus() != NORMAL) {
nextAttempt();
nextAttempt(false);
return;
}

Expand All @@ -238,22 +274,32 @@ else if (exceptionClassification.getHostStatus() == OVERLOADED) {
backoffDelay,
overloadedRejects,
throwable.getMessage());
schedule(backoffDelay, () -> nextAttempt(true));
}
catch (Throwable t) {
// this should never happen, but ensure that invocation always finishes
unexpectedError(t);
}
}

ListenableFuture<?> delay = invoker.delay(backoffDelay);
private synchronized void schedule(Duration timeout, Runnable task)
{
try {
ListenableFuture<?> delay = invoker.delay(timeout);
currentTask = delay;
Futures.addCallback(delay, new FutureCallback<Object>()
{
@Override
public void onSuccess(Object result)
{
nextAttempt();
task.run();
}

@Override
public void onFailure(Throwable throwable)
public void onFailure(Throwable t)
{
// this should never happen in a delay future
unexpectedError(throwable);
unexpectedError(t);
}
},
directExecutor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.airlift.drift.TException;
import io.airlift.drift.client.ExceptionClassification.HostStatus;
import io.airlift.drift.client.address.AddressSelector;
import io.airlift.drift.client.address.SimpleAddressSelector.SimpleAddress;
import io.airlift.drift.codec.ThriftCodec;
import io.airlift.drift.codec.internal.builtin.ShortThriftCodec;
import io.airlift.drift.protocol.TTransportException;
Expand All @@ -39,7 +40,7 @@

import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -370,6 +371,57 @@ private static void testConnectionFailed(int expectedInvocationAttempts, int fai
}
}

@Test
public void testConnectionFailedDelay()
throws Exception
{
testConnectionFailedDelay(0, 0, 0);
testConnectionFailedDelay(1, 1, 0);
testConnectionFailedDelay(10, 1, 0);
testConnectionFailedDelay(1, 2, 1);
testConnectionFailedDelay(2, 2, 2);
testConnectionFailedDelay(10, 2, 10);
testConnectionFailedDelay(10, 5, 40);
}

private static void testConnectionFailedDelay(int numberOfAddresses, int numberOfRetriesPerAddress, int expectedDelays)
throws Exception
{
testConnectionFailedDelay(false, numberOfAddresses, numberOfRetriesPerAddress, expectedDelays);
testConnectionFailedDelay(true, numberOfAddresses, numberOfRetriesPerAddress, expectedDelays);
}

private static void testConnectionFailedDelay(boolean overloaded, int numberOfAddresses, int numberOfRetriesPerAddress, int expectedDelays)
throws Exception
{
ImmutableList.Builder<Address> addresses = ImmutableList.builder();
for (int i = 0; i < numberOfAddresses; i++) {
Address address = createTestingAddress(20_000 + i);
for (int j = 0; j < numberOfRetriesPerAddress; j++) {
addresses.add(address);
}
}

MockMethodInvoker invoker = new MockMethodInvoker(request -> immediateFailedFuture(createClassifiedException(true, overloaded ? OVERLOADED : DOWN)));
DriftMethodInvocation<?> methodInvocation = createDriftMethodInvocation(
new RetryPolicy(new DriftClientConfig(), new TestingExceptionClassifier()),
new TestingMethodInvocationStat(),
invoker,
new TestingAddressSelector(addresses.build()),
systemTicker());

try {
methodInvocation.get();
fail("Expected exception");
}
catch (ExecutionException e) {
assertTrue(e.getCause() instanceof TTransportException);
TTransportException transportException = (TTransportException) e.getCause();
assertTrue(transportException.getMessage().startsWith("No hosts available"));
}
assertEquals(invoker.getDelays().size(), expectedDelays);
}

@Test(timeOut = 60000)
public void testExceptionFromInvokerInvoke()
throws Exception
Expand Down Expand Up @@ -711,6 +763,11 @@ private static void assertDelays(MockMethodInvoker invoker, RetryPolicy retryPol
.collect(toImmutableList()));
}

private static Address createTestingAddress(int port)
{
return new SimpleAddress(HostAndPort.fromParts("localhost", port));
}

private static class TestingExceptionClassifier
implements ExceptionClassifier
{
Expand Down Expand Up @@ -758,10 +815,10 @@ public ExceptionClassification getClassification()
public static class TestingAddressSelector
implements AddressSelector<Address>
{
private final int maxAddresses;
private List<Address> addresses;

@GuardedBy("this")
private final List<HostAndPort> markdownHosts = new ArrayList<>();
private final Set<Address> markdownHosts = new HashSet<>();

@GuardedBy("this")
private int addressCount;
Expand All @@ -771,7 +828,19 @@ public static class TestingAddressSelector

public TestingAddressSelector(int maxAddresses)
{
this.maxAddresses = maxAddresses;
this(createAddresses(maxAddresses));
}

private static List<Address> createAddresses(int count)
{
return IntStream.range(0, count)
.mapToObj(i -> createTestingAddress(20_000 + i))
.collect(toImmutableList());
}

public TestingAddressSelector(List<Address> addresses)
{
this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null"));
}

@Override
Expand All @@ -784,24 +853,21 @@ public synchronized Optional<Address> selectAddress(Optional<String> addressSele
public synchronized Optional<Address> selectAddress(Optional<String> addressSelectionContext, Set<Address> attempted)
{
lastAttemptedSet = ImmutableSet.copyOf(attempted);
if (addressCount >= maxAddresses) {
if (addressCount >= addresses.size()) {
return Optional.empty();
}
HostAndPort hostAndPort = HostAndPort.fromParts("localhost", 20_000 + addressCount++);
return Optional.of(() -> hostAndPort);
return Optional.of(addresses.get(addressCount++));
}

@Override
public synchronized void markdown(Address address)
{
markdownHosts.add(address.getHostAndPort());
markdownHosts.add(address);
}

public synchronized void assertAllDown()
{
assertEquals(markdownHosts, IntStream.range(0, addressCount)
.mapToObj(i -> HostAndPort.fromParts("localhost", 20_000 + i))
.collect(toImmutableList()));
assertEquals(markdownHosts, ImmutableSet.copyOf(addresses));
}

public synchronized Set<Address> getLastAttemptedSet()
Expand Down

0 comments on commit 800c823

Please sign in to comment.