diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt index c92803527..0f82dfce6 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt @@ -21,6 +21,7 @@ import com.google.crypto.tink.jwt.JwkSetConverter import com.google.crypto.tink.jwt.JwtPublicKeyVerify import com.google.crypto.tink.jwt.JwtSignatureConfig import com.google.crypto.tink.jwt.JwtValidator +import com.google.crypto.tink.jwt.VerifiedJwt import com.google.gson.JsonObject import com.google.gson.JsonParser import io.grpc.Metadata @@ -32,17 +33,11 @@ import org.wfanet.measurement.common.base64UrlDecode /** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */ class OpenIdConnectAuthentication( - audience: String, openIdProviderConfigs: Iterable, clock: Clock = Clock.systemUTC(), ) { - private val jwtValidator = - JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build() - - private val jwksHandleByIssuer: Map = - openIdProviderConfigs.associateBy({ it.issuer }) { - JwkSetConverter.toPublicKeysetHandle(it.jwks) - } + private val openIdProviderByIssuer: Map = + openIdProviderConfigs.associateBy({ it.issuer }) { OpenIdProvider(it, clock) } /** * Verifies and decodes an OIDC bearer token from [headers]. @@ -74,13 +69,13 @@ class OpenIdConnectAuthentication( val issuer = payload.get(ISSUER_CLAIM)?.asString ?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException() - val jwksHandle = - jwksHandleByIssuer[issuer] + val provider: OpenIdProvider = + openIdProviderByIssuer[issuer] ?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException() - val verifiedJwt = + val verifiedJwt: VerifiedJwt = try { - jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator) + provider.verifyAndDecode(token) } catch (e: GeneralSecurityException) { throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException() } @@ -113,5 +108,20 @@ class OpenIdConnectAuthentication( val issuer: String, /** JSON Web Key Set (JWKS) for the provider. */ val jwks: String, + /** Client ID registered with the provider. */ + val clientId: String, ) + + private class OpenIdProvider(providerConfig: OpenIdProviderConfig, clock: Clock) { + private val jwtValidator: JwtValidator = + JwtValidator.newBuilder() + .expectIssuer(providerConfig.issuer) + .expectAudience(providerConfig.clientId) + .setClock(clock) + .build() + private val jwksHandle: KeysetHandle = JwkSetConverter.toPublicKeysetHandle(providerConfig.jwks) + + fun verifyAndDecode(token: String): VerifiedJwt = + jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator) + } } diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt index 7ffdc1563..e611f97ba 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt @@ -29,35 +29,33 @@ import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication /** An ephemeral OpenID provider for testing. */ -class OpenIdProvider(private val issuer: String) { +class OpenIdProvider(issuer: String, clientId: String) { private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE) val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy { val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle) - OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks) + OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks, clientId) } fun generateCredentials( - audience: String, subject: String, scopes: Set, expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)), ): BearerTokenCallCredentials { - val token = generateSignedToken(audience, subject, scopes, expiration) + val token = generateSignedToken(subject, scopes, expiration) return BearerTokenCallCredentials(token, false) } /** Generates a signed and encoded JWT using the specified parameters. */ private fun generateSignedToken( - audience: String, subject: String, scopes: Set, expiration: Instant, ): String { val rawJwt = RawJwt.newBuilder() - .setAudience(audience) - .setIssuer(issuer) + .setAudience(providerConfig.clientId) + .setIssuer(providerConfig.issuer) .setSubject(subject) .addStringClaim("scope", scopes.joinToString(" ")) .setExpiration(expiration) diff --git a/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt index 168290886..a83c0aefe 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt @@ -38,11 +38,11 @@ class OpenIdConnectAuthenticationTest { fun `verifyAndDecodeBearerToken returns VerifiedToken`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials(audience, subject, scopes) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + val openIdProvider = OpenIdProvider(issuer, clientId) + val credentials = openIdProvider.generateCredentials(subject, scopes) + val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig)) val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) @@ -53,17 +53,16 @@ class OpenIdConnectAuthenticationTest { fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) + val openIdProvider = OpenIdProvider(issuer, clientId) val credentials = openIdProvider.generateCredentials( - audience, subject, scopes, Instant.now().minus(Duration.ofMinutes(5)), ) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig)) val exception = assertFailsWith { @@ -74,34 +73,15 @@ class OpenIdConnectAuthenticationTest { assertThat(exception).hasMessageThat().ignoringCase().contains("expired") } - @Test - fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() { - val issuer = "example.com" - val subject = "user1@example.com" - val audience = "foobar" - val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes) - val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) - - val exception = - assertFailsWith { - auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) - } - - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception).hasMessageThat().ignoringCase().contains("audience") - } - @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() { val issuer = "example.com" val subject = "user1@example.com" - val audience = "foobar" + val clientId = "foobar" val scopes = setOf("foo.bar", "foo.baz") - val openIdProvider = OpenIdProvider(issuer) - val credentials = openIdProvider.generateCredentials(audience, subject, scopes) - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val openIdProvider = OpenIdProvider(issuer, clientId) + val credentials = openIdProvider.generateCredentials(subject, scopes) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { @@ -114,9 +94,8 @@ class OpenIdConnectAuthenticationTest { @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() { - val audience = "foobar" val credentials = BearerTokenCallCredentials("foo", false) - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { @@ -129,8 +108,7 @@ class OpenIdConnectAuthenticationTest { @Test fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() { - val audience = "foobar" - val auth = OpenIdConnectAuthentication(audience, emptyList()) + val auth = OpenIdConnectAuthentication(emptyList()) val exception = assertFailsWith { auth.verifyAndDecodeBearerToken(Metadata()) }