-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit adds a new context helper that combines the "quit" signals of a parent context and a quit channel and derives a child context that is cancelled when either one of those is cancelled/closed.
- Loading branch information
1 parent
5659c01
commit c0d2d1c
Showing
2 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package fn | ||
|
||
import ( | ||
"context" | ||
) | ||
|
||
// ctxWithQuit is a context that is canceled when either the passed context is | ||
// canceled or the quit channel is closed. | ||
type ctxWithQuit struct { | ||
context.Context | ||
|
||
// quit is the quit channel that was passed to ContextWithQuit. This is | ||
// the caller's quit channel and is not ours to close. It should only | ||
// ever be listened on. | ||
quit <-chan struct{} | ||
} | ||
|
||
// Done returns a channel that is closed when either the passed context is | ||
// canceled or the quit channel is closed. | ||
// | ||
// NOTE: this is part of the context.Context interface. | ||
func (c *ctxWithQuit) Done() <-chan struct{} { | ||
select { | ||
case <-c.quit: | ||
return c.quit | ||
default: | ||
return c.Context.Done() | ||
} | ||
} | ||
|
||
// Err returns the error that caused the context to be canceled. If the passed | ||
// context is canceled, then the error from the passed context is returned. | ||
// Otherwise, if the quit channel is closed, then context.Canceled is returned. | ||
// | ||
// NOTE: this is part of the context.Context interface. | ||
func (c *ctxWithQuit) Err() error { | ||
select { | ||
case <-c.Context.Done(): | ||
return c.Context.Err() | ||
case <-c.quit: | ||
return context.Canceled | ||
default: | ||
return nil | ||
} | ||
} | ||
|
||
// ContextWithQuit returns a new context that is canceled when either the | ||
// passed context is canceled or the quit channel is closed. This in essence | ||
// combines the signals of the passed context and the quit channel. | ||
// | ||
// NOTE: if the parent context is canceled first, then the returned context is | ||
// guaranteed to be cancelled immediately since it is derived from the parent. | ||
// However, if the quit channel is closed first, then there the returned context | ||
// will be closed once the closed quit channel signal has been responded to. | ||
func ContextWithQuit(ctx context.Context, | ||
quit <-chan struct{}) (context.Context, context.CancelFunc) { | ||
|
||
// Derive a fresh context from the passed context. If the passed | ||
// context has already been canceled, then this fresh one will also | ||
// already be canceled. | ||
ctx, cancel := context.WithCancel(ctx) | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return ctx, cancel | ||
case <-quit: | ||
cancel() | ||
return ctx, cancel | ||
default: | ||
} | ||
|
||
go func() { | ||
select { | ||
// If the derived context is canceled, which will be the case | ||
// if either the passed parent context is canceled or if the | ||
// returned cancel function is called, then there is nothing | ||
// left to do. | ||
case <-ctx.Done(): | ||
|
||
// If the quit channel is closed, then we cancel the derived | ||
// context which was returned. | ||
case <-quit: | ||
cancel() | ||
} | ||
}() | ||
|
||
return &ctxWithQuit{Context: ctx, quit: quit}, cancel | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package fn | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
) | ||
|
||
// TestContextWithQuit tests the behaviour of the ContextWithQuit function. | ||
// We test that the derived context is correctly cancelled when either the | ||
// passed context is cancelled or the quit channel is closed. | ||
func TestContextWithQuit(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Test that the derived context is cancelled when the passed context is | ||
// cancelled. | ||
t.Run("Parent context is cancelled", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
var ( | ||
ctx, cancel = context.WithCancel(context.Background()) | ||
quit = make(chan struct{}) | ||
) | ||
|
||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Cancel the parent context. | ||
cancel() | ||
|
||
// Assert that the derived context is cancelled. | ||
select { | ||
case <-ctxc.Done(): | ||
default: | ||
t.Errorf("The derived context should be cancelled at " + | ||
"this point") | ||
} | ||
}) | ||
|
||
// Test that the derived context is cancelled when the passed quit | ||
// channel is closed. | ||
t.Run("Quit channel is closed", func(t *testing.T) { | ||
var ( | ||
ctx = context.Background() | ||
quit = make(chan struct{}) | ||
) | ||
|
||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Close the quit channel. | ||
close(quit) | ||
|
||
// Assert that the derived context is cancelled. | ||
select { | ||
case <-ctxc.Done(): | ||
default: | ||
t.Errorf("The derived context should be cancelled at " + | ||
"this point") | ||
} | ||
}) | ||
|
||
t.Run("Parent context is already closed", func(t *testing.T) { | ||
var ( | ||
ctx, cancel = context.WithCancel(context.Background()) | ||
quit = make(chan struct{}) | ||
) | ||
|
||
cancel() | ||
|
||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Assert that the derived context is cancelled already | ||
// cancelled. | ||
select { | ||
case <-ctxc.Done(): | ||
default: | ||
t.Errorf("The derived context should be cancelled at " + | ||
"this point") | ||
} | ||
}) | ||
|
||
t.Run("Passed quit channel is already closed", func(t *testing.T) { | ||
var ( | ||
ctx = context.Background() | ||
quit = make(chan struct{}) | ||
) | ||
|
||
close(quit) | ||
|
||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Assert that the derived context is cancelled already | ||
// cancelled. | ||
select { | ||
case <-ctxc.Done(): | ||
default: | ||
t.Errorf("The derived context should be cancelled at " + | ||
"this point") | ||
} | ||
}) | ||
|
||
t.Run("Child context should be cancelled synchronously with "+ | ||
"parent", func(t *testing.T) { | ||
|
||
var ( | ||
ctx, cancel = context.WithCancel(context.Background()) | ||
quit = make(chan struct{}) | ||
task = make(chan struct{}) | ||
done = make(chan struct{}) | ||
) | ||
|
||
// Derive a child context. | ||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Spin off a routine that exists cleaning if the child context | ||
// is cancelled but fails if the task is performed. | ||
go func() { | ||
defer close(done) | ||
select { | ||
case <-ctxc.Done(): | ||
case <-task: | ||
t.Fatalf("should not get here") | ||
} | ||
}() | ||
|
||
// Give the goroutine some time to spin up. | ||
time.Sleep(time.Millisecond * 500) | ||
|
||
// First cancel the parent context. Then immediately execute the | ||
// task. | ||
cancel() | ||
close(task) | ||
|
||
// Wait for the goroutine to exit. | ||
select { | ||
case <-done: | ||
case <-time.After(time.Second): | ||
t.Fatalf("timeout") | ||
} | ||
}) | ||
|
||
t.Run("Child context should be cancelled synchronously with the "+ | ||
"close of the quit channel", func(t *testing.T) { | ||
|
||
var ( | ||
ctx = context.Background() | ||
quit = make(chan struct{}) | ||
task = make(chan struct{}) | ||
done = make(chan struct{}) | ||
) | ||
|
||
// Derive a child context. | ||
ctxc, _ := ContextWithQuit(ctx, quit) | ||
|
||
// Spin off a routine that exists cleaning if the child context | ||
// is cancelled but fails if the task is performed. | ||
go func() { | ||
defer close(done) | ||
select { | ||
case <-ctxc.Done(): | ||
case <-task: | ||
t.Fatalf("should not get here") | ||
} | ||
}() | ||
|
||
// First cancel the parent context. Then immediately execute the | ||
// task. | ||
close(quit) | ||
|
||
// Give the context some time to cancel the derived context. | ||
time.Sleep(time.Millisecond * 500) | ||
|
||
close(task) | ||
|
||
// Wait for the goroutine to exit. | ||
select { | ||
case <-done: | ||
case <-time.After(time.Second): | ||
t.Fatalf("timeout") | ||
} | ||
}) | ||
} |