Skip to content

Commit

Permalink
feat: add cache criteria to cached tool
Browse files Browse the repository at this point in the history
  • Loading branch information
realdavidvega committed Sep 30, 2024
1 parent e5916d6 commit 1258b99
Showing 1 changed file with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,44 @@ abstract class CachedTool<Input, Output>(
private val timeCachePolicy: Duration = 1.days
) : Tool<Input, Output> {

override suspend fun invoke(input: Input): Output {
return cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
}
/**
* Logic to be executed when the cache is missed.
*
* @return the output.
*/
abstract suspend fun onCacheMissed(input: Input): Output

/**
* Criteria to check if the cache should be used for the given [input]. By default, it returns
* true, meaning always use the cache if available.
*
* @return true if the cache should be used.
*/
open fun shouldUseCache(input: Input): Boolean = true

/**
* Criteria to check if the result should be cached based on the given [input] and [output]. By
* default, it returns true, meaning always cache the result.
*
* @return true if the result should be cached.
*/
open fun shouldCacheOutput(input: Input, output: Output): Boolean = true

/**
* Caches the result of [onCacheMissed] if [shouldCacheOutput] returns true. Otherwise, returns
* the result of [onCacheMissed].
*
* @return the output.
*/
override suspend fun invoke(input: Input): Output =
if (shouldUseCache(input)) cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
else onCacheMissed(input)

/**
* Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
* [timeCachePolicy]. Removes expired cache entries.
*
* @return the map of input to output.
*/
suspend fun getCache(): Map<Input, Output> {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
Expand All @@ -42,8 +73,6 @@ abstract class CachedTool<Input, Output>(
return withoutExpired.map { it.key.value to it.value.value }.toMap()
}

abstract suspend fun onCacheMissed(input: Input): Output

private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output {
val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
Expand All @@ -55,7 +84,9 @@ abstract class CachedTool<Input, Output>(
}
}
val response = block()
cache.get()[input] = CachedToolValue(response, timeInMillis())
if (shouldCacheOutput(input.value, response)) {
cache.get()[input] = CachedToolValue(response, timeInMillis())
}
return response
}
}

0 comments on commit 1258b99

Please sign in to comment.