Skip to content

Commit

Permalink
chore
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-olaveide committed Jan 15, 2024
1 parent a53f0ce commit 39e1e51
Show file tree
Hide file tree
Showing 17 changed files with 62 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ClientProperties @JvmOverloads constructor(var tokenEndpointUrl: URI? = nu

private fun endpointUrlFromMetadata(wellKnown: URI?) =
runCatching {
wellKnown?.let { AuthorizationServerMetadata.parse(DefaultResourceRetriever().retrieveResource(wellKnown.toURL()).content).tokenEndpointURI }
wellKnown?.let { AuthorizationServerMetadata.parse(DefaultResourceRetriever().retrieveResource(it.toURL()).content).tokenEndpointURI }
?: throw OAuth2ClientException("Well-known url cannot be null, please check your configuration")
}.getOrElse {
when(it) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ class ClientAssertion(private val tokenEndpointUrl : URI, private val clientId :
.issueTime(Date.from(this))
.build()).serialize()
}

@Deprecated("Use com.nimbusds.oauth2.sdk.auth.JWTAuthentication instead", ReplaceWith("JWTAuthentication.CLIENT_ASSERTION_TYPE"), WARNING)
fun assertionType() = CLIENT_ASSERTION_TYPE

private fun createSignedJWT(rsaJwk : RSAKey, claimsSet : JWTClaimsSet) =
runCatching {
SignedJWT(JWSHeader.Builder(RS256)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package no.nav.security.token.support.client.core.http
import no.nav.security.token.support.client.core.oauth2.OAuth2AccessTokenResponse

interface OAuth2HttpClient {
fun post(request : OAuth2HttpRequest) : OAuth2AccessTokenResponse?
fun post(request : OAuth2HttpRequest) : OAuth2AccessTokenResponse
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.lang.String.CASE_INSENSITIVE_ORDER
import java.util.Objects
import java.util.TreeMap

class OAuth2HttpHeaders (val headers : Map<String, List<String>>) {
class OAuth2HttpHeaders(val headers : Map<String, List<String>> = emptyMap()) {

override fun equals(other : Any?) : Boolean {
if (this === other) return true
Expand All @@ -27,7 +27,7 @@ class OAuth2HttpHeaders (val headers : Map<String, List<String>>) {
companion object {

@JvmField
val NONE = OAuth2HttpHeaders(emptyMap())
val NONE = OAuth2HttpHeaders()
@JvmStatic
fun of(headers : Map<String, List<String>>) = OAuth2HttpHeaders(headers)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package no.nav.security.token.support.client.core.http

import java.net.URI
import java.util.Collections.unmodifiableMap
import no.nav.security.token.support.client.core.http.OAuth2HttpHeaders.Companion.NONE

class OAuth2HttpRequest(val tokenEndpointUrl : URI, val oAuth2HttpHeaders : OAuth2HttpHeaders = NONE, val formParameters : Map<String, String>) {
Expand All @@ -18,7 +17,7 @@ class OAuth2HttpRequest(val tokenEndpointUrl : URI, val oAuth2HttpHeaders : OAut

fun formParameters(entries: Map<String, String>) = this.also { formParameters.putAll(entries) }

fun build(): OAuth2HttpRequest = OAuth2HttpRequest(tokenEndpointUrl, oAuth2HttpHeaders, unmodifiableMap(formParameters))
fun build(): OAuth2HttpRequest = OAuth2HttpRequest(tokenEndpointUrl, oAuth2HttpHeaders, formParameters.toMap())

@Override
override fun toString() = "OAuth2HttpRequest.OAuth2HttpRequestBuilder(tokenEndpointUrl=$tokenEndpointUrl, oAuth2HttpHeaders=$oAuth2HttpHeaders, entries=$formParameters"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.nimbusds.oauth2.sdk.auth.JWTAuthentication
import java.lang.String.join
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Base64.getEncoder
import org.checkerframework.checker.nullness.qual.NonNull
import no.nav.security.token.support.client.core.ClientProperties
import no.nav.security.token.support.client.core.OAuth2ClientException
import no.nav.security.token.support.client.core.OAuth2ParameterNames.CLIENT_ASSERTION
Expand All @@ -23,12 +24,12 @@ import no.nav.security.token.support.client.core.http.OAuth2HttpClient
import no.nav.security.token.support.client.core.http.OAuth2HttpHeaders
import no.nav.security.token.support.client.core.http.OAuth2HttpRequest

abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest?> internal constructor(private val oAuth2HttpClient : OAuth2HttpClient) {
abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest> internal constructor(private val oAuth2HttpClient : OAuth2HttpClient) {

protected abstract fun formParameters(grantRequest : T) : Map<String, String>

fun getTokenResponse(grantRequest : T) =
grantRequest?.clientProperties?.let {
grantRequest.clientProperties.let {
runCatching {
oAuth2HttpClient.post(OAuth2HttpRequest.builder(it.tokenEndpointUrl!!)
.oAuth2HttpHeaders(OAuth2HttpHeaders.of(tokenRequestHeaders(it)))
Expand Down Expand Up @@ -57,7 +58,7 @@ abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest?> intern
}

private fun defaultFormParameters(grantRequest : T) =
grantRequest?.clientProperties?.let {
grantRequest.clientProperties.let {
defaultClientAuthenticationFormParameters(grantRequest).apply {
put(GRANT_TYPE,grantRequest.grantType.value)
if (TOKEN_EXCHANGE != it.grantType) {
Expand All @@ -67,22 +68,20 @@ abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest?> intern
} ?: throw OAuth2ClientException("ClientProperties cannot be null")

private fun defaultClientAuthenticationFormParameters(grantRequest : T) =
grantRequest?.clientProperties?.let {
with(it) {
when (authentication.clientAuthMethod) {
CLIENT_SECRET_POST -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_SECRET, authentication.clientSecret!!)
}
PRIVATE_KEY_JWT -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_ASSERTION_TYPE, JWTAuthentication.CLIENT_ASSERTION_TYPE)
put(CLIENT_ASSERTION, ClientAssertion(tokenEndpointUrl!!, authentication).assertion())
}
else -> mutableMapOf()
with(grantRequest.clientProperties) {
when (authentication.clientAuthMethod) {
CLIENT_SECRET_POST -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_SECRET, authentication.clientSecret!!)
}
PRIVATE_KEY_JWT -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_ASSERTION_TYPE, JWTAuthentication.CLIENT_ASSERTION_TYPE)
put(CLIENT_ASSERTION, ClientAssertion(tokenEndpointUrl!!, authentication).assertion())
}
else -> mutableMapOf()
}
} ?: throw OAuth2ClientException("ClientProperties cannot be null")
}

private fun basicAuth(username : String, password : String) =
UTF_8.newEncoder().run {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OAuth2AccessTokenService @JvmOverloads constructor(private val tokenResolv



fun getAccessToken(clientProperties : ClientProperties) : OAuth2AccessTokenResponse? {
fun getAccessToken(clientProperties : ClientProperties) : OAuth2AccessTokenResponse {
log.trace("Getting access_token for grant={}", clientProperties.grantType)
return when (clientProperties.grantType) {
JWT_BEARER -> executeOnBehalfOf(clientProperties)
Expand All @@ -40,18 +40,17 @@ class OAuth2AccessTokenService @JvmOverloads constructor(private val tokenResolv
getFromCacheIfEnabled(ClientCredentialsGrantRequest(clientProperties), clientCredentialsGrantCache, clientCredentialsTokenClient::getTokenResponse)

private fun tokenExchangeGrantRequest(clientProperties : ClientProperties) =
TokenExchangeGrantRequest(clientProperties, tokenResolver.token() ?: throw OAuth2ClientException("no authenticated jwt token found in validation context, cannot do token exchange"))
TokenExchangeGrantRequest(clientProperties, tokenResolver.token() ?: throw OAuth2ClientException("No authenticated jwt token found in validation context, cannot do token exchange"))

private fun onBehalfOfGrantRequest(clientProperties : ClientProperties) =
OnBehalfOfGrantRequest(clientProperties, tokenResolver.token() ?: throw OAuth2ClientException("no authenticated jwt token found in validation context, cannot do on-behalf-of"))
OnBehalfOfGrantRequest(clientProperties, tokenResolver.token() ?: throw OAuth2ClientException("No authenticated jwt token found in validation context, cannot do on-behalf-of"))

override fun toString() = "${javaClass.getSimpleName()} [clientCredentialsGrantCache=$clientCredentialsGrantCache, onBehalfOfGrantCache=$onBehalfOfGrantCache, tokenExchangeClient=$tokenExchangeClient, tokenResolver=$tokenResolver, onBehalfOfTokenClient=$onBehalfOfTokenClient, clientCredentialsTokenClient=$clientCredentialsTokenClient, exchangeGrantCache=$exchangeGrantCache]"
companion object {

private val SUPPORTED_GRANT_TYPES = listOf(JWT_BEARER, CLIENT_CREDENTIALS, TOKEN_EXCHANGE
)
private val SUPPORTED_GRANT_TYPES = listOf(JWT_BEARER, CLIENT_CREDENTIALS, TOKEN_EXCHANGE)
private val log = LoggerFactory.getLogger(OAuth2AccessTokenService::class.java)
private fun <T : AbstractOAuth2GrantRequest?> getFromCacheIfEnabled(grantRequest : T, cache : Cache<T, OAuth2AccessTokenResponse>?, client : Function<T, OAuth2AccessTokenResponse?>) =
private fun <T : AbstractOAuth2GrantRequest?> getFromCacheIfEnabled(grantRequest : T, cache : Cache<T, OAuth2AccessTokenResponse>?, client : Function<T, OAuth2AccessTokenResponse>) =
cache?.let {
log.debug("Cache is enabled so attempt to get from cache or update cache if not present.")
cache[grantRequest, client]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ class SimpleOAuth2HttpClient : OAuth2HttpClient {

private fun HttpRequest.sendRequest() = newHttpClient().send(this, BodyHandlers.ofString())
private fun HttpResponse<String>.processResponse() =
if (this.statusCode() in 200..299) {
MAPPER.readValue<OAuth2AccessTokenResponse>(body())
} else {
throw OAuth2ClientException("Error response from token endpoint: ${this.statusCode()} ${this.body()}")
with(this) {
if (statusCode() in 200..299) {
MAPPER.readValue<OAuth2AccessTokenResponse>(body())
} else {
throw OAuth2ClientException("Error response from token endpoint: ${statusCode()} ${body()}")
}
}

private fun Map<String, String>.toUrlEncodedString() = entries.joinToString("&") { (key, value) -> "$key=${URLEncoder.encode(value, UTF_8)}" }
companion object {
private val MAPPER = jacksonObjectMapper().configure(FAIL_ON_UNKNOWN_PROPERTIES, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ internal class OnBehalfOfTokenClientTest {
.contains("requested_token_use=on_behalf_of")
.contains("assertion=$assertion")
assertThat(response).isNotNull()
assertThat(response?.accessToken).isNotBlank()
assertThat(response?.expiresAt).isPositive()
assertThat(response?.expiresIn).isPositive()
assertThat(response.accessToken).isNotBlank()
assertThat(response.expiresAt).isPositive()
assertThat(response.expiresIn).isPositive()
}

@Test
Expand All @@ -63,8 +63,7 @@ internal class OnBehalfOfTokenClientTest {
val clientProperties = clientProperties(tokenEndpointUrl, JWT_BEARER)
val oAuth2OnBehalfOfGrantRequest = OnBehalfOfGrantRequest(clientProperties, assertion)
assertThrows<OAuth2ClientException> {
val res = onBehalfOfTokenResponseClient.getTokenResponse(oAuth2OnBehalfOfGrantRequest)
println(res)
onBehalfOfTokenResponseClient.getTokenResponse(oAuth2OnBehalfOfGrantRequest)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ class ClientConfig(applicationConfig: ApplicationConfig, httpClient: HttpClient)
internal val clients =
applicationConfig.configList(CLIENTS_PATH)
.associate {
val wellKnownUrl = it.propertyToString("well_known_url")
val clientAuth = ClientAuthenticationProperties(
it.propertyToString("authentication.client_id"),
ClientAuthenticationMethod(it.propertyToString("authentication.client_auth_method")),
it.propertyToStringOrNull("client_secret"),
it.propertyToStringOrNull("authentication.client_jwk"))
it.propertyToString(CLIENT_NAME) to OAuth2Client(httpClient, wellKnownUrl, clientAuth, cacheConfig)
it.propertyToString(CLIENT_NAME) to OAuth2Client(httpClient, it.propertyToString("well_known_url"), clientAuth, cacheConfig)
}

companion object CommonConfigurationAttributes {
const val COMMON_PREFIX = "no.nav.security.jwt.client.registration"
private const val COMMON_PREFIX = "no.nav.security.jwt.client.registration"
const val CLIENTS_PATH = "${COMMON_PREFIX}.clients"
const val CACHE_PATH = "${COMMON_PREFIX}.cache"
const val CLIENT_NAME = "client_name"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import no.nav.security.token.support.client.spring.ClientConfigurationProperties
* Default implementation that matcher host in request URL with the registration
* name. Override for other strategies. Will typically be used with
* [OAuth2ClientRequestInterceptor]. Must be registered by the
* applications themselves, no automatic bean registration
* applications themselves, there is no automatic bean registration
*
*/
interface ClientConfigurationPropertiesMatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package no.nav.security.token.support.client.spring.oauth2
import org.springframework.http.HttpHeaders
import org.springframework.util.LinkedMultiValueMap
import org.springframework.web.client.RestClient
import org.springframework.web.client.body
import no.nav.security.token.support.client.core.OAuth2ClientException
import no.nav.security.token.support.client.core.http.OAuth2HttpClient
import no.nav.security.token.support.client.core.http.OAuth2HttpRequest
Expand All @@ -12,18 +13,15 @@ open class DefaultOAuth2HttpClient(val restClient: RestClient) : OAuth2HttpClien


override fun post(request: OAuth2HttpRequest) =
restClient.post()
.uri(request.tokenEndpointUrl)
.headers { it.addAll(headers(request)) }
.body(LinkedMultiValueMap<String, String>().apply {
setAll(request.formParameters)
}).retrieve()
.onStatus({ it.isError }) { _, response ->
throw OAuth2ClientException("Received $response.statusCode from $request.tokenEndpointUrl")
}
.body(OAuth2AccessTokenResponse::class.java)

private fun headers(req: OAuth2HttpRequest): HttpHeaders = HttpHeaders().apply { req.oAuth2HttpHeaders?.let { putAll(it.headers) } }
with(request) {
restClient.post()
.uri(tokenEndpointUrl)
.headers { it.addAll(HttpHeaders().apply { putAll(oAuth2HttpHeaders.headers) }) }
.body(LinkedMultiValueMap<String, String>().apply { setAll(formParameters) })
.retrieve()
.onStatus({ it.isError }) { _, res -> throw OAuth2ClientException("Received $res.statusCode from $tokenEndpointUrl") }
.body<OAuth2AccessTokenResponse>() ?: throw OAuth2ClientException("No response from $tokenEndpointUrl")
}

override fun toString() = "$javaClass.simpleName [restClient=$restClient]"
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OAuth2ClientRequestInterceptor(private val properties: ClientConfiguration
private val matcher: ClientConfigurationPropertiesMatcher) : ClientHttpRequestInterceptor {
override fun intercept(req: HttpRequest, body: ByteArray, execution: ClientHttpRequestExecution): ClientHttpResponse {
matcher.findProperties(properties, req.uri)?.let {
service.getAccessToken(it)?.accessToken?.let { token -> req.headers.setBearerAuth(token) }
service.getAccessToken(it).accessToken?.let { token -> req.headers.setBearerAuth(token) }
}
return execution.execute(req, body)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@ import no.nav.security.token.support.core.validation.JwtTokenValidatorFactory.to

open class IssuerConfiguration(val name : String, properties : IssuerProperties, val resourceRetriever : ResourceRetriever = ProxyAwareResourceRetriever()) {

val metadata : AuthorizationServerMetadata
val metadata = providerMetadata(resourceRetriever, properties.discoveryUrl)
val acceptedAudience = properties.acceptedAudience
val headerName = properties.headerName
val tokenValidator : JwtTokenValidator

init {
metadata = providerMetadata(resourceRetriever, properties.discoveryUrl)
tokenValidator = tokenValidator(properties, metadata, resourceRetriever)
}
val tokenValidator = tokenValidator(properties, metadata, resourceRetriever)

override fun toString() = ("${javaClass.simpleName} [name=$name, metaData=$metadata, acceptedAudience=$acceptedAudience, headerName=$headerName, tokenValidator=$tokenValidator, resourceRetriever=$resourceRetriever]")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class IssuerProperties @JvmOverloads constructor(val discoveryUrl : URL,
val usePlaintextForHttps: Boolean = false) {

init {
cookieName?.let { throw IllegalArgumentException("Cookie-support is discontinued, please remove $it from ypur configuration now") }
cookieName?.let { throw IllegalArgumentException("Cookie-support is discontinued, please remove $it from your configuration now") }
}

override fun toString() = "IssuerProperties(discoveryUrl=$discoveryUrl, acceptedAudience=$acceptedAudience, headerName=$headerName, proxyUrl=$proxyUrl, usePlaintextForHttps=$usePlaintextForHttps, validation=$validation, jwksCache=$jwksCache)"
Expand Down
Loading

0 comments on commit 39e1e51

Please sign in to comment.