Skip to content

Commit

Permalink
Fix race during concurrent cache entry creation (#1053)
Browse files Browse the repository at this point in the history
* Add a test for racing construction on Cache

* Compute values for cache once and return them on concurrent accesses

When computing a new value to place in the cache, return it from all concurrent GetOrSet
calls.

Also remove ErrCacheItemNotFound, it can no longer be generated.  Cached values (always)
can always outlive their lifetimes in the cache itself!  The actual values returned from
GetOrSet are no longer controlled by the cache.

* Remove ErrCacheItemNotFound: no longer in use

* Remove ChanLocker

No longer needed.

* Fix OnlyOne interface and add tests

* Rebase and fix change tier_fs_test to use OnlyOne

It was using cache.ChanLocker, which is now gone.

* [checks] Wrap the right err

Thanks, golangci, good catch there!
  • Loading branch information
arielshaqed authored Dec 14, 2020
1 parent c6ff255 commit 11b432d
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 161 deletions.
40 changes: 12 additions & 28 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cache

import (
"errors"
"math/rand"
"time"

Expand All @@ -16,49 +15,34 @@ type Cache interface {
}

type GetSetCache struct {
lru *lru.Cache
locker *ChanLocker
jitterFn JitterFn
baseExpiry time.Duration
lru *lru.Cache
computations *ChanOnlyOne
jitterFn JitterFn
baseExpiry time.Duration
}

var (
ErrCacheItemNotFound = errors.New("cache item not found")
)

func NewCache(size int, expiry time.Duration, jitterFn JitterFn) *GetSetCache {
c, _ := lru.New(size)
return &GetSetCache{
lru: c,
locker: NewChanLocker(),
jitterFn: jitterFn,
baseExpiry: expiry,
lru: c,
computations: NewChanOnlyOne(),
jitterFn: jitterFn,
baseExpiry: expiry,
}
}

func (c *GetSetCache) GetOrSet(k interface{}, setFn SetFn) (v interface{}, err error) {
if v, ok := c.lru.Get(k); ok {
return v, nil
}
acquired := c.locker.Lock(k, func() {
return c.computations.Compute(k, func() (interface{}, error) {
v, err = setFn()
if err != nil {
return
if err != nil { // Don't cache errors
return nil, err
}
c.lru.AddEx(k, v, c.baseExpiry+c.jitterFn())
})
if acquired {
return v, err
}

// someone else got the lock first and should have inserted something
if v, ok := c.lru.Get(k); ok {
return v, nil
}

// someone else acquired the lock, but no key was found
// (most likely this value doesn't exist or the upstream fetch failed)
return nil, ErrCacheItemNotFound
})
}

func NewJitterFn(jitter time.Duration) JitterFn {
Expand Down
84 changes: 84 additions & 0 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package cache_test

import (
"sync"
"testing"
"time"

"github.com/treeverse/lakefs/cache"
"github.com/treeverse/lakefs/testutil"
)

func TestCache(t *testing.T) {
const (
n = 200
// Thrash the cache by placing worldSize-1 every even iteration and the
// remaining values ever odd iteration. In particular must have cacheSize <
// worldSize-1.
worldSize = 10
cacheSize = 7
)

c := cache.NewCache(cacheSize, time.Hour*12, cache.NewJitterFn(time.Millisecond))

numCalls := 0
for i := 0; i < n; i++ {
var k int
if i%2 == 0 {
k = worldSize - 1
} else {
k = (i / 2) % (worldSize - 1)
}
actual, err := c.GetOrSet(k, func() (interface{}, error) {
numCalls++
return k * k, nil
})
testutil.MustDo(t, "GetOrSet", err)
if actual.(int) != k*k {
t.Errorf("got %v != %d at %d", actual, k*k, k)
}
}
// Every even call except the first is served from cache; no odd call is ever served
// from cache.
expectedNumCalls := 1 + n/2
if numCalls != expectedNumCalls {
t.Errorf("cache called refill %d times instead of %d", numCalls, expectedNumCalls)
}
}

func TestCacheRace(t *testing.T) {
const (
parallelism = 25
n = 200
worldSize = 10
cacheSize = 7
)

c := cache.NewCache(cacheSize, time.Hour*12, cache.NewJitterFn(time.Millisecond))

start := make(chan struct{})
wg := sync.WaitGroup{}

for i := 0; i < parallelism; i++ {
wg.Add(1)
go func(i int) {
<-start
for j := 0; j < n; j++ {
k := j % worldSize
kk, err := c.GetOrSet(k, func() (interface{}, error) {
return k * k, nil
})
if err != nil {
t.Error(err)
return
}
if kk.(int) != k*k {
t.Errorf("[%d] got %d^2=%d, expected %d", i, k, kk, k*k)
}
}
wg.Done()
}(i)
}
close(start)
wg.Wait()
}
31 changes: 0 additions & 31 deletions cache/lock.go

This file was deleted.

80 changes: 0 additions & 80 deletions cache/lock_test.go

This file was deleted.

40 changes: 40 additions & 0 deletions cache/only_one.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package cache

import "sync"

// OnlyOne ensures only one concurrent evaluation of a keyed expression.
type OnlyOne interface {
// Compute returns the value of calling fn(), but only calls fn once concurrently for
// each k.
Compute(k interface{}, fn func() (interface{}, error)) (interface{}, error)
}

type ChanOnlyOne struct {
m *sync.Map
}

func NewChanOnlyOne() *ChanOnlyOne {
return &ChanOnlyOne{
m: &sync.Map{},
}
}

type chanAndResult struct {
ch chan struct{}
value interface{}
err error
}

func (c *ChanOnlyOne) Compute(k interface{}, fn func() (interface{}, error)) (interface{}, error) {
stop := chanAndResult{ch: make(chan struct{})}
actual, inFlight := c.m.LoadOrStore(k, &stop)
actualStop := actual.(*chanAndResult)
if inFlight {
<-actualStop.ch
} else {
actualStop.value, actualStop.err = fn()
close(actualStop.ch)
c.m.Delete(k)
}
return actualStop.value, actualStop.err
}
89 changes: 89 additions & 0 deletions cache/only_one_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package cache_test

import (
"sync"
"testing"
"time"

"github.com/treeverse/lakefs/cache"
"github.com/treeverse/lakefs/testutil"
)

func TestOnlyOne_ComputeInSequence(t *testing.T) {
const (
one = "foo"
two = "bar"
)
c := cache.NewChanOnlyOne()
first, err := c.Compute("foo", func() (interface{}, error) { return one, nil })
testutil.MustDo(t, "first Compute", err)
second, err := c.Compute("foo", func() (interface{}, error) { return two, nil })
testutil.MustDo(t, "second Compute", err)
if first.(string) != one {
t.Errorf("got first compute %s, expected %s", first, one)
}
if second.(string) != two {
t.Errorf("got second compute %s, expected %s", second, two)
}
}

func TestOnlyOne_ComputeConcurrentlyOnce(t *testing.T) {
c := cache.NewChanOnlyOne()

var wg sync.WaitGroup
wg.Add(3)

ch := make(chan struct{})
did100 := false
go func(didIt *bool) {
defer wg.Done()
value, err := c.Compute("foo", func() (interface{}, error) {
close(ch)
*didIt = true
time.Sleep(time.Millisecond * 100)
return 100, nil
})
if value != 100 || err != nil {
t.Errorf("got %v, %v not 100, nil", value, err)
}
}(&did100)

<-ch // Ensure first computation is in progress

did10 := false
go func(didIt *bool) {
defer wg.Done()
time.Sleep(10 * time.Millisecond)
value, err := c.Compute("foo", func() (interface{}, error) {
*didIt = true
return 101, nil
})
if value != 100 || err != nil {
t.Errorf("got %v, %v not 100, nil", value, err)
}
}(&did10)

did5 := false
go func(didIt *bool) {
defer wg.Done()
time.Sleep(5 * time.Millisecond)
value, err := c.Compute("foo", func() (interface{}, error) {
*didIt = true
return 102, nil
})
if value != 100 || err != nil {
t.Errorf("got %v, %v not 100, nil", value, err)
}
}(&did5)

wg.Wait()
if !did100 {
t.Error("expected to run first concurrent compute and wait 100ms")
}
if did10 {
t.Error("did not expect to run concurrent compute after 10ms")
}
if did5 {
t.Error("did not expect to run concurrent compute after 5ms")
}
}
Loading

0 comments on commit 11b432d

Please sign in to comment.