Skip to content

Commit

Permalink
fix!: Verify audience claim matches client ID for OpenID provider
Browse files Browse the repository at this point in the history
The client ID is generated by the OpenID provider on registration, so the OpenID provider configuration needs to include the client ID.
  • Loading branch information
SanjayVas committed Dec 19, 2024
1 parent d03e8e4 commit b3944d0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.crypto.tink.jwt.JwkSetConverter
import com.google.crypto.tink.jwt.JwtPublicKeyVerify
import com.google.crypto.tink.jwt.JwtSignatureConfig
import com.google.crypto.tink.jwt.JwtValidator
import com.google.crypto.tink.jwt.VerifiedJwt
import com.google.gson.JsonObject
import com.google.gson.JsonParser
import io.grpc.Metadata
Expand All @@ -32,17 +33,11 @@ import org.wfanet.measurement.common.base64UrlDecode

/** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */
class OpenIdConnectAuthentication(
audience: String,
openIdProviderConfigs: Iterable<OpenIdProviderConfig>,
clock: Clock = Clock.systemUTC(),
) {
private val jwtValidator =
JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build()

private val jwksHandleByIssuer: Map<String, KeysetHandle> =
openIdProviderConfigs.associateBy({ it.issuer }) {
JwkSetConverter.toPublicKeysetHandle(it.jwks)
}
private val openIdProviderByIssuer: Map<String, OpenIdProvider> =
openIdProviderConfigs.associateBy({ it.issuer }) { OpenIdProvider(it, clock) }

/**
* Verifies and decodes an OIDC bearer token from [headers].
Expand Down Expand Up @@ -74,13 +69,13 @@ class OpenIdConnectAuthentication(
val issuer =
payload.get(ISSUER_CLAIM)?.asString
?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException()
val jwksHandle =
jwksHandleByIssuer[issuer]
val provider: OpenIdProvider =
openIdProviderByIssuer[issuer]
?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException()

val verifiedJwt =
val verifiedJwt: VerifiedJwt =
try {
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
provider.verifyAndDecode(token)
} catch (e: GeneralSecurityException) {
throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException()
}
Expand Down Expand Up @@ -113,5 +108,20 @@ class OpenIdConnectAuthentication(
val issuer: String,
/** JSON Web Key Set (JWKS) for the provider. */
val jwks: String,
/** Client ID registered with the provider. */
val clientId: String,
)

private class OpenIdProvider(providerConfig: OpenIdProviderConfig, clock: Clock) {
private val jwtValidator: JwtValidator =
JwtValidator.newBuilder()
.expectIssuer(providerConfig.issuer)
.expectAudience(providerConfig.clientId)
.setClock(clock)
.build()
private val jwksHandle: KeysetHandle = JwkSetConverter.toPublicKeysetHandle(providerConfig.jwks)

fun verifyAndDecode(token: String): VerifiedJwt =
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@ import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials
import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication

/** An ephemeral OpenID provider for testing. */
class OpenIdProvider(private val issuer: String) {
class OpenIdProvider(issuer: String, clientId: String) {
private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE)

val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy {
val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle)
OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks)
}
val providerConfig =
OpenIdConnectAuthentication.OpenIdProviderConfig(
issuer,
JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle),
clientId,
)

fun generateCredentials(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)),
): BearerTokenCallCredentials {
val token = generateSignedToken(audience, subject, scopes, expiration)
val token = generateSignedToken(subject, scopes, expiration)
return BearerTokenCallCredentials(token, false)
}

/** Generates a signed and encoded JWT using the specified parameters. */
private fun generateSignedToken(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant,
): String {
val rawJwt =
RawJwt.newBuilder()
.setAudience(audience)
.setIssuer(issuer)
.setAudience(providerConfig.clientId)
.setIssuer(providerConfig.issuer)
.setSubject(subject)
.addStringClaim("scope", scopes.joinToString(" "))
.setExpiration(expiration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken returns VerifiedToken`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials))

Expand All @@ -53,17 +53,16 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials =
openIdProvider.generateCredentials(
audience,
subject,
scopes,
Instant.now().minus(Duration.ofMinutes(5)),
)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val exception =
assertFailsWith<StatusException> {
Expand All @@ -78,11 +77,14 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth =
OpenIdConnectAuthentication(
listOf(openIdProvider.providerConfig.copy(clientId = "bad-client-id"))
)

val exception =
assertFailsWith<StatusException> {
Expand All @@ -97,11 +99,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -114,9 +116,8 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() {
val audience = "foobar"
val credentials = BearerTokenCallCredentials("foo", false)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -129,8 +130,7 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() {
val audience = "foobar"
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception = assertFailsWith<StatusException> { auth.verifyAndDecodeBearerToken(Metadata()) }

Expand Down

0 comments on commit b3944d0

Please sign in to comment.