Skip to content

Commit

Permalink
remove all usages of cxf.rs.security.jose
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Mierzwa <[email protected]>
  • Loading branch information
MaciejMierzwa committed Sep 29, 2023
1 parent 48658dd commit 93c1bce
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExc
}
JWK key = keyProvider.getKey(kid);

// TODO algorithm is final in jose implementation. Algorithm is not mandatory for the key material, so we set it to the same as the JWT
// TODO algorithm is final in jose implementation. Algorithm is not mandatory for the key material, so we set it to the same as the JWT, check if it's even necessary
if (key.getAlgorithm() == null && key.getKeyUse() == KeyUse.SIGNATURE && key.getKeyType() == KeyType.RSA) {
// key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import com.nimbusds.jose.jwk.JWK;

public interface KeyProvider {
public JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;

public JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import javax.net.ssl.SSLParameters;
import javax.net.ssl.TrustManagerFactory;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import com.nimbusds.jose.jwk.JWKSet;
import org.apache.hc.core5.function.Callback;
import org.apache.hc.core5.http.ClassicHttpRequest;
import org.apache.hc.core5.http.ClassicHttpResponse;
Expand All @@ -42,8 +42,6 @@
import org.opensearch.security.test.helper.file.FileHelper;
import org.opensearch.security.test.helper.network.SocketUtils;

import static com.amazon.dlic.auth.http.jwt.keybyoidc.CxfTestTools.toJson;

class MockIpdServer implements Closeable {
final static String CTX_DISCOVER = "/discover";
final static String CTX_KEYS = "/api/oauth/keys";
Expand All @@ -52,13 +50,13 @@ class MockIpdServer implements Closeable {
private final int port;
private final String uri;
private final boolean ssl;
private final JsonWebKeys jwks;
private final JWKSet jwks;

MockIpdServer(JsonWebKeys jwks) throws IOException {
MockIpdServer(JWKSet jwks) throws IOException {
this(jwks, SocketUtils.findAvailableTcpPort(), false);
}

MockIpdServer(JsonWebKeys jwks, int port, boolean ssl) throws IOException {
MockIpdServer(JWKSet jwks, int port, boolean ssl) throws IOException {
this.port = port;
this.uri = (ssl ? "https" : "http") + "://localhost:" + port;
this.ssl = ssl;
Expand Down Expand Up @@ -143,7 +141,7 @@ protected void handleDiscoverRequest(HttpRequest request, ClassicHttpResponse re
protected void handleKeysRequest(HttpRequest request, ClassicHttpResponse response, HttpContext context) throws HttpException,
IOException {
response.setCode(200);
response.setEntity(new StringEntity(toJson(jwks)));
response.setEntity(new StringEntity(jwks.toString()));
}

private SSLContext createSSLContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import org.junit.Assert;
import org.junit.Test;

Expand All @@ -26,12 +27,12 @@ public class SelfRefreshingKeySetTest {
public void basicTest() throws AuthenticatorUnavailableException, BadCredentialsException {
SelfRefreshingKeySet selfRefreshingKeySet = new SelfRefreshingKeySet(new MockKeySetProvider());

JsonWebKey key1 = selfRefreshingKeySet.getKey("kid/a");
Assert.assertEquals(TestJwk.OCT_1_K, key1.getProperty("k"));
OctetSequenceKey key1 = (OctetSequenceKey) selfRefreshingKeySet.getKey("kid/a");
Assert.assertEquals(TestJwk.OCT_1_K, key1.getKeyValue().toString());
Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount());

JsonWebKey key2 = selfRefreshingKeySet.getKey("kid/b");
Assert.assertEquals(TestJwk.OCT_2_K, key2.getProperty("k"));
OctetSequenceKey key2 = (OctetSequenceKey) selfRefreshingKeySet.getKey("kid/b");
Assert.assertEquals(TestJwk.OCT_2_K, key2.getKeyValue().toString());
Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount());

try {
Expand All @@ -51,20 +52,20 @@ public void twoThreadedTest() throws Exception {

ExecutorService executorService = Executors.newCachedThreadPool();

Future<JsonWebKey> f1 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/a"));
Future<JWK> f1 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/a"));

provider.waitForCalled();

Future<JsonWebKey> f2 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/b"));
Future<JWK> f2 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/b"));

while (selfRefreshingKeySet.getQueuedGetCount() == 0) {
Thread.sleep(10);
}

provider.unblock();

Assert.assertEquals(TestJwk.OCT_1_K, f1.get().getProperty("k"));
Assert.assertEquals(TestJwk.OCT_2_K, f2.get().getProperty("k"));
Assert.assertEquals(TestJwk.OCT_1_K, ((OctetSequenceKey) f1.get()).getKeyValue().toString());
Assert.assertEquals(TestJwk.OCT_2_K, ((OctetSequenceKey) f2.get()).getKeyValue().toString());

Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount());
Assert.assertEquals(1, selfRefreshingKeySet.getQueuedGetCount());
Expand All @@ -74,7 +75,7 @@ public void twoThreadedTest() throws Exception {
static class MockKeySetProvider implements KeySetProvider {

@Override
public JsonWebKeys get() throws AuthenticatorUnavailableException {
public JWKSet get() throws AuthenticatorUnavailableException {
return TestJwk.OCT_1_2_3;
}

Expand All @@ -85,7 +86,7 @@ static class BlockingMockKeySetProvider extends MockKeySetProvider {
private boolean called = false;

@Override
public synchronized JsonWebKeys get() throws AuthenticatorUnavailableException {
public synchronized JWKSet get() throws AuthenticatorUnavailableException {

called = true;
notifyAll();
Expand Down
103 changes: 50 additions & 53 deletions src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwk.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@

import java.util.Arrays;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.KeyType;
import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.util.Base64URL;


class TestJwk {

Expand All @@ -29,13 +33,13 @@ class TestJwk {
static final String OCT_3_K =
"r3aeW3OK7-B4Hs3hq9BmlT1D3jRiolH9PL82XUz9xAS7dniAdmvMnN5GkOc1vqibOe2T-CC_103UglDm9D0iU9S9zn6wTuQt1L5wfZIoHd9f5IjJ_YFEzZMvsoUY_-ji_0K_ugVvBPwi9JnBQHHS4zrgmP06dGjmcnZDcIf4W_iFas3lDYSXilL1V2QhNaynpSqTarpfBGSphKv4Zg2JhsX8xB0VSaTlEq4lF8pzvpWSxXCW9CtomhB80daSuTizrmSTEPpdN3XzQ2-Tovo1ieMOfDU4csvjEk7Bwc2ThjpnA8ucKQUYpUv9joBxKuCdUltssthWnetrogjYOn_xGA";

static final JsonWebKey OCT_1 = createOct("kid/a", "HS256", OCT_1_K);
static final JsonWebKey OCT_2 = createOct("kid/b", "HS256", OCT_2_K);
static final JsonWebKey OCT_3 = createOct("kid/c", "HS256", OCT_3_K);
static final JsonWebKey ESCAPED_SLASH_KID_OCT_1 = createOct("kid\\/_a", "HS256", OCT_1_K);
static final JsonWebKey FORWARD_SLASH_KID_OCT_1 = createOct("kid/_a", "HS256", OCT_1_K);
static final JWK OCT_1 = createOct("kid/a", "HS256", OCT_1_K);
static final JWK OCT_2 = createOct("kid/b", "HS256", OCT_2_K);
static final JWK OCT_3 = createOct("kid/c", "HS256", OCT_3_K);
static final JWK ESCAPED_SLASH_KID_OCT_1 = createOct("kid\\/_a", "HS256", OCT_1_K);
static final JWK FORWARD_SLASH_KID_OCT_1 = createOct("kid/_a", "HS256", OCT_1_K);

static final JsonWebKeys OCT_1_2_3 = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, ESCAPED_SLASH_KID_OCT_1);
static final JWKSet OCT_1_2_3 = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, ESCAPED_SLASH_KID_OCT_1);

static final String RSA_1_D =
"On8XGMmdM5Fm5hvuhQk-qAkIP2CoK5QMx0OH5m_WDzKXZv8lZ2eg89I4ehBiOKGdw1h_mjmWwTah-evpXV-BF5QpejPQqxkXS-8s5r2AvietQq32jl-gwIwZWTvfzjpT9On0YJZ4q01tMDj3r-YOLUW2xrz3za9tl6pPU_5kP63C-hoj1ybTwcC7ujbCPwhY6yAopMA1v10uVmCxsjsNikEjB6YePgHixez51wO3Z8mXNwefWukFWYJ5T7t4kHMSf5P_8FJZ14u5yvYZnngE_tJCyHFdIDb6UWsrgxomtlQU-SdZYK_NY6gw6mCkjjlqOoYqlsrRJ16kJ81Ds269oQ";
Expand All @@ -55,67 +59,60 @@ class TestJwk {
"jDDVUMXOXDVcaRVAT5TtuiAsLxk7XAAwyyECfmySZul7D5XVLMtGe6rP2900q3nM4BaCEiuwXjmTCZDAGlFGs2a3eQ1vbBSv9_0KGHL-gZGFPNiv0v8aR7QzZ-abhGnRy5F52PlTWsypGgG_kQpF2t2TBotvYhvVPagAt4ljllDKvY1siOvS3nh4TqcUtWcbgQZEWPmaXuhx0eLmhQJca7UEw99YlGNew48AEzt7ZnfU0Qkz3JwSz7IcPx-NfIh6BN6LwAg_ASdoM3MR8rDOtLYavmJVhutrfOpE-4-fw1mf3eLYu7xrxIplSiOIsHunTUssnTiBkXAaGqGJs604Pw";
static final String RSA_X_E = "AQAB";

static final JsonWebKey RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D);
static final JsonWebKey RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N);
static final JsonWebKey RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N);
static final JsonWebKey RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N);
static final JWK RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D);
static final JWK RSA_KEY = new RSAKey.Builder(new Base64URL(RSA_1_N), new Base64URL(RSA_1_E))
.privateExponent(new Base64URL(RSA_1_D))
.algorithm(JWSAlgorithm.RS512)
.keyID("kid/1")
.build();

static final JWK RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N);
static final JWK RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N);
static final JWK RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N);

static final JsonWebKey RSA_2 = createRsa("kid/2", "RS256", RSA_2_E, RSA_2_N, RSA_2_D);
static final JsonWebKey RSA_2_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_2_E, RSA_2_N);
static final JWK RSA_2 = createRsa("kid/2", "RS256", RSA_2_E, RSA_2_N, RSA_2_D);
static final JWK RSA_2_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_2_E, RSA_2_N);

static final JsonWebKey RSA_X = createRsa("kid/2", "RS256", RSA_X_E, RSA_X_N, RSA_X_D);
static final JsonWebKey RSA_X_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_X_E, RSA_X_N);
static final JWK RSA_X = createRsa("kid/2", "RS256", RSA_X_E, RSA_X_N, RSA_X_D);
static final JWK RSA_X_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_X_E, RSA_X_N);

static final JsonWebKeys RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC);
static final JWKSet RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC);

static class Jwks {
static final JsonWebKeys ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC);
static final JsonWebKeys RSA_1 = createJwks(RSA_1_PUBLIC);
static final JsonWebKeys RSA_2 = createJwks(RSA_2_PUBLIC);
static final JsonWebKeys RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG);
static final JsonWebKeys RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG);
static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC);
static final JWKSet RSA_1 = createJwks(RSA_1_PUBLIC);
static final JWKSet RSA_2 = createJwks(RSA_2_PUBLIC);
static final JWKSet RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG);
static final JWKSet RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG);
}

