Skip to content

Commit

Permalink
Asynchronous credentials refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis committed Jan 25, 2024
1 parent 840f62e commit 441cf26
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolReference
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.integration.AppendingSectionWriter
import software.amazon.smithy.kotlin.codegen.integration.AuthSchemeHandler
import software.amazon.smithy.kotlin.codegen.integration.SectionWriterBinding
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.hasTrait
Expand Down Expand Up @@ -40,6 +42,14 @@ class SigV4S3ExpressAuthSchemeIntegration : SigV4AuthSchemeIntegration() {

override fun customizeEndpointResolution(ctx: ProtocolGenerator.GenerationContext): EndpointCustomization = SigV4S3ExpressEndpointCustomization

override val sectionWriters: List<SectionWriterBinding>
get() = listOf(SectionWriterBinding(HttpProtocolClientGenerator.ClientInitializer, renderClientInitializer))

// add S3 Express credentials provider to managed resources in the service client initializer
private val renderClientInitializer = AppendingSectionWriter { writer ->
writer.write("managedResources.addIfManaged(config.s3ExpressCredentialsProvider)")
}

override fun authSchemes(ctx: ProtocolGenerator.GenerationContext): List<AuthSchemeHandler> = listOf(SigV4S3ExpressAuthSchemeHandler())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,117 @@ import aws.smithy.kotlin.runtime.client.SdkClient
import aws.smithy.kotlin.runtime.collections.LruCache
import aws.smithy.kotlin.runtime.time.Clock
import aws.smithy.kotlin.runtime.util.ExpiringValue
import aws.smithy.kotlin.runtime.io.Closeable
import aws.smithy.kotlin.runtime.time.Instant
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.selects.*
import kotlin.coroutines.coroutineContext
import aws.sdk.kotlin.services.s3.*
import aws.sdk.kotlin.services.s3.model.CreateSessionRequest
import aws.sdk.kotlin.runtime.auth.credentials.internal.credentials
import kotlin.time.Duration.Companion.minutes
import kotlin.coroutines.CoroutineContext
import aws.smithy.kotlin.runtime.telemetry.logging.logger
import kotlin.time.Duration
import kotlin.time.*
import aws.smithy.kotlin.runtime.time.until

private const val DEFAULT_S3_EXPRESS_CACHE_SIZE: Int = 100
private val REFRESH_BUFFER = 1.minutes
private val DEFAULT_REFRESH_PERIOD = 5.minutes

public class S3ExpressCredentialsCache(
private val clock: Clock = Clock.System,
) {
) : CoroutineScope, Closeable {
override val coroutineContext: CoroutineContext = Job() + CoroutineName("S3ExpressCredentialsCacheRefresh")

private val lru = LruCache<S3ExpressCredentialsCacheKey, ExpiringValue<Credentials>>(DEFAULT_S3_EXPRESS_CACHE_SIZE)
private val immediateRefreshChannel = Channel<Unit>(Channel.CONFLATED) // channel used to indicate an immediate refresh attempt is required

init {
launch(coroutineContext) {
refresh()
}
}

public suspend fun get(key: S3ExpressCredentialsCacheKey): Credentials = lru.get(key)?.value
?: (createSessionCredentials(key).also { put(key, it) }).value

public suspend fun put(key: S3ExpressCredentialsCacheKey, value: ExpiringValue<Credentials>): Unit {
lru.put(key, value)
immediateRefreshChannel.send(Unit)
}

private suspend fun refresh(): Unit {
val logger = coroutineContext.logger<S3ExpressCredentialsCache>()
while (isActive) {
logger.trace { "Looping..." }
println("Looping...")
val refreshedCredentials = mutableMapOf<S3ExpressCredentialsCacheKey, ExpiringValue<Credentials>>()
var nextRefresh: Instant = clock.now() + DEFAULT_REFRESH_PERIOD

lru.withLock {
lru.entries.forEach { (key, cachedValue) ->
logger.trace { "Checking entry for ${key.bucket}" }
println("Checking entry for ${key.bucket}")
nextRefresh = minOf(nextRefresh, cachedValue.expiresAt)

public suspend fun get(key: S3ExpressCredentialsCacheKey): Credentials? = (
lru.get(key)
?.takeIf { it.expiresAt > clock.now() }
)?.value
if ((clock.now().until(cachedValue.expiresAt)).absoluteValue <= REFRESH_BUFFER) {
logger.trace { "Credentials for ${key.bucket} expire within the refresh buffer period, performing a refresh..." }
println("Credentials for ${key.bucket} expire within the refresh buffer period, performing a refresh...")
createSessionCredentials(key).also {
refreshedCredentials.put(key, it)
nextRefresh = minOf(nextRefresh, it.expiresAt)
}
}
}

public suspend fun put(key: S3ExpressCredentialsCacheKey, value: ExpiringValue<Credentials>): Unit = lru.put(key, value)
refreshedCredentials.forEach { (key, value) ->
lru.remove(key)
lru.putUnlocked(key, value)
}
}

// wake up when it's time to refresh or an immediate refresh has been triggered
select<Unit> {
onTimeout(clock.now().until(nextRefresh)) {
logger.trace { "Woke up from timeout" }
println("Woke up from timeout")
}
immediateRefreshChannel.onReceive {
logger.trace { "Woke up from channel" }
println("Woke up from channel")
}
}
}
}

private suspend fun createSessionCredentials(key: S3ExpressCredentialsCacheKey): ExpiringValue<Credentials> {
val logger = coroutineContext.logger<S3ExpressCredentialsCache>()

val credentials = (key.client as S3Client).createSession(
CreateSessionRequest {
bucket = key.bucket
},
).credentials!!

return ExpiringValue(
credentials(
accessKeyId = credentials.accessKeyId,
secretAccessKey = credentials.secretAccessKey,
sessionToken = credentials.sessionToken,
expiration = credentials.expiration,
providerName = "S3ExpressCredentialsProvider",
),
credentials.expiration,
).also { logger.debug { "got credentials ${it.value}" } }
}

override fun close(): Unit {
coroutineContext.cancel(null)
immediateRefreshChannel.close()
}
}

public class S3ExpressCredentialsCacheKey(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,31 @@ import aws.sdk.kotlin.services.s3.*
import aws.sdk.kotlin.services.s3.model.CreateSessionRequest
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider
import aws.smithy.kotlin.runtime.auth.awscredentials.CloseableCredentialsProvider
import aws.smithy.kotlin.runtime.collections.Attributes
import aws.smithy.kotlin.runtime.collections.get
import aws.smithy.kotlin.runtime.telemetry.logging.logger
import aws.smithy.kotlin.runtime.util.ExpiringValue
import kotlin.coroutines.coroutineContext
import aws.smithy.kotlin.runtime.io.SdkManagedBase

public class S3ExpressCredentialsProvider(
public val bootstrapCredentialsProvider: CredentialsProvider,
) : CredentialsProvider {
) : CloseableCredentialsProvider, SdkManagedBase() {
private val credentialsCache = S3ExpressCredentialsCache()

override suspend fun resolve(attributes: Attributes): Credentials {
val logger = coroutineContext.logger<S3ExpressCredentialsProvider>()

val bucket: String = attributes[S3ExpressAttributes.Bucket]
val client = attributes[S3ExpressAttributes.Client]

val key = S3ExpressCredentialsCacheKey(bucket, client, bootstrapCredentialsProvider.resolve(attributes))

return credentialsCache.get(key)
?: (createSessionCredentials(key).also { credentialsCache.put(key, it) }).value
return credentialsCache.get(key).also { logger.trace { "Got credentials $it from cache" }}
}

private suspend fun createSessionCredentials(key: S3ExpressCredentialsCacheKey): ExpiringValue<Credentials> {
val logger = coroutineContext.logger<S3ExpressCredentialsCache>()

val credentials = (key.client as S3Client).createSession(
CreateSessionRequest {
bucket = key.bucket
},
).credentials!!

return ExpiringValue(
credentials(
accessKeyId = credentials.accessKeyId,
secretAccessKey = credentials.secretAccessKey,
sessionToken = credentials.sessionToken,
expiration = credentials.expiration,
providerName = "S3ExpressCredentialsProvider",
),
credentials.expiration,
).also { logger.debug { "got credentials ${it.value}" } }
override fun close() {
credentialsCache.close()
}
}

0 comments on commit 441cf26

Please sign in to comment.