diff --git a/gcp/gcp.go b/gcp/gcp.go index 6d4794b..90ff011 100644 --- a/gcp/gcp.go +++ b/gcp/gcp.go @@ -80,17 +80,11 @@ func Options(opts ...Option) ([]any, error) { //nolint:cyclop,funlen if err != nil { return nil, fmt.Errorf("create otlp trace exporter: %w", err) } - provider := trace.NewTracerProvider( - trace.WithBatcher(exporter), - trace.WithResource(res), - ) appOpts = append(appOpts, - provider, - func(ctx context.Context) error { - <-ctx.Done() - - return provider.Shutdown(context.WithoutCancel(ctx)) - }, + trace.NewTracerProvider( + trace.WithBatcher(exporter), + trace.WithResource(res), + ), ) } if option.metricOpts != nil { @@ -99,17 +93,11 @@ func Options(opts ...Option) ([]any, error) { //nolint:cyclop,funlen return nil, fmt.Errorf("create otlp metric exporter: %w", err) } - provider := metric.NewMeterProvider( - metric.WithReader(metric.NewPeriodicReader(exporter)), - metric.WithResource(res), - ) appOpts = append(appOpts, - provider, - func(ctx context.Context) error { - <-ctx.Done() - - return provider.Shutdown(context.WithoutCancel(ctx)) - }, + metric.NewMeterProvider( + metric.WithReader(metric.NewPeriodicReader(exporter)), + metric.WithResource(res), + ), ) } diff --git a/gcp/gcp_test.go b/gcp/gcp_test.go index dfb7d94..fe73c5e 100644 --- a/gcp/gcp_test.go +++ b/gcp/gcp_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/sdk/metric" - "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/sdk/trace" "github.com/nil-go/nilgo/gcp" ) @@ -83,12 +83,10 @@ func TestOptions(t *testing.T) { assertion: func(t *testing.T, opts []any) { t.Helper() - assert.Len(t, opts, 3) + assert.Len(t, opts, 2) _, ok := opts[0].(slog.Handler) assert.True(t, ok) - _, ok = opts[1].(trace.TracerProvider) - assert.True(t, ok) - _, ok = opts[2].(func(context.Context) error) + _, ok = opts[1].(*trace.TracerProvider) assert.True(t, ok) }, }, @@ -101,13 +99,11 @@ func TestOptions(t *testing.T) { assertion: func(t *testing.T, opts []any) { t.Helper() - assert.Len(t, opts, 3) + assert.Len(t, opts, 2) _, ok := opts[0].(slog.Handler) assert.True(t, ok) _, ok = opts[1].(*metric.MeterProvider) assert.True(t, ok) - _, ok = opts[2].(func(context.Context) error) - assert.True(t, ok) }, }, { diff --git a/run.go b/run.go index 5cadde7..7903022 100644 --- a/run.go +++ b/run.go @@ -48,8 +48,18 @@ func Run(args ...any) error { //nolint:cyclop logOpts = append(logOpts, opt) case trace.TracerProvider: otel.SetTracerProvider(opt) + if provider, ok := opt.(interface { + Shutdown(ctx context.Context) error + }); ok { + runOpts = append(runOpts, run.WithPostRun(provider.Shutdown)) + } case metric.MeterProvider: otel.SetMeterProvider(opt) + if provider, ok := opt.(interface { + Shutdown(ctx context.Context) error + }); ok { + runOpts = append(runOpts, run.WithPostRun(provider.Shutdown)) + } case run.Option: runOpts = append(runOpts, opt) case func(context.Context) error: diff --git a/run/option.go b/run/option.go index 811e7c6..fa02ab2 100644 --- a/run/option.go +++ b/run/option.go @@ -15,6 +15,13 @@ func WithPreRun(runs ...func(context.Context) error) Option { } } +// WithPostRun provides runs to execute after the main runs provided in Runner.Run. +func WithPostRun(runs ...func(context.Context) error) Option { + return func(options *options) { + options.postRuns = append(options.postRuns, runs...) + } +} + // WithStartGate provides gates to block the start of main runs provided in Runner.Run, // until all start gates returns without error. // diff --git a/run/runner.go b/run/runner.go index 0ead789..39c4faf 100644 --- a/run/runner.go +++ b/run/runner.go @@ -16,6 +16,7 @@ import ( // To create an Runner, use [New]. type Runner struct { preRuns []func(context.Context) error + postRuns []func(context.Context) error startGates []func(context.Context) error stopGates []func(context.Context) error } @@ -39,7 +40,7 @@ func New(opts ...Option) Runner { // The execution can be interrupted if any run returns non-nil error, // or it receives an OS signal syscall.SIGINT or syscall.SIGTERM. // It waits all run return unless it's forcefully terminated by OS. -func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) error { +func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) error { //nolint:funlen allRuns := make([]func(context.Context) error, 0, len(r.preRuns)+1) startGates := slices.Clone(r.startGates) if len(r.preRuns) > 0 { @@ -93,6 +94,13 @@ func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) er if err = Parallel(signalCtx, startGates...); err != nil { return err } + defer func() { + // Wait for all post-runs to finish. + e := Parallel(runCtx, r.postRuns...) + if err == nil { + err = e + } + }() return Parallel(runCtx, runs...) }, diff --git a/run/runner_test.go b/run/runner_test.go index cb19682..ce6aaee 100644 --- a/run/runner_test.go +++ b/run/runner_test.go @@ -48,6 +48,22 @@ func TestRunner_Run(t *testing.T) { err: "pre-run error", ran: true, }, + { + description: "with post-run", + runner: run.New(run.WithPostRun(func(context.Context) error { + return nil + })), + ran: true, + }, + { + description: "post-run error", + runner: run.New(run.WithPostRun(func(context.Context) error { + return errors.New("post-run error") + }), + ), + err: "post-run error", + ran: true, + }, { description: "with start gate", runner: run.New(run.WithStartGate(func(context.Context) error { return nil })),