diff --git a/cache/cache.go b/cache/cache.go index 24b2e187fce..c1b6861e6c8 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,7 +1,6 @@ package cache import ( - "errors" "math/rand" "time" @@ -16,23 +15,19 @@ type Cache interface { } type GetSetCache struct { - lru *lru.Cache - locker *ChanLocker - jitterFn JitterFn - baseExpiry time.Duration + lru *lru.Cache + computations *ChanOnlyOne + jitterFn JitterFn + baseExpiry time.Duration } -var ( - ErrCacheItemNotFound = errors.New("cache item not found") -) - func NewCache(size int, expiry time.Duration, jitterFn JitterFn) *GetSetCache { c, _ := lru.New(size) return &GetSetCache{ - lru: c, - locker: NewChanLocker(), - jitterFn: jitterFn, - baseExpiry: expiry, + lru: c, + computations: NewChanOnlyOne(), + jitterFn: jitterFn, + baseExpiry: expiry, } } @@ -40,25 +35,14 @@ func (c *GetSetCache) GetOrSet(k interface{}, setFn SetFn) (v interface{}, err e if v, ok := c.lru.Get(k); ok { return v, nil } - acquired := c.locker.Lock(k, func() { + return c.computations.Compute(k, func() (interface{}, error) { v, err = setFn() - if err != nil { - return + if err != nil { // Don't cache errors + return nil, err } c.lru.AddEx(k, v, c.baseExpiry+c.jitterFn()) - }) - if acquired { - return v, err - } - - // someone else got the lock first and should have inserted something - if v, ok := c.lru.Get(k); ok { return v, nil - } - - // someone else acquired the lock, but no key was found - // (most likely this value doesn't exist or the upstream fetch failed) - return nil, ErrCacheItemNotFound + }) } func NewJitterFn(jitter time.Duration) JitterFn { diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 00000000000..bdbe7d9edc0 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,84 @@ +package cache_test + +import ( + "sync" + "testing" + "time" + + "github.com/treeverse/lakefs/cache" + "github.com/treeverse/lakefs/testutil" +) + +func TestCache(t *testing.T) { + const ( + n = 200 + // Thrash the cache by placing worldSize-1 every even iteration and the + // remaining values ever odd iteration. In particular must have cacheSize < + // worldSize-1. + worldSize = 10 + cacheSize = 7 + ) + + c := cache.NewCache(cacheSize, time.Hour*12, cache.NewJitterFn(time.Millisecond)) + + numCalls := 0 + for i := 0; i < n; i++ { + var k int + if i%2 == 0 { + k = worldSize - 1 + } else { + k = (i / 2) % (worldSize - 1) + } + actual, err := c.GetOrSet(k, func() (interface{}, error) { + numCalls++ + return k * k, nil + }) + testutil.MustDo(t, "GetOrSet", err) + if actual.(int) != k*k { + t.Errorf("got %v != %d at %d", actual, k*k, k) + } + } + // Every even call except the first is served from cache; no odd call is ever served + // from cache. + expectedNumCalls := 1 + n/2 + if numCalls != expectedNumCalls { + t.Errorf("cache called refill %d times instead of %d", numCalls, expectedNumCalls) + } +} + +func TestCacheRace(t *testing.T) { + const ( + parallelism = 25 + n = 200 + worldSize = 10 + cacheSize = 7 + ) + + c := cache.NewCache(cacheSize, time.Hour*12, cache.NewJitterFn(time.Millisecond)) + + start := make(chan struct{}) + wg := sync.WaitGroup{} + + for i := 0; i < parallelism; i++ { + wg.Add(1) + go func(i int) { + <-start + for j := 0; j < n; j++ { + k := j % worldSize + kk, err := c.GetOrSet(k, func() (interface{}, error) { + return k * k, nil + }) + if err != nil { + t.Error(err) + return + } + if kk.(int) != k*k { + t.Errorf("[%d] got %d^2=%d, expected %d", i, k, kk, k*k) + } + } + wg.Done() + }(i) + } + close(start) + wg.Wait() +} diff --git a/cache/lock.go b/cache/lock.go deleted file mode 100644 index ad1ac37d06b..00000000000 --- a/cache/lock.go +++ /dev/null @@ -1,31 +0,0 @@ -package cache - -import "sync" - -type Locker interface { - Lock(v interface{}, onAcquireFn func()) -} - -type ChanLocker struct { - m *sync.Map -} - -func NewChanLocker() *ChanLocker { - return &ChanLocker{ - m: &sync.Map{}, - } -} - -func (c ChanLocker) Lock(v interface{}, onAcquireFn func()) (acquired bool) { - tid := make(chan struct{}) - actual, alreadyLocked := c.m.LoadOrStore(v, tid) - if !alreadyLocked { - onAcquireFn() - c.m.Delete(v) - close(tid) - return true - } - - <-actual.(chan struct{}) - return false -} diff --git a/cache/lock_test.go b/cache/lock_test.go deleted file mode 100644 index 2bef80d370d..00000000000 --- a/cache/lock_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package cache_test - -import ( - "sync" - "testing" - "time" - - "github.com/treeverse/lakefs/cache" -) - -func TestChanLocker_LockAfterLock(t *testing.T) { - c := cache.NewChanLocker() - acq := c.Lock("foo", func() {}) - if !acq { - t.Fatalf("expected first lock to acquire") - } - - acq = c.Lock("foo", func() {}) - if !acq { - t.Fatalf("expected second lock to acquire") - } -} - -func TestChanLocker_Lock(t *testing.T) { - c := cache.NewChanLocker() - - var wg sync.WaitGroup - wg.Add(3) - - var foo100, getFoo100 bool - ch := make(chan bool) - go func(acq *bool, getter *bool, ch chan bool) { - close(ch) - defer wg.Done() - *acq = c.Lock("foo", func() { - *getter = true - time.Sleep(time.Millisecond * 100) - }) - }(&foo100, &getFoo100, ch) - <-ch // wait until goroutine starts - - var foo10, getFoo10 bool - go func(acq *bool, getter *bool) { - defer wg.Done() - time.Sleep(10 * time.Millisecond) - *acq = c.Lock("foo", func() { - *getter = true - }) - }(&foo10, &getFoo10) - - var bar10 bool - var getBar10 bool - go func(acq *bool, getter *bool) { - defer wg.Done() - time.Sleep(10 * time.Millisecond) - *acq = c.Lock("bar", func() { - *getter = true - }) - }(&bar10, &getBar10) - - wg.Wait() - if foo10 { - t.Error("expected to not acquire foo after 10ms") - } - if getFoo10 { - t.Error("expected foo (10ms) getter not to be called") - } - if !getFoo100 { - t.Error("expected foo (100ms) getter to be called") - } - if !foo100 { - t.Error("expected to acquire foo after 100ms") - } - if !getBar10 { - t.Error("expected bar getter to be called") - } - if !bar10 { - t.Error("expected to acquire bar") - } -} diff --git a/cache/only_one.go b/cache/only_one.go new file mode 100644 index 00000000000..999be06054b --- /dev/null +++ b/cache/only_one.go @@ -0,0 +1,40 @@ +package cache + +import "sync" + +// OnlyOne ensures only one concurrent evaluation of a keyed expression. +type OnlyOne interface { + // Compute returns the value of calling fn(), but only calls fn once concurrently for + // each k. + Compute(k interface{}, fn func() (interface{}, error)) (interface{}, error) +} + +type ChanOnlyOne struct { + m *sync.Map +} + +func NewChanOnlyOne() *ChanOnlyOne { + return &ChanOnlyOne{ + m: &sync.Map{}, + } +} + +type chanAndResult struct { + ch chan struct{} + value interface{} + err error +} + +func (c *ChanOnlyOne) Compute(k interface{}, fn func() (interface{}, error)) (interface{}, error) { + stop := chanAndResult{ch: make(chan struct{})} + actual, inFlight := c.m.LoadOrStore(k, &stop) + actualStop := actual.(*chanAndResult) + if inFlight { + <-actualStop.ch + } else { + actualStop.value, actualStop.err = fn() + close(actualStop.ch) + c.m.Delete(k) + } + return actualStop.value, actualStop.err +} diff --git a/cache/only_one_test.go b/cache/only_one_test.go new file mode 100644 index 00000000000..3b352725ae9 --- /dev/null +++ b/cache/only_one_test.go @@ -0,0 +1,89 @@ +package cache_test + +import ( + "sync" + "testing" + "time" + + "github.com/treeverse/lakefs/cache" + "github.com/treeverse/lakefs/testutil" +) + +func TestOnlyOne_ComputeInSequence(t *testing.T) { + const ( + one = "foo" + two = "bar" + ) + c := cache.NewChanOnlyOne() + first, err := c.Compute("foo", func() (interface{}, error) { return one, nil }) + testutil.MustDo(t, "first Compute", err) + second, err := c.Compute("foo", func() (interface{}, error) { return two, nil }) + testutil.MustDo(t, "second Compute", err) + if first.(string) != one { + t.Errorf("got first compute %s, expected %s", first, one) + } + if second.(string) != two { + t.Errorf("got second compute %s, expected %s", second, two) + } +} + +func TestOnlyOne_ComputeConcurrentlyOnce(t *testing.T) { + c := cache.NewChanOnlyOne() + + var wg sync.WaitGroup + wg.Add(3) + + ch := make(chan struct{}) + did100 := false + go func(didIt *bool) { + defer wg.Done() + value, err := c.Compute("foo", func() (interface{}, error) { + close(ch) + *didIt = true + time.Sleep(time.Millisecond * 100) + return 100, nil + }) + if value != 100 || err != nil { + t.Errorf("got %v, %v not 100, nil", value, err) + } + }(&did100) + + <-ch // Ensure first computation is in progress + + did10 := false + go func(didIt *bool) { + defer wg.Done() + time.Sleep(10 * time.Millisecond) + value, err := c.Compute("foo", func() (interface{}, error) { + *didIt = true + return 101, nil + }) + if value != 100 || err != nil { + t.Errorf("got %v, %v not 100, nil", value, err) + } + }(&did10) + + did5 := false + go func(didIt *bool) { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + value, err := c.Compute("foo", func() (interface{}, error) { + *didIt = true + return 102, nil + }) + if value != 100 || err != nil { + t.Errorf("got %v, %v not 100, nil", value, err) + } + }(&did5) + + wg.Wait() + if !did100 { + t.Error("expected to run first concurrent compute and wait 100ms") + } + if did10 { + t.Error("did not expect to run concurrent compute after 10ms") + } + if did5 { + t.Error("did not expect to run concurrent compute after 5ms") + } +} diff --git a/catalog/mvcc/cataloger_cache.go b/catalog/mvcc/cataloger_cache.go index 4e256c6498a..ba32d4121e1 100644 --- a/catalog/mvcc/cataloger_cache.go +++ b/catalog/mvcc/cataloger_cache.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" - "github.com/treeverse/lakefs/cache" "github.com/treeverse/lakefs/catalog" "github.com/treeverse/lakefs/db" ) @@ -17,9 +16,6 @@ func (c *cataloger) getRepositoryCache(tx db.Tx, repository string) (*catalog.Re } return repo, nil }) - if errors.Is(err, cache.ErrCacheItemNotFound) { - return repo, catalog.ErrRepositoryNotFound - } return repo, err } @@ -31,9 +27,6 @@ func (c *cataloger) getRepositoryIDCache(tx db.Tx, repository string) (int, erro } return repoID, nil }) - if errors.Is(err, cache.ErrCacheItemNotFound) { - return repoID, catalog.ErrRepositoryNotFound - } return repoID, err } @@ -46,7 +39,7 @@ func (c *cataloger) getBranchIDCache(tx db.Tx, repository string, branch string) return branchID, nil }) - if !(errors.Is(err, cache.ErrCacheItemNotFound) || errors.Is(err, db.ErrNotFound)) { + if !errors.Is(err, db.ErrNotFound) { return branchID, err } diff --git a/pyramid/tier_fs.go b/pyramid/tier_fs.go index 8c963a11240..087ed2b1c31 100644 --- a/pyramid/tier_fs.go +++ b/pyramid/tier_fs.go @@ -26,7 +26,7 @@ type TierFS struct { adaptor block.Adapter eviction eviction - keyLock *cache.ChanLocker + keyLock cache.OnlyOne syncDir *directory fsName string @@ -73,7 +73,7 @@ func NewFS(c *Config) (FS, error) { logger: c.logger, fsLocalBaseDir: fsLocalBaseDir, syncDir: &directory{ceilingDir: fsLocalBaseDir}, - keyLock: cache.NewChanLocker(), + keyLock: cache.NewChanOnlyOne(), remotePrefix: path.Join(c.fsBlockStoragePrefix, c.fsName), } eviction, err := newLRUSizeEviction(c.allocatedDiskBytes, tierFS.removeFromLocal) @@ -242,35 +242,33 @@ func (tfs *TierFS) openFile(fileRef localFileRef, fh *os.File) (*ROFile, error) // and places it in the local FS for further reading. // It returns a file handle to the local file. func (tfs *TierFS) readFromBlockStorage(fileRef localFileRef) (*os.File, error) { - var e error - tfs.keyLock.Lock(fileRef.filename, func() { + _, err := tfs.keyLock.Compute(fileRef.filename, func() (interface{}, error) { + var err error reader, err := tfs.adaptor.Get(tfs.objPointer(fileRef.namespace, fileRef.filename), 0) if err != nil { - e = fmt.Errorf("read from block storage: %w", err) - return + return nil, fmt.Errorf("read from block storage: %w", err) } defer reader.Close() writer, err := tfs.syncDir.createFile(fileRef.fullPath) if err != nil { - e = fmt.Errorf("creating file: %w", err) - return + return nil, fmt.Errorf("creating file: %w", err) } written, err := io.Copy(writer, reader) if err != nil { - e = fmt.Errorf("copying date to file: %w", err) - return + return nil, fmt.Errorf("copying date to file: %w", err) } - if err := writer.Close(); err != nil { - e = fmt.Errorf("writer close: %w", err) + if err = writer.Close(); err != nil { + err = fmt.Errorf("writer close: %w", err) } downloadHistograms.WithLabelValues(tfs.fsName).Observe(float64(written)) + return nil, err }) - if e != nil { - return nil, e + if err != nil { + return nil, err } fh, err := os.Open(fileRef.fullPath)