Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple signing keys to be provided #4632

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,25 @@ public class JwtAuthenticationTests {

public static final String QA_SONG_INDEX_NAME = String.format("song_lyrics_%s", QA_DEPARTMENT);

private static final KeyPair KEY_PAIR = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY = new String(Base64.getEncoder().encode(KEY_PAIR.getPublic().getEncoded()), US_ASCII);
private static final KeyPair KEY_PAIR1 = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY1 = new String(Base64.getEncoder().encode(KEY_PAIR1.getPublic().getEncoded()), US_ASCII);

private static final KeyPair KEY_PAIR2 = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY2 = new String(Base64.getEncoder().encode(KEY_PAIR2.getPublic().getEncoded()), US_ASCII);

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String JWT_AUTH_HEADER = "jwt-auth";

private static final JwtAuthorizationHeaderFactory tokenFactory = new JwtAuthorizationHeaderFactory(
KEY_PAIR.getPrivate(),
private static final JwtAuthorizationHeaderFactory tokenFactory1 = new JwtAuthorizationHeaderFactory(
KEY_PAIR1.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
JWT_AUTH_HEADER
);

private static final JwtAuthorizationHeaderFactory tokenFactory2 = new JwtAuthorizationHeaderFactory(
KEY_PAIR2.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
JWT_AUTH_HEADER
Expand All @@ -108,7 +118,10 @@ public class JwtAuthenticationTests {
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER)
.signingKey(List.of(PUBLIC_KEY1, PUBLIC_KEY2))
.subjectKey(CLAIM_USERNAME)
.rolesKey(CLAIM_ROLES)
).backend("noop");
public static final String SONG_ID_1 = "song-id-01";

Expand Down Expand Up @@ -143,7 +156,7 @@ public static void createTestData() {

@Test
public void shouldAuthenticateWithJwtToken_positive() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateValidToken(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateValidToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -155,7 +168,7 @@ public void shouldAuthenticateWithJwtToken_positive() {

@Test
public void shouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateValidToken(USERNAME_ROOT))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateValidToken(USERNAME_ROOT))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -167,7 +180,7 @@ public void shouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {

@Test
public void shouldAuthenticateWithJwtToken_failureLackingUserName() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateTokenWithoutPreferredUsername(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateTokenWithoutPreferredUsername(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -178,7 +191,7 @@ public void shouldAuthenticateWithJwtToken_failureLackingUserName() {

@Test
public void shouldAuthenticateWithJwtToken_failureExpiredToken() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateExpiredToken(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateExpiredToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -202,7 +215,7 @@ public void shouldAuthenticateWithJwtToken_failureIncorrectFormatOfToken() {
@Test
public void shouldAuthenticateWithJwtToken_failureIncorrectSignature() {
KeyPair incorrectKeyPair = Keys.keyPairFor(SignatureAlgorithm.RS256);
Header header = tokenFactory.generateTokenSignedWithKey(incorrectKeyPair.getPrivate(), USER_SUPERHERO);
Header header = tokenFactory1.generateTokenSignedWithKey(incorrectKeyPair.getPrivate(), USER_SUPERHERO);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -214,7 +227,7 @@ public void shouldAuthenticateWithJwtToken_failureIncorrectSignature() {

@Test
public void shouldReadRolesFromToken_positiveFirstRoleSet() {
Header header = tokenFactory.generateValidToken(USER_SUPERHERO, ROLE_ADMIN, ROLE_DEVELOPER, ROLE_QA);
Header header = tokenFactory1.generateValidToken(USER_SUPERHERO, ROLE_ADMIN, ROLE_DEVELOPER, ROLE_QA);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -228,7 +241,7 @@ public void shouldReadRolesFromToken_positiveFirstRoleSet() {

@Test
public void shouldReadRolesFromToken_positiveSecondRoleSet() {
Header header = tokenFactory.generateValidToken(USER_SUPERHERO, ROLE_CTO, ROLE_CEO, ROLE_VP);
Header header = tokenFactory1.generateValidToken(USER_SUPERHERO, ROLE_CTO, ROLE_CEO, ROLE_VP);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -244,7 +257,7 @@ public void shouldReadRolesFromToken_positiveSecondRoleSet() {
public void shouldExposeTokenClaimsAsUserAttributes_positive() throws IOException {
String[] roles = { ROLE_VP };
Map<String, Object> additionalClaims = Map.of(CLAIM_DEPARTMENT, QA_DEPARTMENT);
Header header = tokenFactory.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
try (RestHighLevelClient client = cluster.getRestHighLevelClient(List.of(header))) {
SearchRequest searchRequest = queryStringQueryRequest(QA_SONG_INDEX_NAME, QUERY_TITLE_MAGNUM_OPUS);

Expand All @@ -261,11 +274,36 @@ public void shouldExposeTokenClaimsAsUserAttributes_positive() throws IOExceptio
public void shouldExposeTokenClaimsAsUserAttributes_negative() throws IOException {
String[] roles = { ROLE_VP };
Map<String, Object> additionalClaims = Map.of(CLAIM_DEPARTMENT, "department-without-access-to-qa-song-index");
Header header = tokenFactory.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
try (RestHighLevelClient client = cluster.getRestHighLevelClient(List.of(header))) {
SearchRequest searchRequest = queryStringQueryRequest(QA_SONG_INDEX_NAME, QUERY_TITLE_MAGNUM_OPUS);

assertThatThrownBy(() -> client.search(searchRequest, DEFAULT), statusException(FORBIDDEN));
}
}

@Test
public void secondKeypairShouldAuthenticateWithJwtToken_positive() {
try (TestRestClient client = cluster.getRestClient(tokenFactory2.generateValidToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(USER_SUPERHERO));
}
}

@Test
public void secondKeypairShouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {
try (TestRestClient client = cluster.getRestClient(tokenFactory2.generateValidToken(USERNAME_ROOT))) {

HttpResponse response = client.getAuthInfo();

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(USERNAME_ROOT));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ public class JwtAuthenticationWithUrlParamTests {
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM)
.signingKey(List.of(PUBLIC_KEY))
.subjectKey(CLAIM_USERNAME)
.rolesKey(CLAIM_ROLES)
).backend("noop");

@Rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*/
package org.opensearch.test.framework;

import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand All @@ -19,7 +20,7 @@
public class JwtConfigBuilder {
private String jwtHeader;
private String jwtUrlParameter;
private String signingKey;
private List<String> signingKeys;
private String subjectKey;
private String rolesKey;

Expand All @@ -33,8 +34,8 @@ public JwtConfigBuilder jwtUrlParameter(String jwtUrlParameter) {
return this;
}

public JwtConfigBuilder signingKey(String signingKey) {
stephen-crawford marked this conversation as resolved.
Show resolved Hide resolved
this.signingKey = signingKey;
public JwtConfigBuilder signingKey(List<String> signingKeys) {
this.signingKeys = signingKeys;
return this;
}

Expand All @@ -50,10 +51,10 @@ public JwtConfigBuilder rolesKey(String rolesKey) {

public Map<String, Object> build() {
Builder<String, Object> builder = new Builder<>();
if (Objects.isNull(signingKey)) {
if (Objects.isNull(signingKeys)) {
throw new IllegalStateException("Signing key is required.");
}
builder.put("signing_key", signingKey);
builder.put("signing_key", signingKeys);
if (isNoneBlank(jwtHeader)) {
builder.put("jwt_header", jwtHeader);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ public static class AuthcDomain implements ToXContentObject {
).httpAuthenticator("basic").backend("internal");

public final static AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain("jwt", 1).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtHeader(AUTHORIZATION).signingKey(PUBLIC_KEY)
new JwtConfigBuilder().jwtHeader(AUTHORIZATION).signingKey(List.of(PUBLIC_KEY))
).backend("noop");

private final String id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -54,7 +55,7 @@
private static final Pattern BASIC = Pattern.compile("^\\s*Basic\\s.*", Pattern.CASE_INSENSITIVE);
private static final String BEARER = "bearer ";

private final JwtParser jwtParser;
private final List<JwtParser> jwtParsers = new ArrayList<>();
private final String jwtHeaderName;
private final boolean isDefaultAuthHeader;
private final String jwtUrlParameter;
Expand All @@ -67,7 +68,8 @@
public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
super();

String signingKey = settings.get("signing_key");
List<String> signingKeys = settings.getAsList("signing_key");
stephen-crawford marked this conversation as resolved.
Show resolved Hide resolved

jwtUrlParameter = settings.get("jwt_url_parameter");
jwtHeaderName = settings.get("jwt_header", AUTHORIZATION);
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
Expand All @@ -83,19 +85,23 @@
);
}

final JwtParserBuilder jwtParserBuilder = KeyUtils.createJwtParserBuilderFromSigningKey(signingKey, log);
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}

final SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());
for (String key : signingKeys) {
JwtParser jwtParser;
final JwtParserBuilder jwtParserBuilder = KeyUtils.createJwtParserBuilderFromSigningKey(key, log);
if (jwtParserBuilder == null) {
jwtParser = null;

Check warning on line 92 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L92

Added line #L92 was not covered by tests
} else {
if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}

final SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());

Check warning on line 100 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L100

Added line #L100 was not covered by tests
}
jwtParser = AccessController.doPrivileged((PrivilegedAction<JwtParser>) jwtParserBuilder::build);
}
jwtParser = AccessController.doPrivileged((PrivilegedAction<JwtParser>) jwtParserBuilder::build);
jwtParsers.add(jwtParser);
}
}

Expand All @@ -120,7 +126,8 @@
}

private AuthCredentials extractCredentials0(final SecurityRequest request) {
if (jwtParser == null) {

if (jwtParsers.isEmpty() || jwtParsers.getFirst() == null) {
log.error("Missing Signing Key. JWT authentication will not work");
return null;
}
Expand Down Expand Up @@ -157,39 +164,43 @@
}
}

try {
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();
for (JwtParser jwtParser : jwtParsers) {
try {

if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();

final String subject = extractSubject(claims, request);
if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}

if (subject == null) {
log.error("No subject found in JWT token");
return null;
}
final String subject = extractSubject(claims, request);

final String[] roles = extractRoles(claims, request);
if (subject == null) {
log.error("No subject found in JWT token");
return null;
}

final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();
final String[] roles = extractRoles(claims, request);

for (Entry<String, Object> claim : claims.entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}
final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();

return ac;
for (Entry<String, Object> claim : claims.entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}

} catch (WeakKeyException e) {
log.error("Cannot authenticate user with JWT because of ", e);
return null;
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Invalid or expired JWT token.", e);
return ac;

} catch (WeakKeyException e) {
log.error("Cannot authenticate user with JWT because of ", e);
return null;

Check warning on line 195 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L193-L195

Added lines #L193 - L195 were not covered by tests
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Invalid or expired JWT token.", e);

Check warning on line 198 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L198

Added line #L198 was not covered by tests
}
}
return null;
}
log.error("Failed to parse JWT token using any of the available parsers");
return null;
}

private void assertValidAudienceClaim(Claims claims) throws BadJWTException {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/security/util/KeyUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public JwtParserBuilder run() {
PublicKey key = null;

final String minimalKeyFormat = signingKey.replace("-----BEGIN PUBLIC KEY-----\n", "")
.replace("-----END PUBLIC KEY-----", "");

.replace("-----END PUBLIC KEY-----", "")
.trim();
stephen-crawford marked this conversation as resolved.
Show resolved Hide resolved
final byte[] decoded = Base64.getDecoder().decode(minimalKeyFormat);

try {
Expand Down
Loading
Loading