Skip to content

Commit

Permalink
support stop gates
Browse files Browse the repository at this point in the history
  • Loading branch information
ktong committed Mar 15, 2024
1 parent dd1e96b commit 816fe12
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 46 deletions.
10 changes: 10 additions & 0 deletions run/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ func WithStartGate(gates ...func(context.Context) error) Option {
}
}

// WithStopGate provides gates to block the stop of main runs provided in Runner.Run,
// until all stop gates returns.
//
// All stop gates must return in limited time to avoid blocking the main runs.
func WithStopGate(gates ...func(context.Context) error) Option {
return func(opts *options) {
opts.stopGates = append(opts.stopGates, gates...)
}
}

type (
// Option configures the Runner with specific options.
Option func(*options)
Expand Down
29 changes: 21 additions & 8 deletions run/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
type Runner struct {
preRuns []func(context.Context) error
startGates []func(context.Context) error
stopGates []func(context.Context) error
running atomic.Bool
}

Expand All @@ -40,7 +41,7 @@ func New(opts ...Option) *Runner {
// The execution can be interrupted if any run returns non-nil error,
// or it receives an OS signal syscall.SIGINT or syscall.SIGTERM.
// It waits all run return unless it's forcefully terminated by OS.
func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) error {
func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) error { //nolint:funlen
if e == nil {
// Use empty instance instead to avoid nil pointer dereference,
// Assignment propagates only to callee but not to caller.
Expand Down Expand Up @@ -75,24 +76,36 @@ func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) e
},
)

// Root context which is used for pre-runs.
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
// Context can be terminated by OS signals, which is used for start-gates and parent of context for main runs.
signalCtx, signalCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer signalCancel()
// Context is used for main runs and stop-gates.
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()

allRuns = append(allRuns,
func(ctx context.Context) (err error) { //nolint:nonamedreturns
func(context.Context) error {
defer runCancel()

<-signalCtx.Done()

return Parallel(runCtx, e.stopGates...)
},
func(context.Context) (err error) { //nolint:nonamedreturns
defer func() {
cancel(err)
}()

// Terminate signals apply to the runs, then cancel the root context for pre-runs.
nctx, ncancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer ncancel()

// Wait for all startGates to open.
if err = Parallel(nctx, e.startGates...); err != nil {
// Use signalCtx to allow it to be interrupted by OS signals.
if err = Parallel(signalCtx, e.startGates...); err != nil {
return err
}

return Parallel(nctx, runs...)
return Parallel(runCtx, runs...)
},
)

Expand Down
112 changes: 74 additions & 38 deletions run/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,19 @@ package run_test
import (
"context"
"errors"
"os"
"syscall"
"testing"
"time"

"github.com/nil-go/nilgo/internal/assert"
"github.com/nil-go/nilgo/run"
)

func TestRunner_Run(t *testing.T) {
for _, testcase := range testcases() {
testcase := testcase

var ran bool
t.Run(testcase.description, func(t *testing.T) {
err := testcase.runner.Run(
context.Background(),
func(context.Context) error {
ran = true
if testcase.err != "" {
return errors.New(testcase.err)
}

return nil
},
)
t.Parallel()

assert.Equal(t, testcase.ran, ran)
if testcase.err == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, testcase.err)
}
})
}
}

func testcases() []struct {
description string
runner *run.Runner
ran bool
err string
} {
return []struct {
testcases := []struct {
description string
runner *run.Runner
ran bool
Expand All @@ -61,6 +33,17 @@ func testcases() []struct {
ran: true,
err: "run error",
},
{
description: "with pre-run",
runner: run.New(run.WithPreRun(func(context.Context) error { return nil })),
ran: true,
},
{
description: "pre-run error",
runner: run.New(run.WithPreRun(func(context.Context) error { return errors.New("pre-run error") })),
err: "pre-run error",
ran: true,
},
{
description: "with start gate",
runner: run.New(run.WithStartGate(func(context.Context) error { return nil })),
Expand All @@ -72,19 +55,50 @@ func testcases() []struct {
err: "start gate error",
},
{
description: "with pre-run",
runner: run.New(run.WithPreRun(func(context.Context) error { return nil })),
description: "with stop gate",
runner: run.New(run.WithStopGate(func(context.Context) error { return nil })),
ran: true,
},
{
description: "pre-run error",
runner: run.New(run.WithPreRun(func(context.Context) error { return errors.New("pre-run error") })),
err: "pre-run error",
description: "stop gate error",
runner: run.New(run.WithStopGate(func(context.Context) error { return errors.New("stop gate error") })),
ran: true,
err: "stop gate error",
},
}

for _, testcase := range testcases {
testcase := testcase

var ran bool
t.Run(testcase.description, func(t *testing.T) {
t.Parallel()

err := testcase.runner.Run(
context.Background(),
func(context.Context) error {
ran = true
if testcase.err != "" {
return errors.New(testcase.err)
}

return nil
},
)

assert.Equal(t, testcase.ran, ran)
if testcase.err == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, testcase.err)
}
})
}
}

func TestRunner_Run_twice(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -107,3 +121,25 @@ func TestRunner_Run_twice(t *testing.T) {
// It should return an error as it's already running.
assert.EqualError(t, runner.Run(ctx), "runner is already running")
}

func TestRunner_Run_signal(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var runner run.Runner
assert.NoError(t, runner.Run(ctx,
func(ctx context.Context) error {
timer := time.NewTimer(time.Minute)
defer timer.Stop()
select {
case <-ctx.Done():
return nil
case <-timer.C:
return errors.New("timeout")
}
},
func(context.Context) error {
return syscall.Kill(os.Getpid(), syscall.SIGINT)
},
))
}

0 comments on commit 816fe12

Please sign in to comment.