diff --git a/app.go b/app.go index 8d0132c62d..37d971952d 100644 --- a/app.go +++ b/app.go @@ -296,9 +296,12 @@ type App struct { errorHooks []ErrorHandler validate bool // Used to signal shutdowns. - donesMu sync.Mutex // guards dones and shutdownSig - dones []chan os.Signal - shutdownSig os.Signal + donesMu sync.Mutex // guards dones and shutdownSig + dones []chan os.Signal + shutdownSig os.Signal + waitsMu sync.Mutex // guards waits and shutdownCode + waits []chan ShutdownSignal + shutdownSignal *ShutdownSignal osExit func(code int) // os.Exit override; used for testing only } @@ -737,6 +740,31 @@ func (app *App) Done() <-chan os.Signal { return c } +func (app *App) wait() <-chan ShutdownSignal { + c := make(chan ShutdownSignal, 1) + + app.waitsMu.Lock() + defer app.waitsMu.Unlock() + + if app.shutdownSignal != nil { + c <- *app.shutdownSignal + return c + } + + app.waits = append(app.waits, c) + return c +} + +func (app *App) Wait(ctx context.Context) (ShutdownSignal, error) { + c := app.wait() + select { + case s := <-c: + return s, nil + case <-ctx.Done(): + return ShutdownSignal{}, ctx.Err() + } +} + // StartTimeout returns the configured startup timeout. Apps default to using // DefaultTimeout, but users can configure this behavior using the // StartTimeout option. diff --git a/shutdown.go b/shutdown.go index d5b8488c0c..d5f125d14e 100644 --- a/shutdown.go +++ b/shutdown.go @@ -23,6 +23,8 @@ package fx import ( "fmt" "os" + + "go.uber.org/multierr" ) // Shutdowner provides a method that can manually trigger the shutdown of the @@ -39,8 +41,26 @@ type ShutdownOption interface { apply(*shutdowner) } +type shutdownCode int + +func (c shutdownCode) apply(s *shutdowner) { + s.exitCode = int(c) +} + +// ShutdownCode implements a shutdown option that allows a user specify the +// os.Exit code that an application should exit with. +func ShutdownCode(code int) ShutdownOption { + return shutdownCode(code) +} + type shutdowner struct { - app *App + exitCode int + app *App +} + +type ShutdownSignal struct { + Signal os.Signal + ExitCode int } // Shutdown broadcasts a signal to all of the application's Done channels @@ -49,14 +69,25 @@ type shutdowner struct { // In practice this means Shutdowner.Shutdown should not be called from an // fx.Invoke, but from a fx.Lifecycle.OnStart hook. func (s *shutdowner) Shutdown(opts ...ShutdownOption) error { - return s.app.broadcastSignal(_sigTERM) + for _, opt := range opts { + opt.apply(s) + } + + return s.app.broadcastSignal(_sigTERM, s.exitCode) } func (app *App) shutdowner() Shutdowner { return &shutdowner{app: app} } -func (app *App) broadcastSignal(signal os.Signal) error { +func (app *App) broadcastSignal(signal os.Signal, code int) error { + return multierr.Combine( + app.broadcastDoneSignal(signal), + app.broadcastWaitSignal(signal, code), + ) +} + +func (app *App) broadcastDoneSignal(signal os.Signal) error { app.donesMu.Lock() defer app.donesMu.Unlock() @@ -81,3 +112,32 @@ func (app *App) broadcastSignal(signal os.Signal) error { return nil } + +func (app *App) broadcastWaitSignal(signal os.Signal, code int) error { + app.waitsMu.Lock() + defer app.waitsMu.Unlock() + + app.shutdownSignal = &ShutdownSignal{ + Signal: signal, + ExitCode: code, + } + + var unsent int + for _, wait := range app.waits { + select { + case wait <- *app.shutdownSignal: + default: + // shutdown called when wait channel has already received a + // termination signal that has not been cleared + unsent++ + } + } + + if unsent != 0 { + return fmt.Errorf("failed to send %v codes to %v out of %v channels", + signal, unsent, len(app.waits), + ) + } + + return nil +} diff --git a/shutdown_code_example_test.go b/shutdown_code_example_test.go new file mode 100644 index 0000000000..9ac69001a7 --- /dev/null +++ b/shutdown_code_example_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx_test + +import ( + "context" + "fmt" + "time" + + "go.uber.org/fx" +) + +func ExampleShutdownCode() { + app := fx.New( + fx.Invoke(func(shutdowner fx.Shutdowner) { + // Call the shutdowner Shutdown method with a shutdown code + // option + shutdowner.Shutdown(fx.ShutdownCode(1)) + }), + ) + + app.Run() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + shutdown, err := app.Wait(ctx) + + if err != nil { + panic(err) + } + + fmt.Printf("os.Exit(%v)\n", shutdown.ExitCode) + + // Output: + // os.Exit(1) +} diff --git a/shutdown_test.go b/shutdown_test.go index b6af93f131..6f96547e83 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -22,8 +22,10 @@ package fx_test import ( "context" + "fmt" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -87,6 +89,62 @@ func TestShutdown(t *testing.T) { assert.NotNil(t, <-done1, "done channel 1 did not receive signal") assert.NotNil(t, <-done2, "done channel 2 did not receive signal") }) + + t.Run("shutdown app with exit code(s)", func(t *testing.T) { + t.Parallel() + + t.Run("default", func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New(t, fx.Populate(&s)) + + done := app.Done() + defer app.RequireStart().RequireStop() + + assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + signal, err := app.Wait(ctx) + assert.NoError(t, err, "error in app wait") + assert.NotEmpty(t, signal, "no shutdown signal") + assert.NotNil(t, signal.Signal) + assert.Zero(t, signal.ExitCode) + assert.Equal(t, signal.Signal, <-done) + assert.NoError(t, ctx.Err()) + }) + + for expected := 0; expected <= 3; expected++ { + expected := expected + t.Run(fmt.Sprintf("with exit code %v", expected), func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New( + t, + fx.Populate(&s), + ) + + done := app.Done() + defer app.RequireStart().RequireStop() + + assert.NoError( + t, + s.Shutdown(fx.ShutdownCode(expected)), + "error in app shutdown", + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + signal, err := app.Wait(ctx) + assert.NoError(t, err, "error in app wait") + assert.NotEmpty(t, signal, "no shutdown signal") + assert.NotNil(t, signal.Signal) + assert.Equal(t, expected, signal.ExitCode) + assert.Equal(t, signal.Signal, <-done) + }) + } + }) } func TestDataRace(t *testing.T) {