Skip to content

Commit

Permalink
Issue 17: Retry http errors #18
Browse files Browse the repository at this point in the history
Retry on certain HTTP errors.

Signed-off-by: Christophe Balczunas <[email protected]>
  • Loading branch information
Ravi Sharda authored and Christophe Balczunas committed Aug 24, 2020
2 parents 5ea42a3 + 2fcd492 commit 742c2fd
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
* ( Strangely, things like providing a bad client secret result in a 400 error )
*/
public class KeycloakAuthenticationException extends RuntimeException {

public KeycloakAuthenticationException(Throwable e) {
super("Authentication failure", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

package io.pravega.keycloak.client;

import io.pravega.common.util.Retry;
import org.apache.http.HttpStatus;
import org.keycloak.authorization.client.AuthzClient;
import org.keycloak.authorization.client.ClientAuthenticator;
Expand All @@ -27,10 +28,12 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.ConnectException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;

/**
* Wrapper to manage a service account obtaining access tokens and RPTs for a given audience
Expand All @@ -40,9 +43,13 @@ public class KeycloakAuthzClient {

public static final String DEFAULT_PRAVEGA_CONTROLLER_CLIENT_ID = "pravega-controller";
private static final int DEFAULT_TOKEN_MIN_TIME_TO_LIVE_SECS = 60;
private static final int DEFAULT_HTTP_MAX_RETRIES = 20;
private static final int DEFAULT_HTTP_INITIAL_DELAY_MS = 100;

private final AuthzClient client;
private final TokenCache tokenCache;
private final int httpMaxRetries;
private final int httpInitialDelayMs;

/**
* Builds a Keycloak authorization client.
Expand All @@ -54,6 +61,15 @@ public static Builder builder() {
public KeycloakAuthzClient(AuthzClient client, TokenCache tokenCache) {
this.client = client;
this.tokenCache = tokenCache;
this.httpMaxRetries = DEFAULT_HTTP_MAX_RETRIES;
this.httpInitialDelayMs = DEFAULT_HTTP_INITIAL_DELAY_MS;
}

public KeycloakAuthzClient(AuthzClient client, TokenCache tokenCache, int httpMaxRetries, int httpInitialDelayMs) {
this.client = client;
this.tokenCache = tokenCache;
this.httpMaxRetries = httpMaxRetries;
this.httpInitialDelayMs = httpInitialDelayMs;
}

/**
Expand All @@ -72,27 +88,23 @@ public String getRPT() {
// obtain an access token with which to make an authorization request
AccessTokenResponse accResponse;
try {
accResponse = client.obtainAccessToken();
accResponse = Retry.withExpBackoff(httpInitialDelayMs, 2, httpMaxRetries)
.retryWhen(isRetryable()).run(() -> client.obtainAccessToken());
LOG.debug("Obtained access token from Keycloak");
} catch (HttpResponseException e) {
LOG.error("Failed to obtain access token from Keycloak", e);
if (e.getStatusCode() == HttpStatus.SC_BAD_REQUEST) {
throw new KeycloakAuthenticationException(e);
}
throw e;
throw new KeycloakAuthenticationException(e);
}

// obtain an RPT
// obtain a Relying Party Token (RPT)
AuthorizationRequest request = new AuthorizationRequest();
try {
token = client.authorization(accResponse.getToken()).authorize(request);
token = Retry.withExpBackoff(httpInitialDelayMs, 2, httpMaxRetries)
.retryWhen(isRetryable()).run(() -> client.authorization(accResponse.getToken()).authorize(request));
LOG.debug("Obtained RPT from Keycloak");
} catch (HttpResponseException e) {
LOG.error("Failed to obtain RPT from Keycloak", e);
if (e.getStatusCode() == HttpStatus.SC_BAD_REQUEST) {
throw new KeycloakAuthorizationException(e);
}
throw e;
throw new KeycloakAuthorizationException(e);
}

// update the token cache
Expand All @@ -104,6 +116,51 @@ public String getRPT() {
return token.getToken();
}

/**
* Predicate to determine what is retryable and what is not.
* HttpResponseException with error code of 400, 401, 403 are not retryable.
* All other HttpResponseException are retryable as well as java.net.ConnectException.
* All others are not retryable.
* @return
*/
private static Predicate<Throwable> isRetryable() {
return e -> {
Throwable rootCause = unwrap(e);
if (rootCause instanceof HttpResponseException) {
int statusCode = ((HttpResponseException) e).getStatusCode();
if (statusCode == HttpStatus.SC_BAD_REQUEST ||
statusCode == HttpStatus.SC_UNAUTHORIZED ||
statusCode == HttpStatus.SC_FORBIDDEN ) {
// these are authN or authZ related errors.
LOG.error("Non retryable HttpResponseException with HTTP code: {}", statusCode);
return false;
} else { // 5xx errors etc.
LOG.warn("Retryable HttpResponseException with HTTP code: {}", statusCode);
return true;
}
} else if (rootCause instanceof ConnectException) {
LOG.warn("Retryable connection exception", rootCause);
return true;
} else {
// random unexpected Exceptions, don't retry, these should be debugged.
LOG.error("Other non retryable exception", rootCause);
return false;
}
};
}

/**
* Fully and uncondtionally unwraps an exception to get the cause.
* @param e the exception to unwrap
* @return the unwrapped exception
*/
private static Throwable unwrap(Throwable e) {
if (e.getCause() != null) {
return unwrap(e.getCause());
}
return e;
}

/**
* Deserialize a raw access token into an AccessToken object.
*
Expand Down Expand Up @@ -152,10 +209,14 @@ public static class Builder {
private String audience;
private String configFile;
private BiFunction<Configuration, ClientAuthenticator, AuthzClient> clientSupplier;
private int httpMaxRetries;
private int httpInitialDelayMs;

public Builder() {
audience = DEFAULT_PRAVEGA_CONTROLLER_CLIENT_ID;
clientSupplier = AuthzClient::create;
httpMaxRetries = DEFAULT_HTTP_MAX_RETRIES;
httpInitialDelayMs = DEFAULT_HTTP_INITIAL_DELAY_MS;
}

/**
Expand All @@ -178,6 +239,20 @@ public Builder withAudience(String audience) {
return this;
}

/**
* Sets max http retries and initial delay different than the default for
* the creation of the AuthzClient, as well as the calls to Keycloak inside the getRPT() method
* of the KeycloakAuthzClient
* @param httpMaxRetries maximum retries for the request
* @param httpInitialDelayMs initial delay between retries (will be exponentially increased by a factor of 2)
* @return
*/
public Builder withCustomRetries(int httpMaxRetries, int httpInitialDelayMs) {
this.httpInitialDelayMs = httpInitialDelayMs;
this.httpMaxRetries = httpMaxRetries;
return this;
}

/**
* Sets the Keycloak {@link AuthzClient} authz client provider. For test purposes only.
* @param clientSupplier a function which maps to an authz client.
Expand Down Expand Up @@ -208,12 +283,13 @@ public KeycloakAuthzClient build() {

// create the Keycloak authz client
ClientAuthenticator authenticator = createClientAuthenticator(configuration.getResource(), (String) configuration.getCredentials().get("secret"));
AuthzClient client = clientSupplier.apply(configuration, authenticator);
AuthzClient client = Retry.withExpBackoff(httpInitialDelayMs, 2, httpMaxRetries)
.retryWhen(isRetryable()).run(() -> clientSupplier.apply(configuration, authenticator));

// hack: convey the intended audience by setting the configuration resource
configuration.setResource(audience);

return new KeycloakAuthzClient(client, new TokenCache(configuration.getTokenMinimumTimeToLive()));
return new KeycloakAuthzClient(client, new TokenCache(configuration.getTokenMinimumTimeToLive()), httpMaxRetries, httpInitialDelayMs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

package io.pravega.keycloak.client;

import io.pravega.common.util.RetriesExhaustedException;
import io.pravega.keycloak.client.KeycloakAuthzClient.TokenCache;
import io.pravega.keycloak.client.helpers.AccessTokenBuilder;
import io.pravega.keycloak.client.helpers.AccessTokenIssuer;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.authorization.client.AuthzClient;
import org.keycloak.authorization.client.ClientAuthenticator;
Expand All @@ -25,6 +27,7 @@
import org.mockito.Mockito;

import java.io.File;
import java.net.ConnectException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -38,12 +41,10 @@

public class KeycloakAuthzClientTest {
private static final String SVC_ACCOUNT_JSON_FILE = getResourceFile("service-account.json");
// private static final KeycloakDeployment DEPLOYMENT = KeycloakDeploymentResolver.resolve(SVC_ACCOUNT_JSON_FILE).get();

private static final AccessTokenIssuer ISSUER = new AccessTokenIssuer();

@Test
public void getRPT_caching() {
public void getRPTCacheHits() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));

Expand All @@ -67,39 +68,87 @@ public void getRPT_caching() {
verify(tokenCache, times(1)).update(any());
}

@Test(expected = KeycloakAuthenticationException.class)
public void getRPT_error_authn() {
@Test
public void getRPTFailsToGetAccessToken() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));
when(client.obtainAccessToken()).thenThrow(new HttpResponseException("", 400, "", null));

KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache);
authzClient.getRPT();
try {
authzClient.getRPT();
Assert.fail();
} catch(KeycloakAuthenticationException e) {
}
verify(client, times(1)).obtainAccessToken();
}
@Test(expected = KeycloakAuthorizationException.class)
public void getRPT_error_authz() {

@Test
public void getRPTCannotExchangeAccessTokenForRPT() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));
AccessTokenResponse accessToken = accessTokenResponse();
when(client.obtainAccessToken()).thenReturn(accessToken);
when(client.authorization(any()).authorize(any())).thenThrow(new HttpResponseException("", 400, "", null));

KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache);
authzClient.getRPT();
try {
authzClient.getRPT();
Assert.fail();
} catch(KeycloakAuthorizationException e) {
}
verify(client.authorization(any()), times(1)).authorize(any());
}

@Test(expected = HttpResponseException.class)
public void getRPT_error_other() {
@Test
public void getRPTWithHttp500Exception() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));

when(client.obtainAccessToken()).thenThrow(new HttpResponseException("", 500, "", null));
KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache, 3, 1);
try {
authzClient.getRPT();
Assert.fail();
} catch(RetriesExhaustedException e) {
}
verify(client, times(3)).obtainAccessToken();
}

KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache);
authzClient.getRPT();
@Test
public void getRPTWithRuntimeConnectException() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));

when(client.obtainAccessToken()).thenThrow(new RuntimeException(new ConnectException()));
KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache,3, 1);
try {
authzClient.getRPT();
Assert.fail();
} catch(RetriesExhaustedException e) {
}
verify(client, times(3)).obtainAccessToken();
}

@Test
public void getRPTWithRandomRuntimeException() {
AuthzClient client = mock(AuthzClient.class, Mockito.RETURNS_DEEP_STUBS);
TokenCache tokenCache = spy(new TokenCache(0));

when(client.obtainAccessToken()).thenThrow(new RuntimeException("bogus"));
KeycloakAuthzClient authzClient = new KeycloakAuthzClient(client, tokenCache, 3, 1);
try {
authzClient.getRPT();
Assert.fail();
} catch(RetriesExhaustedException e) {
Assert.fail();
} catch (RuntimeException e) {
}
verify(client, times(1)).obtainAccessToken();
}

@Test
public void tokenCache_expiration() {
public void tokenCacheExpiration() {
AuthorizationResponse response;
TokenCache tokenCache = new TokenCache(0);

Expand All @@ -119,28 +168,28 @@ public void tokenCache_expiration() {
}

@Test
public void builder_defaultAudience() {
public void builderDefaultAudience() {
TestSupplier supplier = new TestSupplier();
KeycloakAuthzClient.builder().withAuthzClientSupplier(supplier).withConfigFile(SVC_ACCOUNT_JSON_FILE).build();
assertEquals(DEFAULT_PRAVEGA_CONTROLLER_CLIENT_ID, supplier.configuration.getResource());
}

@Test
public void builder_setAudience() {
public void builderSetAudience() {
TestSupplier supplier = new TestSupplier();
KeycloakAuthzClient.builder().withAuthzClientSupplier(supplier).withConfigFile(SVC_ACCOUNT_JSON_FILE)
.withAudience("builder_setAudience").build();
assertEquals("builder_setAudience", supplier.configuration.getResource());
}

@Test(expected = KeycloakConfigurationException.class)
public void builder_noConfig() {
public void builderNoConfig() {
TestSupplier supplier = new TestSupplier();
KeycloakAuthzClient.builder().withAuthzClientSupplier(supplier).build();
}

@Test
public void builder_authenticator() {
public void builderAuthenticator() {
TestSupplier supplier = new TestSupplier();
KeycloakAuthzClient.builder().withAuthzClientSupplier(supplier).withConfigFile(SVC_ACCOUNT_JSON_FILE).build();

Expand Down

0 comments on commit 742c2fd

Please sign in to comment.