Skip to content

Commit

Permalink
fix!: Verify audience claim matches client ID for OpenID provider (#290)
Browse files Browse the repository at this point in the history
This addresses a bug in #288 

The client ID is generated by the OpenID provider on registration, so the OpenID provider configuration needs to include the client ID.
  • Loading branch information
SanjayVas authored Dec 19, 2024
1 parent d03e8e4 commit 8441ad2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.crypto.tink.jwt.JwkSetConverter
import com.google.crypto.tink.jwt.JwtPublicKeyVerify
import com.google.crypto.tink.jwt.JwtSignatureConfig
import com.google.crypto.tink.jwt.JwtValidator
import com.google.crypto.tink.jwt.VerifiedJwt
import com.google.gson.JsonObject
import com.google.gson.JsonParser
import io.grpc.Metadata
Expand All @@ -32,17 +33,11 @@ import org.wfanet.measurement.common.base64UrlDecode

/** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */
class OpenIdConnectAuthentication(
audience: String,
openIdProviderConfigs: Iterable<OpenIdProviderConfig>,
clock: Clock = Clock.systemUTC(),
) {
private val jwtValidator =
JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build()

private val jwksHandleByIssuer: Map<String, KeysetHandle> =
openIdProviderConfigs.associateBy({ it.issuer }) {
JwkSetConverter.toPublicKeysetHandle(it.jwks)
}
private val openIdProviderByIssuer: Map<String, OpenIdProvider> =
openIdProviderConfigs.associateBy({ it.issuer }) { OpenIdProvider(it, clock) }

/**
* Verifies and decodes an OIDC bearer token from [headers].
Expand Down Expand Up @@ -74,13 +69,13 @@ class OpenIdConnectAuthentication(
val issuer =
payload.get(ISSUER_CLAIM)?.asString
?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException()
val jwksHandle =
jwksHandleByIssuer[issuer]
val provider: OpenIdProvider =
openIdProviderByIssuer[issuer]
?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException()

val verifiedJwt =
val verifiedJwt: VerifiedJwt =
try {
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
provider.verifyAndDecode(token)
} catch (e: GeneralSecurityException) {
throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException()
}
Expand Down Expand Up @@ -113,5 +108,20 @@ class OpenIdConnectAuthentication(
val issuer: String,
/** JSON Web Key Set (JWKS) for the provider. */
val jwks: String,
/** Client ID registered with the provider. */
val clientId: String,
)

private class OpenIdProvider(providerConfig: OpenIdProviderConfig, clock: Clock) {
private val jwtValidator: JwtValidator =
JwtValidator.newBuilder()
.expectIssuer(providerConfig.issuer)
.expectAudience(providerConfig.clientId)
.setClock(clock)
.build()
private val jwksHandle: KeysetHandle = JwkSetConverter.toPublicKeysetHandle(providerConfig.jwks)

fun verifyAndDecode(token: String): VerifiedJwt =
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@ import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials
import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication

/** An ephemeral OpenID provider for testing. */
class OpenIdProvider(private val issuer: String) {
class OpenIdProvider(issuer: String, clientId: String) {
private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE)

val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy {
val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle)
OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks)
}
val providerConfig =
OpenIdConnectAuthentication.OpenIdProviderConfig(
issuer,
JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle),
clientId,
)

fun generateCredentials(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)),
): BearerTokenCallCredentials {
val token = generateSignedToken(audience, subject, scopes, expiration)
val token = generateSignedToken(subject, scopes, expiration)
return BearerTokenCallCredentials(token, false)
}

/** Generates a signed and encoded JWT using the specified parameters. */
private fun generateSignedToken(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant,
): String {
val rawJwt =
RawJwt.newBuilder()
.setAudience(audience)
.setIssuer(issuer)
.setAudience(providerConfig.clientId)
.setIssuer(providerConfig.issuer)
.setSubject(subject)
.addStringClaim("scope", scopes.joinToString(" "))
.setExpiration(expiration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken returns VerifiedToken`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials))

Expand All @@ -53,17 +53,16 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials =
openIdProvider.generateCredentials(
audience,
subject,
scopes,
Instant.now().minus(Duration.ofMinutes(5)),
)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val exception =
assertFailsWith<StatusException> {
Expand All @@ -78,11 +77,14 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth =
OpenIdConnectAuthentication(
listOf(openIdProvider.providerConfig.copy(clientId = "bad-client-id"))
)

val exception =
assertFailsWith<StatusException> {
Expand All @@ -97,11 +99,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -114,9 +116,8 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() {
val audience = "foobar"
val credentials = BearerTokenCallCredentials("foo", false)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -129,8 +130,7 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() {
val audience = "foobar"
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception = assertFailsWith<StatusException> { auth.verifyAndDecodeBearerToken(Metadata()) }

Expand Down

0 comments on commit 8441ad2

Please sign in to comment.