Skip to content

Commit

Permalink
fn: Remove ctx from GoroutineManager constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Dec 9, 2024
1 parent c0d2d1c commit 0f1388a
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 156 deletions.
74 changes: 44 additions & 30 deletions fn/goroutine_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,55 @@ import (
// GoroutineManager is used to launch goroutines until context expires or the
// manager is stopped. The Stop method blocks until all started goroutines stop.
type GoroutineManager struct {
wg sync.WaitGroup
mu sync.Mutex
ctx context.Context
cancel func()
stopped sync.Once
wg sync.WaitGroup
mu sync.Mutex
quit chan struct{}
}

// NewGoroutineManager constructs and returns a new instance of
// GoroutineManager.
func NewGoroutineManager(ctx context.Context) *GoroutineManager {
ctx, cancel := context.WithCancel(ctx)

func NewGoroutineManager() *GoroutineManager {
return &GoroutineManager{
ctx: ctx,
cancel: cancel,
quit: make(chan struct{}),
}
}

// Go tries to start a new goroutine and returns a boolean indicating its
// success. It fails iff the goroutine manager is stopping or its context passed
// to NewGoroutineManager has expired.
func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
// success. It returns true if the goroutine was successfully created and false
// otherwise. A goroutine will fail to be created iff the goroutine manager is
// stopping or the passed context has already expired. The passed call-back
// function must exit if the passed context expires.
func (g *GoroutineManager) Go(ctx context.Context,
f func(ctx context.Context)) bool {

// Derive a child context which will be canceled when either the passed
// context is canceled or the quit channel of the GoroutineManager is
// closed.
ctxc, cancel := ContextWithQuit(ctx, g.quit)

// Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race
// condition, since it is not clear should Wait() block or not. This
// condition, since it is not clear if should Wait() block or not. This
// kind of race condition is detected by Go runtime and results in a
// crash if running with `-race`. To prevent this, whole Go method is
// protected with a mutex. The call to wg.Wait() inside Stop() can still
// run in parallel with Go, but in that case g.ctx is in expired state,
// because cancel() was called in Stop, so Go returns before wg.Add(1)
// call.
// crash if running with `-race`. To prevent this, we protect the calls
// to wg.Add(1) and wg.Wait() with a mutex. If we block here because
// Stop is running first, then Stop will cancel the quit channel which
// will cause the context to be cancelled, and we will exit before
// calling wg.Add(1). If we grab the mutex here before Stop does, then
// Stop will block until after we call wg.Add(1).
g.mu.Lock()
defer g.mu.Unlock()

if g.ctx.Err() != nil {
// If we get here and Stop has already been called
if ctxc.Err() != nil {
return false
}

g.wg.Add(1)
go func() {
defer g.wg.Done()
f(g.ctx)
f(ctxc)
cancel()
}()

return true
Expand All @@ -56,20 +65,25 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
// Stop prevents new goroutines from being added and waits for all running
// goroutines to finish.
func (g *GoroutineManager) Stop() {
g.mu.Lock()
g.cancel()
g.mu.Unlock()
g.stopped.Do(func() {
// Closing the quit channel will cause all goroutines to exit
// and will prevent any new goroutines from starting.
g.mu.Lock()
close(g.quit)
g.mu.Unlock()

// Wait for all goroutines to finish. Note that this wg.Wait() call is
// safe, since it can't run in parallel with wg.Add(1) call in Go, since
// we just cancelled the context and even if Go call starts running here
// after acquiring the mutex, it would see that the context has expired
// and return false instead of calling wg.Add(1).
g.wg.Wait()
// Wait for all goroutines to finish. Note that this wg.Wait()
// call is safe, since it can't run in parallel with wg.Add(1)
// call in Go, since we just cancelled the context and even if
// Go call starts running here after acquiring the mutex, it
// would see that the context has expired and return false
// instead of calling wg.Add(1).
g.wg.Wait()
})
}

// Done returns a channel which is closed when either the context passed to
// NewGoroutineManager expires or when Stop is called.
func (g *GoroutineManager) Done() <-chan struct{} {
return g.ctx.Done()
return g.quit
}
241 changes: 115 additions & 126 deletions fn/goroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,156 +2,145 @@ package fn

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

// TestGoroutineManager tests that the GoroutineManager starts goroutines until
// ctx expires. It also makes sure it fails to start new goroutines after the
// context expired and the GoroutineManager is in the process of waiting for
// already started goroutines in the Stop method.
// TestGoroutineManager tests the behaviour of the GoroutineManager.
func TestGoroutineManager(t *testing.T) {
t.Parallel()

m := NewGoroutineManager(context.Background())

taskChan := make(chan struct{})

require.True(t, m.Go(func(ctx context.Context) {
<-taskChan
}))

t1 := time.Now()

// Close taskChan in 1s, causing the goroutine to stop.
time.AfterFunc(time.Second, func() {
close(taskChan)
})

m.Stop()
stopDelay := time.Since(t1)

// Make sure Stop was waiting for the goroutine to stop.
require.Greater(t, stopDelay, time.Second)

// Make sure new goroutines do not start after Stop.
require.False(t, m.Go(func(ctx context.Context) {}))

// When Stop() is called, the internal context expires and m.Done() is
// closed. Test this.
select {
case <-m.Done():
default:
t.Errorf("Done() channel must be closed at this point")
}
}

// TestGoroutineManagerContextExpires tests the effect of context expiry.
func TestGoroutineManagerContextExpires(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())

m := NewGoroutineManager(ctx)

require.True(t, m.Go(func(ctx context.Context) {
<-ctx.Done()
}))

// The Done channel of the manager should not be closed, so the
// following call must block.
select {
case <-m.Done():
t.Errorf("Done() channel must not be closed at this point")
default:
}

cancel()

// The Done channel of the manager should be closed, so the following
// call must not block.
select {
case <-m.Done():
default:
t.Errorf("Done() channel must be closed at this point")
}
// Here we test that the GoroutineManager starts goroutines until it has
// been stopped.
t.Run("GM is stopped", func(t *testing.T) {
t.Parallel()

var (
ctx = context.Background()
m = NewGoroutineManager()
taskChan = make(chan struct{})
)

// The gm has not stopped yet and the passed in context has not
// expired, so we expect the goroutine to start. The taskChan is
// blocking, so this goroutine will be live for a while.
require.True(t, m.Go(ctx, func(ctx context.Context) {
<-taskChan
}))

// Make sure new goroutines do not start after context expiry.
require.False(t, m.Go(func(ctx context.Context) {}))
t1 := time.Now()

// Stop will wait for all goroutines to stop.
m.Stop()
}
// Close taskChan in 1s, causing the goroutine to stop.
time.AfterFunc(time.Second, func() {
close(taskChan)
})

// TestGoroutineManagerStress starts many goroutines while calling Stop. It
// is needed to make sure the GoroutineManager does not crash if this happen.
// If the mutex was not used, it would crash because of a race condition between
// wg.Add(1) and wg.Wait().
func TestGoroutineManagerStress(t *testing.T) {
t.Parallel()
m.Stop()
stopDelay := time.Since(t1)

m := NewGoroutineManager(context.Background())
// Make sure Stop was waiting for the goroutine to stop.
require.Greater(t, stopDelay, time.Second)

stopChan := make(chan struct{})
// Make sure new goroutines do not start after Stop.
require.False(t, m.Go(ctx, func(ctx context.Context) {}))

time.AfterFunc(1*time.Millisecond, func() {
m.Stop()
close(stopChan)
// When Stop() is called, gm quit channel has been closed and so
// Done() should return.
select {
case <-m.Done():
default:
t.Errorf("Done() channel must be closed at this point")
}
})

// Starts 100 goroutines sequentially. Sequential order is needed to
// keep wg.counter low (0 or 1) to increase probability of race
// condition to be caught if it exists. If mutex is removed in the
// implementation, this test crashes under `-race`.
for i := 0; i < 100; i++ {
taskChan := make(chan struct{})
ok := m.Go(func(ctx context.Context) {
close(taskChan)
})
// If goroutine was started, wait for its completion.
if ok {
<-taskChan
// Test that the GoroutineManager fails to start a goroutine or exits a
// goroutine if the caller context has expired.
t.Run("Caller context expires", func(t *testing.T) {
t.Parallel()

var (
ctx = context.Background()
m = NewGoroutineManager()
taskChan = make(chan struct{})
)

// Derive a child context with a cancel function.
ctxc, cancel := context.WithCancel(ctx)

// The gm has not stopped yet and the passed in context has not
// expired, so we expect the goroutine to start.
require.True(t, m.Go(ctxc, func(ctx context.Context) {
select {
case <-ctx.Done():
case <-taskChan:
t.Fatalf("The task was performed when it " +
"should not have")
}
}))

// Give the GM a little bit of time to start the goroutine so
// that we can be sure that it is already listening on the
// ctx and taskChan before calling cancel.
time.Sleep(time.Millisecond * 500)

// Cancel the context so that the goroutine exits.
cancel()

// Attempt to send a signal on the task channel, nothing should
// happen since the goroutine has already exited.
select {
case taskChan <- struct{}{}:
case <-time.After(time.Millisecond * 200):
}
}

// Wait for Stop to complete.
<-stopChan
}
// Again attempt to add a goroutine with the same cancelled
// context. This should fail since the context has already
// expired.
require.False(t, m.Go(ctxc, func(ctx context.Context) {
t.Fatalf("The goroutine should not have started")
}))

// TestGoroutineManagerStopsStress launches many Stop() calls in parallel with a
// task exiting. It attempts to catch a race condition between wg.Done() and
// wg.Wait() calls. According to documentation of wg.Wait() this is acceptable,
// therefore this test passes even with -race.
func TestGoroutineManagerStopsStress(t *testing.T) {
t.Parallel()
// Stop the goroutine manager.
m.Stop()
})

m := NewGoroutineManager(context.Background())
// Start many goroutines while calling Stop. We do this to make sure
// that the GoroutineManager does not crash when these calls are done in
// parallel because of the potential race between wg.Add() and
// wg.Done() when the wg counter is 0.
t.Run("Stress test", func(t *testing.T) {
t.Parallel()

// jobChan is used to make the task to finish.
jobChan := make(chan struct{})
var (
ctx = context.Background()
m = NewGoroutineManager()
stopChan = make(chan struct{})
)

// Start a task and wait inside it until we start calling Stop() method.
ok := m.Go(func(ctx context.Context) {
<-jobChan
})
require.True(t, ok)

// Now launch many gorotines calling Stop() method in parallel.
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
time.AfterFunc(1*time.Millisecond, func() {
m.Stop()
}()
}
close(stopChan)
})

// Exit the task in parallel with Stop() calls.
close(jobChan)
// Start 100 goroutines sequentially. Sequential order is
// needed to keep wg.counter low (0 or 1) to increase
// probability of the race condition to triggered if it exists.
// If mutex is removed in the implementation, this test crashes
// under `-race`.
for i := 0; i < 100; i++ {
taskChan := make(chan struct{})
ok := m.Go(ctx, func(ctx context.Context) {
close(taskChan)
})
// If goroutine was started, wait for its completion.
if ok {
<-taskChan
}
}

// Wait until all the Stop() calls complete.
wg.Wait()
// Wait for Stop to complete.
<-stopChan
})
}

0 comments on commit 0f1388a

Please sign in to comment.