From e56585d8dbefda58a8a62fa313d0290dcd05380f Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Tue, 17 Dec 2024 14:23:31 -0800 Subject: [PATCH] fix!: Verify audience claim matches client ID for OpenID provider The client ID is generated by the OpenID provider on registration, so the OpenID provider configuration needs to include the client ID. --- .../grpc/OpenIdConnectAuthentication.kt | 34 ++++++++----- .../common/grpc/testing/OpenIdProvider.kt | 12 ++--- .../grpc/OpenIdConnectAuthenticationTest.kt | 48 +++++-------------- 3 files changed, 40 insertions(+), 54 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt index c92803527..0f82dfce6 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt @@ -21,6 +21,7 @@ import com.google.crypto.tink.jwt.JwkSetConverter import com.google.crypto.tink.jwt.JwtPublicKeyVerify import com.google.crypto.tink.jwt.JwtSignatureConfig import com.google.crypto.tink.jwt.JwtValidator +import com.google.crypto.tink.jwt.VerifiedJwt import com.google.gson.JsonObject import com.google.gson.JsonParser import io.grpc.Metadata @@ -32,17 +33,11 @@ import org.wfanet.measurement.common.base64UrlDecode /** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */ class OpenIdConnectAuthentication( - audience: String, openIdProviderConfigs: Iterable, clock: Clock = Clock.systemUTC(), ) { - private val jwtValidator = - JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build() - - private val jwksHandleByIssuer: Map = - openIdProviderConfigs.associateBy({ it.issuer }) { - JwkSetConverter.toPublicKeysetHandle(it.jwks) - } + private val openIdProviderByIssuer: Map = + openIdProviderConfigs.associateBy({ it.issuer }) { OpenIdProvider(it, clock) } /** * Verifies and decodes an OIDC bearer token from [headers]. @@ -74,13 +69,13 @@ class OpenIdConnectAuthentication( val issuer = payload.get(ISSUER_CLAIM)?.asString ?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException() - val jwksHandle = - jwksHandleByIssuer[issuer] + val provider: OpenIdProvider = + openIdProviderByIssuer[issuer] ?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException() - val verifiedJwt = + val verifiedJwt: VerifiedJwt = try { - jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator) + provider.verifyAndDecode(token) } catch (e: GeneralSecurityException) { throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException() } @@ -113,5 +108,20 @@ class OpenIdConnectAuthentication( val issuer: String, /** JSON Web Key Set (JWKS) for the provider. */ val jwks: String, + /** Client ID registered with the provider. */ + val clientId: String, ) + + private class OpenIdProvider(providerConfig: OpenIdProviderConfig, clock: Clock) { + private val jwtValidator: JwtValidator = + JwtValidator.newBuilder() + .expectIssuer(providerConfig.issuer) + .expectAudience(providerConfig.clientId) + .setClock(clock) + .build() + private val jwksHandle: KeysetHandle = JwkSetConverter.toPublicKeysetHandle(providerConfig.jwks) + + fun verifyAndDecode(token: String): VerifiedJwt = + jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator) + } } diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt index 7ffdc1563..e611f97ba 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt @@ -29,35 +29,33 @@ import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication /** An ephemeral OpenID provider for testing. */ -class OpenIdProvider(private val issuer: String) { +class OpenIdProvider(issuer: String, clientId: String) { private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE) val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy { val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle) - OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks) + OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks, clientId) } fun generateCredentials( - audience: String, subject: String, scopes: Set, expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)), ): BearerTokenCallCredentials { - val token = generateSignedToken(audience, subject, scopes, expiration) + val token = generateSignedToken(subject, scopes, expiration) return BearerTokenCallCredentials(token, false) } /** Generates a signed and encoded JWT using the specified parameters. */ private fun generateSignedToken( - audience: String, subject: String, scopes: Set, expiration: Instant, ): String { val rawJwt = RawJwt.newBuilder() - .setAudience(audience) - .setIssuer(issuer) + .setAudience(providerConfig.clientId) + .setIssuer(providerConfig.issuer) .setSubject(subject) .addStringClaim("scope", scopes.joinToString(" ")) .setExpiration(expiration) diff --git a/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt index 168290886..a83c0aefe 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt @@ -38,11 +38,11 @@ class OpenIdConnectAuthenticationTest { fun `verifyAndDecodeBearerToken returns VerifiedToken`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials(audience, subject, scopes) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + val openIdProvider = OpenIdProvider(issuer, clientId) + val credentials = openIdProvider.generateCredentials(subject, scopes) + val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig)) val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) @@ -53,17 +53,16 @@ class OpenIdConnectAuthenticationTest { fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) + val openIdProvider = OpenIdProvider(issuer, clientId) val credentials = openIdProvider.generateCredentials( - audience, subject, scopes, Instant.now().minus(Duration.ofMinutes(5)), ) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig)) val exception = assertFailsWith { @@ -74,34 +73,15 @@ class OpenIdConnectAuthenticationTest { assertThat(exception).hasMessageThat().ignoringCase().contains("expired") } - @Test - fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() { - val issuer = "example.com" - val subject = "user1@example.com" - val audience = "foobar" - val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) - - val exception = - assertFailsWith { - auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) - } - - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception).hasMessageThat().ignoringCase().contains("audience") - } - @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials(audience, subject, scopes) - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val openIdProvider = OpenIdProvider(issuer, clientId) + val credentials = openIdProvider.generateCredentials(subject, scopes) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { @@ -114,9 +94,8 @@ class OpenIdConnectAuthenticationTest { @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() { - val audience = "foobar" val credentials = BearerTokenCallCredentials("foo", false) - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { @@ -129,8 +108,7 @@ class OpenIdConnectAuthenticationTest { @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() { - val audience = "foobar" - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { auth.verifyAndDecodeBearerToken(Metadata()) }