Skip to content

Commit

Permalink
chore: add a draft version of Load
Browse files Browse the repository at this point in the history
  • Loading branch information
maypok86 committed Sep 29, 2024
1 parent 38ea0a6 commit 40ad400
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 1 deletion.
50 changes: 49 additions & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package otter

import (
"context"
"sync"
"time"

Expand Down Expand Up @@ -73,6 +74,7 @@ type Cache[K comparable, V any] struct {
clock *clock.Clock
stripedBuffer *lossy.Striped[K, V]
writeBuffer *queue.Growable[task[K, V]]
singleflight *group[K, V]
evictionMutex sync.Mutex
closeOnce sync.Once
doneClear chan struct{}
Expand Down Expand Up @@ -114,6 +116,7 @@ func newCache[K comparable, V any](b *Builder[K, V]) *Cache[K, V] {
stats: newStatsRecorder(b.statsRecorder),
logger: b.logger,
stripedBuffer: stripedBuffer,
singleflight: &group[K, V]{},
doneClear: make(chan struct{}),
doneClose: make(chan struct{}, 1),
weigher: b.getWeigher(),
Expand Down Expand Up @@ -274,6 +277,7 @@ func (c *Cache[K, V]) set(key K, value V, expiration int64, onlyIfAbsent bool) (
return current
}
// set
c.singleflight.delete(key)
return n
})
if onlyIfAbsent {
Expand Down Expand Up @@ -315,14 +319,57 @@ func (c *Cache[K, V]) afterWrite(n, old node.Node[K, V]) {
c.writeBuffer.Push(newUpdateTask(n, old, cause))
}

func (c *Cache[K, V]) Load(ctx context.Context, loader Loader[K, V], key K) (V, error) {
if value, ok := c.Get(key); ok {
return value, nil
}

c.singleflight.lazyInit()

// node.Node compute?
cl, shouldDo := c.singleflight.startCall(ctx, key)
if shouldDo {
c.singleflight.doCall(cl, loader, c.afterDeleteCall)
}
cl.wait()

if err := cl.err; err != nil {
return zeroValue[V](), err
}

return cl.value, nil
}

func (c *Cache[K, V]) afterDeleteCall(cl *call[K, V]) {
var (
inserted bool
old node.Node[K, V]
)
newNode := c.hashmap.Compute(cl.key, func(oldNode node.Node[K, V]) node.Node[K, V] {
old = oldNode
deleted := c.singleflight.deleteCall(cl)
if !deleted || cl.err != nil {
return oldNode
}
inserted = true
return c.nodeManager.Create(cl.key, cl.value, c.defaultExpiration(), c.weigher(cl.key, cl.value))
})
if inserted {
c.afterWrite(newNode, old)
}
}

// Invalidate discards any cached value for the key.
//
// Returns previous value if any. The invalidated result reports whether the key was
// present.
func (c *Cache[K, V]) Invalidate(key K) (value V, invalidated bool) {
var d node.Node[K, V]
c.hashmap.Compute(key, func(n node.Node[K, V]) node.Node[K, V] {
d = n
if n != nil {
c.singleflight.delete(key)
d = n
}
return nil
})
c.afterDelete(d)
Expand All @@ -339,6 +386,7 @@ func (c *Cache[K, V]) deleteNodeFromMap(n node.Node[K, V]) node.Node[K, V] {
return nil
}
if n.AsPointer() == current.AsPointer() {
c.singleflight.delete(n.Key())
deleted = current
return nil
}
Expand Down
226 changes: 226 additions & 0 deletions singleflight.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright (c) 2024 Alexey Mayshev and contributors. All rights reserved.
// Copyright 2009 The Go Authors. All rights reserved.
//
// Copyright notice. Initial version of the following code was based on
// the following file from the Go Programming Language core repo:
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/container/list/list_test.go
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// That can be found at https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:LICENSE

package otter

import (
"bytes"
"context"
"errors"
"fmt"
"runtime/debug"
"sync"
"sync/atomic"
"unsafe"

"github.com/maypok86/otter/v2/internal/hashmap"
)

// errGoexit indicates the runtime.Goexit was called in
// the user given function.
var errGoexit = errors.New("runtime.Goexit was called")

// A panicError is an arbitrary value recovered from a panic
// with the stack trace during the execution of given function.
type panicError struct {
value any
stack []byte
}

// Error implements error interface.
func (p *panicError) Error() string {
return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
}

func (p *panicError) Unwrap() error {
err, ok := p.value.(error)
if !ok {
return nil
}

return err
}

func newPanicError(v any) error {
stack := debug.Stack()

// The first line of the stack trace is of the form "goroutine N [status]:"
// but by the time the panic reaches Do the goroutine may no longer exist
// and its status will have changed. Trim out the misleading line.
if line := bytes.IndexByte(stack, '\n'); line >= 0 {
stack = stack[line+1:]
}
return &panicError{value: v, stack: stack}
}

type call[K comparable, V any] struct {
key K
value V
err error
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}

func newCall[K comparable, V any](ctx context.Context, key K) *call[K, V] {
ctx, cancel := context.WithCancel(ctx)
c := &call[K, V]{
key: key,
ctx: ctx,
cancel: cancel,
}
c.wg.Add(1)
return c
}

func (c *call[K, V]) Key() K {
return c.key
}

func (c *call[K, V]) Value() V {
return c.value
}

func (c *call[K, V]) AsPointer() unsafe.Pointer {
return unsafe.Pointer(c)
}

func (c *call[K, V]) wait() {
c.wg.Wait()
}

type mapCallManager[K comparable, V any] struct{}

func (m *mapCallManager[K, V]) FromPointer(ptr unsafe.Pointer) *call[K, V] {
return (*call[K, V])(ptr)
}

func (m *mapCallManager[K, V]) IsNil(c *call[K, V]) bool {
return c == nil
}

type group[K comparable, V any] struct {
calls *hashmap.Map[K, V, *call[K, V]]
initMutex sync.Mutex
isInitialized atomic.Bool
}

func (g *group[K, V]) lazyInit() {
if !g.isInitialized.Load() {
g.initMutex.Lock()
if !g.isInitialized.Load() {
g.calls = hashmap.New[K, V, *call[K, V]](&mapCallManager[K, V]{})
g.isInitialized.Store(true)
}
g.initMutex.Unlock()
}
}

func (g *group[K, V]) getCall(key K) *call[K, V] {
return g.calls.Get(key)
}

func (g *group[K, V]) startCall(ctx context.Context, key K) (c *call[K, V], shouldDo bool) {
// fast path
if c := g.getCall(key); c != nil {
return c, shouldDo
}

return g.calls.Compute(key, func(prevCall *call[K, V]) *call[K, V] {
// double check
if prevCall != nil {
return prevCall
}
shouldDo = true
return newCall[K, V](ctx, key)
}), shouldDo
}

func (g *group[K, V]) doCall(c *call[K, V], loader Loader[K, V], afterFinish func(c *call[K, V])) {
normalReturn := false
recovered := false

// use double-defer to distinguish panic from runtime.Goexit,
// more details see https://golang.org/cl/134395
defer func() {
// the given function invoked runtime.Goexit
if !normalReturn && !recovered {
c.err = errGoexit
}

afterFinish(c)

var e *panicError
if errors.As(c.err, &e) {
// In order to prevent the waiting channels from being blocked forever,
// needs to ensure that this panic cannot be recovered.
panic(e)
}
// Normal return
}()

func() {
defer func() {
if !normalReturn {
// Ideally, we would wait to take a stack trace until we've determined
// whether this is a panic or a runtime.Goexit.
//
// Unfortunately, the only way we can distinguish the two is to see
// whether the recover stopped the goroutine from terminating, and by
// the time we know that, the part of the stack trace relevant to the
// panic has been discarded.
if r := recover(); r != nil {
c.err = newPanicError(r)
}
}
}()

c.value, c.err = loader.Load(c.ctx, c.key)
normalReturn = true
}()

if !normalReturn {
recovered = true
}
}

func (g *group[K, V]) deleteCall(c *call[K, V]) (deleted bool) {
// fast path
if got := g.getCall(c.key); got != c {
return false
}

cl := g.calls.Compute(c.key, func(prevCall *call[K, V]) *call[K, V] {
// double check
if prevCall == c {
// delete
c.wg.Done()
return nil
}
return prevCall
})
return cl == nil
}

func (g *group[K, V]) delete(key K) {
if !g.isInitialized.Load() {
return
}

var prev *call[K, V]
g.calls.Compute(key, func(prevCall *call[K, V]) *call[K, V] {
prev = prevCall
return nil
})
if prev != nil {
prev.cancel()
}
}

0 comments on commit 40ad400

Please sign in to comment.