diff --git a/cache.go b/cache.go index 812a4d3..23d66af 100644 --- a/cache.go +++ b/cache.go @@ -70,6 +70,11 @@ func (bs baseCache[K, V]) Delete(key K) { bs.cache.Delete(key) } +// DeleteByFunc removes the association for this key from the cache when the given function returns true. +func (bs baseCache[K, V]) DeleteByFunc(f func(key K, value V) bool) { + bs.cache.DeleteByFunc(f) +} + // Range iterates over all items in the cache. // // Iteration stops early when the given function returns false. diff --git a/cache_test.go b/cache_test.go index ba8961d..7a280bf 100644 --- a/cache_test.go +++ b/cache_test.go @@ -182,6 +182,29 @@ func TestCache_SetWithTTL(t *testing.T) { } } +func TestBaseCache_DeleteByFunc(t *testing.T) { + size := 256 + c, err := MustBuilder[int, int](size).WithTTL(time.Hour).Build() + if err != nil { + t.Fatalf("can not create builder: %v", err) + } + + for i := 0; i < size; i++ { + c.Set(i, i) + } + + c.DeleteByFunc(func(key int, value int) bool { + return key%2 == 1 + }) + + c.Range(func(key int, value int) bool { + if key%2 == 1 { + t.Fatalf("key should be odd, but got: %d", key) + } + return true + }) +} + func TestCache_Ratio(t *testing.T) { c, err := MustBuilder[uint64, uint64](100).CollectStats().Build() if err != nil { diff --git a/internal/core/cache.go b/internal/core/cache.go index d5f7c98..85c4e73 100644 --- a/internal/core/cache.go +++ b/internal/core/cache.go @@ -232,6 +232,22 @@ func (c *Cache[K, V]) Delete(key K) { } } +// DeleteByFunc removes the association for this key from the cache when the given function returns true. +func (c *Cache[K, V]) DeleteByFunc(f func(key K, value V) bool) { + // TODO(maypok86): This function can be implemented more efficiently, if the performance of this implementation is not enough for you, then come with an issue :) + var keysToDelete []K + c.Range(func(key K, value V) bool { + if f(key, value) { + keysToDelete = append(keysToDelete, key) + } + return true + }) + + for _, key := range keysToDelete { + c.Delete(key) + } +} + func (c *Cache[K, V]) cleanup() { expired := make([]*node.Node[K, V], 0, 128) for {