Skip to content

Commit

Permalink
fn: ContextWithQuit
Browse files Browse the repository at this point in the history
This commit adds a new context helper that combines the "quit" signals
of a parent context and a quit channel and derives a child context that
is cancelled when either one of those is cancelled/closed.
  • Loading branch information
ellemouton committed Dec 9, 2024
1 parent 5659c01 commit c0d2d1c
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 0 deletions.
88 changes: 88 additions & 0 deletions fn/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package fn

import (
"context"
)

// ctxWithQuit is a context that is canceled when either the passed context is
// canceled or the quit channel is closed.
type ctxWithQuit struct {
context.Context

// quit is the quit channel that was passed to ContextWithQuit. This is
// the caller's quit channel and is not ours to close. It should only
// ever be listened on.
quit <-chan struct{}
}

// Done returns a channel that is closed when either the passed context is
// canceled or the quit channel is closed.
//
// NOTE: this is part of the context.Context interface.
func (c *ctxWithQuit) Done() <-chan struct{} {
select {
case <-c.quit:
return c.quit
default:
return c.Context.Done()
}
}

// Err returns the error that caused the context to be canceled. If the passed
// context is canceled, then the error from the passed context is returned.
// Otherwise, if the quit channel is closed, then context.Canceled is returned.
//
// NOTE: this is part of the context.Context interface.
func (c *ctxWithQuit) Err() error {
select {
case <-c.Context.Done():
return c.Context.Err()
case <-c.quit:
return context.Canceled
default:
return nil
}
}

// ContextWithQuit returns a new context that is canceled when either the
// passed context is canceled or the quit channel is closed. This in essence
// combines the signals of the passed context and the quit channel.
//
// NOTE: if the parent context is canceled first, then the returned context is
// guaranteed to be cancelled immediately since it is derived from the parent.
// However, if the quit channel is closed first, then there the returned context
// will be closed once the closed quit channel signal has been responded to.
func ContextWithQuit(ctx context.Context,
quit <-chan struct{}) (context.Context, context.CancelFunc) {

// Derive a fresh context from the passed context. If the passed
// context has already been canceled, then this fresh one will also
// already be canceled.
ctx, cancel := context.WithCancel(ctx)

select {
case <-ctx.Done():
return ctx, cancel
case <-quit:
cancel()
return ctx, cancel
default:
}

go func() {
select {
// If the derived context is canceled, which will be the case
// if either the passed parent context is canceled or if the
// returned cancel function is called, then there is nothing
// left to do.
case <-ctx.Done():

// If the quit channel is closed, then we cancel the derived
// context which was returned.
case <-quit:
cancel()
}
}()

return &ctxWithQuit{Context: ctx, quit: quit}, cancel
}
181 changes: 181 additions & 0 deletions fn/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package fn

import (
"context"
"testing"
"time"
)

// TestContextWithQuit tests the behaviour of the ContextWithQuit function.
// We test that the derived context is correctly cancelled when either the
// passed context is cancelled or the quit channel is closed.
func TestContextWithQuit(t *testing.T) {
t.Parallel()

// Test that the derived context is cancelled when the passed context is
// cancelled.
t.Run("Parent context is cancelled", func(t *testing.T) {
t.Parallel()

var (
ctx, cancel = context.WithCancel(context.Background())
quit = make(chan struct{})
)

ctxc, _ := ContextWithQuit(ctx, quit)

// Cancel the parent context.
cancel()

// Assert that the derived context is cancelled.
select {
case <-ctxc.Done():
default:
t.Errorf("The derived context should be cancelled at " +
"this point")
}
})

// Test that the derived context is cancelled when the passed quit
// channel is closed.
t.Run("Quit channel is closed", func(t *testing.T) {
var (
ctx = context.Background()
quit = make(chan struct{})
)

ctxc, _ := ContextWithQuit(ctx, quit)

// Close the quit channel.
close(quit)

// Assert that the derived context is cancelled.
select {
case <-ctxc.Done():
default:
t.Errorf("The derived context should be cancelled at " +
"this point")
}
})

t.Run("Parent context is already closed", func(t *testing.T) {
var (
ctx, cancel = context.WithCancel(context.Background())
quit = make(chan struct{})
)

cancel()

ctxc, _ := ContextWithQuit(ctx, quit)

// Assert that the derived context is cancelled already
// cancelled.
select {
case <-ctxc.Done():
default:
t.Errorf("The derived context should be cancelled at " +
"this point")
}
})

t.Run("Passed quit channel is already closed", func(t *testing.T) {
var (
ctx = context.Background()
quit = make(chan struct{})
)

close(quit)

ctxc, _ := ContextWithQuit(ctx, quit)

// Assert that the derived context is cancelled already
// cancelled.
select {
case <-ctxc.Done():
default:
t.Errorf("The derived context should be cancelled at " +
"this point")
}
})

t.Run("Child context should be cancelled synchronously with "+
"parent", func(t *testing.T) {

var (
ctx, cancel = context.WithCancel(context.Background())
quit = make(chan struct{})
task = make(chan struct{})
done = make(chan struct{})
)

// Derive a child context.
ctxc, _ := ContextWithQuit(ctx, quit)

// Spin off a routine that exists cleaning if the child context
// is cancelled but fails if the task is performed.
go func() {
defer close(done)
select {
case <-ctxc.Done():
case <-task:
t.Fatalf("should not get here")
}
}()

// Give the goroutine some time to spin up.
time.Sleep(time.Millisecond * 500)

// First cancel the parent context. Then immediately execute the
// task.
cancel()
close(task)

// Wait for the goroutine to exit.
select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
})

t.Run("Child context should be cancelled synchronously with the "+
"close of the quit channel", func(t *testing.T) {

var (
ctx = context.Background()
quit = make(chan struct{})
task = make(chan struct{})
done = make(chan struct{})
)

// Derive a child context.
ctxc, _ := ContextWithQuit(ctx, quit)

// Spin off a routine that exists cleaning if the child context
// is cancelled but fails if the task is performed.
go func() {
defer close(done)
select {
case <-ctxc.Done():
case <-task:
t.Fatalf("should not get here")
}
}()

// First cancel the parent context. Then immediately execute the
// task.
close(quit)

// Give the context some time to cancel the derived context.
time.Sleep(time.Millisecond * 500)

close(task)

// Wait for the goroutine to exit.
select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
})
}

0 comments on commit c0d2d1c

Please sign in to comment.