private static JsonWebKey createOct(String keyId, String algorithm, String k) {
JsonWebKey result = new JsonWebKey();

result.setKeyId(keyId);
result.setKeyType(KeyType.OCTET);
result.setAlgorithm(algorithm);
result.setPublicKeyUse(PublicKeyUse.SIGN);
result.setProperty("k", k);

return result;
private static JWK createOct(String keyId, String algorithm, String k) {
return new OctetSequenceKey.Builder(new Base64URL(k))
.keyID(keyId)
.keyUse(KeyUse.SIGNATURE)
.algorithm(JWSAlgorithm.parse(algorithm))
.build();
}

private static JsonWebKey createRsa(String keyId, String algorithm, String e, String n, String d) {
JsonWebKey result = new JsonWebKey();

result.setKeyId(keyId);
result.setKeyType(KeyType.RSA);
result.setAlgorithm(algorithm);
result.setPublicKeyUse(PublicKeyUse.SIGN);
private static JWK createRsa(String keyId, String algorithm, String e, String n, String d) {
RSAKey.Builder builder = new RSAKey.Builder(new Base64URL(n), new Base64URL(e))
.keyUse(KeyUse.SIGNATURE)
.algorithm(algorithm == null ? null : JWSAlgorithm.parse(algorithm))
.keyID(keyId);

if (d != null) {
result.setProperty("d", d);
builder.privateExponent(new Base64URL(d));
}

result.setProperty("e", e);
result.setProperty("n", n);

return result;
return builder.build();
}

private static JsonWebKey createRsaPublic(String keyId, String algorithm, String e, String n) {
private static JWK createRsaPublic(String keyId, String algorithm, String e, String n) {
return createRsa(keyId, algorithm, e, n, null);
}

private static JsonWebKeys createJwks(JsonWebKey... array) {
JsonWebKeys result = new JsonWebKeys();

result.setKeys(Arrays.asList(array));

return result;
private static JWKSet createJwks(JWK... array) {
return new JWKSet(Arrays.asList(array));
}

}
Loading

0 comments on commit 93c1bce

Please sign in to comment.