Skip to content

Commit

Permalink
simplify cache, perform refresh in provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis committed Feb 20, 2024
1 parent 7e2cffc commit ec1dabb
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import aws.smithy.kotlin.runtime.collections.Attributes
import aws.smithy.kotlin.runtime.collections.get
import aws.smithy.kotlin.runtime.io.SdkManagedBase
import aws.smithy.kotlin.runtime.telemetry.logging.Logger
import aws.smithy.kotlin.runtime.telemetry.TelemetryProvider
import aws.smithy.kotlin.runtime.telemetry.logging.getLogger
import aws.smithy.kotlin.runtime.time.Clock
import aws.smithy.kotlin.runtime.time.until
Expand All @@ -24,41 +24,50 @@ import kotlin.time.Duration.Companion.seconds
import kotlin.time.TimeMark
import kotlin.time.TimeSource

/**
* The duration before expiration that [Credentials] are considered expired
*/
internal val REFRESH_BUFFER = 1.minutes

/**
* How long to wait between cache refresh attempts if no [Credentials] are in the cache
*/
private val DEFAULT_REFRESH_PERIOD = 3.minutes

private const val CREDENTIALS_PROVIDER_NAME = "DefaultS3ExpressCredentialsProvider"

/**
* The default implementation of [S3ExpressCredentialsProvider]
* @param timeSource the time source to use. defaults to [TimeSource.Monotonic]
* @param clock the clock to use. defaults to [Clock.System]. note: the clock is only used to get an initial [Duration]
* until credentials expiration.
*/
internal class DefaultS3ExpressCredentialsProvider(
private val timeSource: TimeSource = TimeSource.Monotonic,
private val clock: Clock = Clock.System,
) : S3ExpressCredentialsProvider, SdkManagedBase(), CoroutineScope {
private lateinit var client: S3Client
private lateinit var logger: Logger
private val credentialsCache = S3ExpressCredentialsCache()

override val coroutineContext: CoroutineContext = Job() +
CoroutineName(CREDENTIALS_PROVIDER_NAME)
override val coroutineContext: CoroutineContext = Job() + CoroutineName(CREDENTIALS_PROVIDER_NAME)

init {
launch(coroutineContext) {
refresh()
}
}

@OptIn(ExperimentalApi::class)
override suspend fun resolve(attributes: Attributes): Credentials {
client = (attributes[S3Attributes.ExpressClient] as S3Client)
logger = client.config.telemetryProvider.loggerProvider.getLogger<S3ExpressCredentialsProvider>()
client = attributes[S3Attributes.ExpressClient] as S3Client

val key = S3ExpressCredentialsCacheKey(attributes[S3Attributes.Bucket], client.config.credentialsProvider.resolve(attributes))

return credentialsCache.get(key)?.takeIf { !it.isExpired }?.value
?: createSessionCredentials(key.bucket).also { credentialsCache.put(key, it) }.value
return credentialsCache.get(key)?.expiringCredentials?.takeIf { !it.isExpired }?.value
?: createSessionCredentials(key.bucket).also { credentialsCache.put(key, S3ExpressCredentialsCacheValue(it, usedSinceLastRefresh = true)) }.value
}

override fun close() = coroutineContext.cancel(null)


/**
* Attempt to refresh the credentials in the cache. A refresh is initiated when the `nextRefresh` time has been reached,
* which is either `DEFAULT_REFRESH_PERIOD` or the soonest credentials expiration time (minus a buffer), whichever comes first.
Expand Down Expand Up @@ -96,7 +105,7 @@ internal class DefaultS3ExpressCredentialsProvider(

try {
val refreshed = async { createSessionCredentials(entry.key.bucket) }.await()
credentialsCache.put(entry.key, refreshed, false)
credentialsCache.put(entry.key, S3ExpressCredentialsCacheValue(refreshed, usedSinceLastRefresh = false))
} catch (e: Exception) {
logger.warn(e) { "Failed to refresh credentials for ${entry.key.bucket}" }
}
Expand All @@ -106,8 +115,8 @@ internal class DefaultS3ExpressCredentialsProvider(

// Find the next expiring credentials, sleep until then
val nextExpiringEntry = entries.maxByOrNull {
// note: `expiresAt` is always in the future, which means the `elapsedNow` values are negative.
// that's the reason `maxBy` is used instead of `minBy`
// note: `expiresAt` is a future time, which means the `elapsedNow` values are negative
// and count up until expiration at t=0. that's why `maxBy` is used instead of `minBy`
it.value.expiringCredentials.expiresAt.elapsedNow()
}

Expand Down Expand Up @@ -135,8 +144,16 @@ internal class DefaultS3ExpressCredentialsProvider(
expirationTimeMark,
)
}

@OptIn(ExperimentalApi::class)
internal val logger get() = if (this::client.isInitialized) {
client.config.telemetryProvider.loggerProvider.getLogger<DefaultS3ExpressCredentialsProvider>()
} else {
TelemetryProvider.None.loggerProvider.getLogger<DefaultS3ExpressCredentialsProvider>()
}
}


/**
* Get the [Duration] between [this] TimeMark and an [other] TimeMark
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,17 @@ private const val DEFAULT_S3_EXPRESS_CACHE_SIZE: Int = 100
internal class S3ExpressCredentialsCache {
private val lru = LruCache<S3ExpressCredentialsCacheKey, S3ExpressCredentialsCacheValue>(DEFAULT_S3_EXPRESS_CACHE_SIZE)

suspend fun get(key: S3ExpressCredentialsCacheKey): ExpiringValue<Credentials>? = lru.get(key)?.expiringCredentials
suspend fun get(key: S3ExpressCredentialsCacheKey): S3ExpressCredentialsCacheValue? = lru.get(key)

suspend fun put(key: S3ExpressCredentialsCacheKey, value: ExpiringValue<Credentials>, usedSinceLastRefresh: Boolean = true) =
lru.put(key, S3ExpressCredentialsCacheValue(value, usedSinceLastRefresh))
suspend fun put(key: S3ExpressCredentialsCacheKey, value: S3ExpressCredentialsCacheValue) = lru.put(key, value)

suspend fun remove(key: S3ExpressCredentialsCacheKey) : ExpiringValue<Credentials>? =
lru.remove(key)?.expiringCredentials
suspend fun remove(key: S3ExpressCredentialsCacheKey) : S3ExpressCredentialsCacheValue? = lru.remove(key)

public val size: Int
get() = lru.size

public val entries: Set<Map.Entry<S3ExpressCredentialsCacheKey, S3ExpressCredentialsCacheValue>>
get() = lru.entries

// suspend fun get(key: S3ExpressCredentialsCacheKey): Credentials = lru
// .get(key)
// ?.takeIf { !it.expiringCredentials.isExpired }
// ?.let {
// it.usedSinceLastRefresh = true
// it.expiringCredentials.value
// }
// ?: (createSessionCredentials(key.bucket).also { put(key, it) }).value


}

internal data class S3ExpressCredentialsCacheKey(
Expand All @@ -58,7 +45,7 @@ internal data class S3ExpressCredentialsCacheKey(

internal data class S3ExpressCredentialsCacheValue(
val expiringCredentials: ExpiringValue<Credentials>,
var usedSinceLastRefresh: Boolean,
var usedSinceLastRefresh: Boolean = false,
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ package aws.sdk.kotlin.services.s3.express

import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.*
import kotlin.time.TestTimeSource
import kotlin.time.Duration.Companion.seconds

public class S3ExpressCredentialsCacheTest {
@Test
Expand All @@ -21,4 +22,39 @@ public class S3ExpressCredentialsCacheTest {

assertEquals(key1, key2)
}

@Test
fun testCacheOperations() = runTest {
val cache = S3ExpressCredentialsCache()

val bucket = "bucket"
val bootstrapCredentials = Credentials("accessKeyId", "secretAccessKey", "sessionToken")
val key = S3ExpressCredentialsCacheKey(bucket, bootstrapCredentials)

val sessionCredentials = Credentials("superFastAccessKey", "superSecretSecretKey", "s3SessionToken")
val expiringSessionCredentials = ExpiringValue(sessionCredentials, TestTimeSource().markNow())
val value = S3ExpressCredentialsCacheValue(expiringSessionCredentials)

cache.put(key, value) // put
assertEquals(expiringSessionCredentials, cache.get(key)?.expiringCredentials) // get
assertEquals(1, cache.size) // size
assertContains(cache.entries.map { it.key }, key) // entries
assertContains(cache.entries.map { it.value }, value) // entries

cache.remove(key)
assertEquals(0, cache.size)
assertNull(cache.get(key))
}

@Test
fun testIsExpired() = runTest {
val timeSource = TestTimeSource()

val sessionCredentials = Credentials("superFastAccessKey", "superSecretSecretKey", "s3SessionToken")
val expiringSessionCredentials = ExpiringValue(sessionCredentials, timeSource.markNow())

assertFalse(expiringSessionCredentials.isExpired)
timeSource += 1.seconds
assertTrue(expiringSessionCredentials.isExpired)
}
}

0 comments on commit ec1dabb

Please sign in to comment.