diff --git a/cache.go b/cache.go index db88d2f..6429fbd 100644 --- a/cache.go +++ b/cache.go @@ -87,6 +87,38 @@ func (c *cache) SetDefault(k string, x interface{}) { c.Set(k, x, DefaultExpiration) } +// Update the expiration on an item in the cache Update the expiration an item +// from the cache. Returns the previous expiration time if one is set (if the +// item never expires a zero value for time.Time is returned), and a +// bool indicating whether the key was found. +func (c *cache) UpdateExpiration(k string, d time.Duration) (time.Time, bool) { + c.mu.Lock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.Unlock() + return time.Time{}, false + } + + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.Unlock() + return time.Time{}, false + } + } + + c.set(k, item.Object, d) + c.mu.Unlock() + + if item.Expiration > 0 { + return time.Unix(0, item.Expiration), true + } + + // If expiration <= 0 (i.e. no expiration time set) then return a zeroed + // time.Time + return time.Time{}, true +} + // Add an item to the cache only if an item doesn't already exist for the given // key, or if the existing item has expired. Returns an error otherwise. func (c *cache) Add(k string, x interface{}, d time.Duration) error { @@ -1159,3 +1191,6 @@ func New(defaultExpiration, cleanupInterval time.Duration) *Cache { func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache { return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) } + + + diff --git a/cache_test.go b/cache_test.go index de3e9d6..f524e60 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1769,3 +1769,98 @@ func TestGetWithExpiration(t *testing.T) { t.Error("expiration for e is in the past") } } + +func TestUpdateExpiration(t *testing.T) { + tc := New(50*time.Millisecond, 1*time.Millisecond) + + tc.Set("a", 1, DefaultExpiration) + x, e1, found := tc.GetWithExpiration("a") + if !found { + t.Error("did not find a") + } + if x != 1 { + t.Error("a should be 1; value:", x) + } + + tc.Increment("a", 1) + x, e2, found := tc.GetWithExpiration("a") + if !found { + t.Error("did not find a2") + } + if x != 2 { + t.Error("a2 should be 2; value:", x) + } + if e1 != e2 { + t.Error("expiration changed when it should not have") + } + + <-time.After(5*time.Millisecond) + tc.Increment("a", 1) + e3, found := tc.UpdateExpiration("a", DefaultExpiration) + if !found { + t.Error("did not find a3") + } + if e1 != e3 { + t.Error("did not get previous expiration back; e1:", e1, "e3:", e3) + } + x, e3, found = tc.GetWithExpiration("a") + if !found { + t.Error("did not find a3") + } + if x != 3 { + t.Error("a3 should be 3; value:", x) + } + if e1 == e3 { + t.Error("should have different expiration but did not change") + } + if e3.UnixNano() != tc.items["a"].Expiration { + t.Error("expiration for a3 is not the correct time") + } + if e3.UnixNano() < time.Now().UnixNano() { + t.Error("expiration for a3 is in the past") + } + + _, found = tc.UpdateExpiration("a", NoExpiration) + if !found { + t.Error("did not find a4") + } + _, e, found := tc.GetWithExpiration("a") + if !found { + t.Error("did not find a4 a second time") + } + if !e.IsZero() { + t.Error("expiration for a4 is not a zeroed time") + } + + _, found = tc.UpdateExpiration("a", 50*time.Millisecond) + if !found { + t.Error("did not find a5") + } + _, e, found = tc.GetWithExpiration("a") + if !found { + t.Error("did not find a5 a second time") + } + if e.UnixNano() != tc.items["a"].Expiration { + t.Error("expiration for a5 is not the correct time") + } + if e.UnixNano() < time.Now().UnixNano() { + t.Error("expiration for a5 is in the past") + } + + <-time.After(49*time.Millisecond) + _, found = tc.Get("a") + if !found { + t.Error("did not find a6") + } + + <-time.After(2*time.Millisecond) + _, found = tc.Get("a") + if found { + t.Error("found a7, but it should have expired") + } + + _, found = tc.UpdateExpiration("a", DefaultExpiration) + if found { + t.Error("found a8, but it should have expired") + } +}