Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Verify audience claim matches client ID for OpenID provider #290

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.crypto.tink.jwt.JwkSetConverter
import com.google.crypto.tink.jwt.JwtPublicKeyVerify
import com.google.crypto.tink.jwt.JwtSignatureConfig
import com.google.crypto.tink.jwt.JwtValidator
import com.google.crypto.tink.jwt.VerifiedJwt
import com.google.gson.JsonObject
import com.google.gson.JsonParser
import io.grpc.Metadata
Expand All @@ -32,17 +33,11 @@ import org.wfanet.measurement.common.base64UrlDecode

/** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */
class OpenIdConnectAuthentication(
audience: String,
openIdProviderConfigs: Iterable<OpenIdProviderConfig>,
clock: Clock = Clock.systemUTC(),
) {
private val jwtValidator =
JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build()

private val jwksHandleByIssuer: Map<String, KeysetHandle> =
openIdProviderConfigs.associateBy({ it.issuer }) {
JwkSetConverter.toPublicKeysetHandle(it.jwks)
}
private val openIdProviderByIssuer: Map<String, OpenIdProvider> =
openIdProviderConfigs.associateBy({ it.issuer }) { OpenIdProvider(it, clock) }

/**
* Verifies and decodes an OIDC bearer token from [headers].
Expand Down Expand Up @@ -74,13 +69,13 @@ class OpenIdConnectAuthentication(
val issuer =
payload.get(ISSUER_CLAIM)?.asString
?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException()
val jwksHandle =
jwksHandleByIssuer[issuer]
val provider: OpenIdProvider =
openIdProviderByIssuer[issuer]
?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException()

val verifiedJwt =
val verifiedJwt: VerifiedJwt =
try {
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
provider.verifyAndDecode(token)
} catch (e: GeneralSecurityException) {
throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException()
}
Expand Down Expand Up @@ -113,5 +108,20 @@ class OpenIdConnectAuthentication(
val issuer: String,
/** JSON Web Key Set (JWKS) for the provider. */
val jwks: String,
/** Client ID registered with the provider. */
val clientId: String,
)

private class OpenIdProvider(providerConfig: OpenIdProviderConfig, clock: Clock) {
private val jwtValidator: JwtValidator =
JwtValidator.newBuilder()
.expectIssuer(providerConfig.issuer)
.expectAudience(providerConfig.clientId)
.setClock(clock)
.build()
private val jwksHandle: KeysetHandle = JwkSetConverter.toPublicKeysetHandle(providerConfig.jwks)

fun verifyAndDecode(token: String): VerifiedJwt =
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@ import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials
import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication

/** An ephemeral OpenID provider for testing. */
class OpenIdProvider(private val issuer: String) {
class OpenIdProvider(issuer: String, clientId: String) {
private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE)

val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy {
val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle)
OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks)
}
val providerConfig =
OpenIdConnectAuthentication.OpenIdProviderConfig(
issuer,
JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle),
clientId,
)

fun generateCredentials(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)),
): BearerTokenCallCredentials {
val token = generateSignedToken(audience, subject, scopes, expiration)
val token = generateSignedToken(subject, scopes, expiration)
return BearerTokenCallCredentials(token, false)
}

/** Generates a signed and encoded JWT using the specified parameters. */
private fun generateSignedToken(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant,
): String {
val rawJwt =
RawJwt.newBuilder()
.setAudience(audience)
.setIssuer(issuer)
.setAudience(providerConfig.clientId)
.setIssuer(providerConfig.issuer)
.setSubject(subject)
.addStringClaim("scope", scopes.joinToString(" "))
.setExpiration(expiration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken returns VerifiedToken`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials))

Expand All @@ -53,17 +53,16 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials =
openIdProvider.generateCredentials(
audience,
subject,
scopes,
Instant.now().minus(Duration.ofMinutes(5)),
)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val auth = OpenIdConnectAuthentication(listOf(openIdProvider.providerConfig))

val exception =
assertFailsWith<StatusException> {
Expand All @@ -78,11 +77,14 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes)
val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig))
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth =
OpenIdConnectAuthentication(
listOf(openIdProvider.providerConfig.copy(clientId = "bad-client-id"))
)

val exception =
assertFailsWith<StatusException> {
Expand All @@ -97,11 +99,11 @@ class OpenIdConnectAuthenticationTest {
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() {
val issuer = "example.com"
val subject = "[email protected]"
val audience = "foobar"
val clientId = "foobar"
val scopes = setOf("foo.bar", "foo.baz")
val openIdProvider = OpenIdProvider(issuer)
val credentials = openIdProvider.generateCredentials(audience, subject, scopes)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val openIdProvider = OpenIdProvider(issuer, clientId)
val credentials = openIdProvider.generateCredentials(subject, scopes)
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -114,9 +116,8 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() {
val audience = "foobar"
val credentials = BearerTokenCallCredentials("foo", false)
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception =
assertFailsWith<StatusException> {
Expand All @@ -129,8 +130,7 @@ class OpenIdConnectAuthenticationTest {

@Test
fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() {
val audience = "foobar"
val auth = OpenIdConnectAuthentication(audience, emptyList())
val auth = OpenIdConnectAuthentication(emptyList())

val exception = assertFailsWith<StatusException> { auth.verifyAndDecodeBearerToken(Metadata()) }

Expand Down