diff --git a/run/option.go b/run/option.go index 9c6a716..5f540f1 100644 --- a/run/option.go +++ b/run/option.go @@ -25,6 +25,16 @@ func WithStartGate(gates ...func(context.Context) error) Option { } } +// WithStopGate provides gates to block the stop of main runs provided in Runner.Run, +// until all stop gates returns. +// +// All stop gates must return in limited time to avoid blocking the main runs. +func WithStopGate(gates ...func(context.Context) error) Option { + return func(opts *options) { + opts.stopGates = append(opts.stopGates, gates...) + } +} + type ( // Option configures the Runner with specific options. Option func(*options) diff --git a/run/runner.go b/run/runner.go index f68ada4..641bfc0 100644 --- a/run/runner.go +++ b/run/runner.go @@ -18,6 +18,7 @@ import ( type Runner struct { preRuns []func(context.Context) error startGates []func(context.Context) error + stopGates []func(context.Context) error running atomic.Bool } @@ -40,7 +41,7 @@ func New(opts ...Option) *Runner { // The execution can be interrupted if any run returns non-nil error, // or it receives an OS signal syscall.SIGINT or syscall.SIGTERM. // It waits all run return unless it's forcefully terminated by OS. -func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) error { +func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) error { //nolint:funlen if e == nil { // Use empty instance instead to avoid nil pointer dereference, // Assignment propagates only to callee but not to caller. @@ -75,24 +76,36 @@ func (e *Runner) Run(ctx context.Context, runs ...func(context.Context) error) e }, ) + // Root context which is used for pre-runs. ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) + // Context can be terminated by OS signals, which is used for start-gates and parent of context for main runs. + signalCtx, signalCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer signalCancel() + // Context is used for main runs and stop-gates. + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() + allRuns = append(allRuns, - func(ctx context.Context) (err error) { //nolint:nonamedreturns + func(context.Context) error { + defer runCancel() + + <-signalCtx.Done() + + return Parallel(runCtx, e.stopGates...) + }, + func(context.Context) (err error) { //nolint:nonamedreturns defer func() { cancel(err) }() - // Terminate signals apply to the runs, then cancel the root context for pre-runs. - nctx, ncancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) - defer ncancel() - // Wait for all startGates to open. - if err = Parallel(nctx, e.startGates...); err != nil { + // Use signalCtx to allow it to be interrupted by OS signals. + if err = Parallel(signalCtx, e.startGates...); err != nil { return err } - return Parallel(nctx, runs...) + return Parallel(runCtx, runs...) }, ) diff --git a/run/runner_test.go b/run/runner_test.go index 3121ad2..6a513a2 100644 --- a/run/runner_test.go +++ b/run/runner_test.go @@ -6,47 +6,19 @@ package run_test import ( "context" "errors" + "os" + "syscall" "testing" + "time" "github.com/nil-go/nilgo/internal/assert" "github.com/nil-go/nilgo/run" ) func TestRunner_Run(t *testing.T) { - for _, testcase := range testcases() { - testcase := testcase - - var ran bool - t.Run(testcase.description, func(t *testing.T) { - err := testcase.runner.Run( - context.Background(), - func(context.Context) error { - ran = true - if testcase.err != "" { - return errors.New(testcase.err) - } - - return nil - }, - ) + t.Parallel() - assert.Equal(t, testcase.ran, ran) - if testcase.err == "" { - assert.NoError(t, err) - } else { - assert.EqualError(t, err, testcase.err) - } - }) - } -} - -func testcases() []struct { - description string - runner *run.Runner - ran bool - err string -} { - return []struct { + testcases := []struct { description string runner *run.Runner ran bool @@ -61,6 +33,17 @@ func testcases() []struct { ran: true, err: "run error", }, + { + description: "with pre-run", + runner: run.New(run.WithPreRun(func(context.Context) error { return nil })), + ran: true, + }, + { + description: "pre-run error", + runner: run.New(run.WithPreRun(func(context.Context) error { return errors.New("pre-run error") })), + err: "pre-run error", + ran: true, + }, { description: "with start gate", runner: run.New(run.WithStartGate(func(context.Context) error { return nil })), @@ -72,19 +55,50 @@ func testcases() []struct { err: "start gate error", }, { - description: "with pre-run", - runner: run.New(run.WithPreRun(func(context.Context) error { return nil })), + description: "with stop gate", + runner: run.New(run.WithStopGate(func(context.Context) error { return nil })), ran: true, }, { - description: "pre-run error", - runner: run.New(run.WithPreRun(func(context.Context) error { return errors.New("pre-run error") })), - err: "pre-run error", + description: "stop gate error", + runner: run.New(run.WithStopGate(func(context.Context) error { return errors.New("stop gate error") })), + ran: true, + err: "stop gate error", }, } + + for _, testcase := range testcases { + testcase := testcase + + var ran bool + t.Run(testcase.description, func(t *testing.T) { + t.Parallel() + + err := testcase.runner.Run( + context.Background(), + func(context.Context) error { + ran = true + if testcase.err != "" { + return errors.New(testcase.err) + } + + return nil + }, + ) + + assert.Equal(t, testcase.ran, ran) + if testcase.err == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, testcase.err) + } + }) + } } func TestRunner_Run_twice(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -107,3 +121,25 @@ func TestRunner_Run_twice(t *testing.T) { // It should return an error as it's already running. assert.EqualError(t, runner.Run(ctx), "runner is already running") } + +func TestRunner_Run_signal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var runner run.Runner + assert.NoError(t, runner.Run(ctx, + func(ctx context.Context) error { + timer := time.NewTimer(time.Minute) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil + case <-timer.C: + return errors.New("timeout") + } + }, + func(context.Context) error { + return syscall.Kill(os.Getpid(), syscall.SIGINT) + }, + )) +}