diff --git a/fxtest/app.go b/fxtest/app.go index 427f95422..9dddddea9 100644 --- a/fxtest/app.go +++ b/fxtest/app.go @@ -74,3 +74,26 @@ func (app *App) RequireStop() { app.tb.FailNow() } } + +func (app *App) RequireRun() { + startCtx, cancelStart := context.WithTimeout(context.Background(), app.StartTimeout()) + defer cancelStart() + if err := app.Start(startCtx); err != nil { + app.tb.Errorf("application didn't start cleanly: %v", err) + app.tb.FailNow() + } + + sig := <-app.Wait() + + stopCtx, cancelStop := context.WithTimeout(context.Background(), app.StopTimeout()) + defer cancelStop() + if err := app.Stop(stopCtx); err != nil { + app.tb.Errorf("application didn't stop cleanly: %v", err) + app.tb.FailNow() + } + + if sig.ExitCode != 0 { + app.tb.Errorf("application exited with code %v", sig.ExitCode) + app.tb.FailNow() + } +} diff --git a/fxtest/app_test.go b/fxtest/app_test.go index 7109908a5..cbd845962 100644 --- a/fxtest/app_test.go +++ b/fxtest/app_test.go @@ -96,4 +96,76 @@ func TestApp(t *testing.T) { assert.Equal(t, 1, spy.failures, "Expected Stop to fail.") assert.Contains(t, spy.errors.String(), "didn't stop cleanly", "Expected to write errors to TB.") }) + + t.Run("RunFailures", func(t *testing.T) { + t.Parallel() + + t.Run("RunFailure during Start", func(t *testing.T) { + t.Parallel() + spy := newTB() + + New( + spy, + fx.Invoke(func(lc fx.Lifecycle, s fx.Shutdowner) { + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + go s.Shutdown() + return errors.New("fail") + }, + OnStop: func(context.Context) error { + return nil + }, + }) + }), + ).RequireRun() + + assert.Equal(t, 1, spy.failures, "Expected RequireRun to fail.") + assert.Contains(t, spy.errors.String(), "didn't start cleanly", "Expected to write errors to TB.") + }) + t.Run("RunFailure during Stop", func(t *testing.T) { + t.Parallel() + spy := newTB() + + New( + spy, + fx.Invoke(func(lc fx.Lifecycle, s fx.Shutdowner) { + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + go s.Shutdown() + return nil + }, + OnStop: func(context.Context) error { + return errors.New("fail") + }, + }) + }), + ).RequireRun() + + assert.Equal(t, 1, spy.failures, "Expected RequireRun to fail.") + assert.Contains(t, spy.errors.String(), "didn't stop cleanly", "Expected to write errors to TB.") + }) + t.Run("RunFailure via exit code", func(t *testing.T) { + t.Parallel() + spy := newTB() + + New( + spy, + fx.Invoke(func(lc fx.Lifecycle, s fx.Shutdowner) { + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + go s.Shutdown(fx.ExitCode(42)) + return nil + }, + OnStop: func(context.Context) error { + return nil + }, + }) + }), + ).RequireRun() + + assert.Equal(t, 1, spy.failures, "Expected RequireRun to fail.") + assert.Contains(t, spy.errors.String(), "application exited with code 42", "Expected to write errors to TB.") + }) + + }) }