diff --git a/pkg/osquery/runtime/osqueryinstance.go b/pkg/osquery/runtime/osqueryinstance.go index 92ce4b088..398551431 100644 --- a/pkg/osquery/runtime/osqueryinstance.go +++ b/pkg/osquery/runtime/osqueryinstance.go @@ -250,6 +250,57 @@ func (i *OsqueryInstance) Launch() error { return fmt.Errorf("could not calculate osquery file paths: %w", err) } + // Register as many of our shutdown functions ahead of time as we can, so that we can make sure + // we fully clean up after any partially-launched erroring instances. + i.errgroup.AddShutdownGoroutine(ctx, "kill_osquery_process", func() error { + if i.cmd.Process == nil { + return nil + } + + // kill osqueryd and children + if err := killProcessGroup(i.cmd); err != nil { + if strings.Contains(err.Error(), "process already finished") || strings.Contains(err.Error(), "no such process") { + i.slogger.Log(ctx, slog.LevelDebug, + "tried to stop osquery, but process already gone", + ) + return nil + } + + return fmt.Errorf("killing osquery process: %w", err) + } + + return nil + }) + // Clean up PID file on shutdown + i.errgroup.AddShutdownGoroutine(ctx, "remove_pid_file", func() error { + // We do a couple retries -- on Windows, the PID file may still be in use + // and therefore unable to be removed. + if err := backoff.WaitFor(func() error { + if err := os.Remove(paths.pidfilePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("removing PID file: %w", err) + } + return nil + }, 5*time.Second, 500*time.Millisecond); err != nil { + return fmt.Errorf("removing PID file %s failed with retries: %w", paths.pidfilePath, err) + } + return nil + }) + + // Clean up socket file on shutdown + i.errgroup.AddShutdownGoroutine(ctx, "remove_socket_file", func() error { + // We do a couple retries -- on Windows, the socket file may still be in use + // and therefore unable to be removed. + if err := backoff.WaitFor(func() error { + if err := os.Remove(paths.extensionSocketPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("removing socket file: %w", err) + } + return nil + }, 5*time.Second, 500*time.Millisecond); err != nil { + return fmt.Errorf("removing socket file %s failed with retries: %w", paths.extensionSocketPath, err) + } + return nil + }) + // Populate augeas lenses, if requested if i.opts.augeasLensFunc != nil { if err := os.MkdirAll(paths.augeasPath, 0755); err != nil { @@ -378,27 +429,6 @@ func (i *OsqueryInstance) Launch() error { } }) - // Kill osquery process on shutdown - i.errgroup.AddShutdownGoroutine(ctx, "kill_osquery_process", func() error { - if i.cmd.Process == nil { - return nil - } - - // kill osqueryd and children - if err := killProcessGroup(i.cmd); err != nil { - if strings.Contains(err.Error(), "process already finished") || strings.Contains(err.Error(), "no such process") { - i.slogger.Log(ctx, slog.LevelDebug, - "tried to stop osquery, but process already gone", - ) - return nil - } - - return fmt.Errorf("killing osquery process: %w", err) - } - - return nil - }) - // Start an extension manager for the extensions that osquery // needs for config/log/etc. i.extensionManagerClient, err = i.StartOsqueryClient(paths) @@ -450,36 +480,6 @@ func (i *OsqueryInstance) Launch() error { return nil }) - // Clean up PID file on shutdown - i.errgroup.AddShutdownGoroutine(ctx, "remove_pid_file", func() error { - // We do a couple retries -- on Windows, the PID file may still be in use - // and therefore unable to be removed. - if err := backoff.WaitFor(func() error { - if err := os.Remove(paths.pidfilePath); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("removing PID file: %w", err) - } - return nil - }, 5*time.Second, 500*time.Millisecond); err != nil { - return fmt.Errorf("removing PID file %s failed with retries: %w", paths.pidfilePath, err) - } - return nil - }) - - // Clean up socket file on shutdown - i.errgroup.AddShutdownGoroutine(ctx, "remove_socket_file", func() error { - // We do a couple retries -- on Windows, the socket file may still be in use - // and therefore unable to be removed. - if err := backoff.WaitFor(func() error { - if err := os.Remove(paths.extensionSocketPath); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("removing socket file: %w", err) - } - return nil - }, 5*time.Second, 500*time.Millisecond); err != nil { - return fmt.Errorf("removing socket file %s failed with retries: %w", paths.extensionSocketPath, err) - } - return nil - }) - return nil } diff --git a/pkg/osquery/runtime/runner.go b/pkg/osquery/runtime/runner.go index dfde2563b..7b7544c9e 100644 --- a/pkg/osquery/runtime/runner.go +++ b/pkg/osquery/runtime/runner.go @@ -136,17 +136,16 @@ func (r *Runner) runInstance(registrationId string) error { // It will retry until it succeeds, or until the runner is shut down. func (r *Runner) launchInstanceWithRetries(registrationId string) (*OsqueryInstance, error) { for { - // Lock to ensure we don't try to restart before launch is complete. + // Add the instance to our instances map right away, so that if we receive a shutdown + // request during launch, we can shut down the instance. r.instanceLock.Lock() instance := newInstance(registrationId, r.knapsack, r.serviceClient, r.opts...) + r.instances[registrationId] = instance + r.instanceLock.Unlock() err := instance.Launch() // Success! if err == nil { - // Now that the instance is running, we can add it to `r.instances` and remove the lock - r.instances[registrationId] = instance - r.instanceLock.Unlock() - r.slogger.Log(context.TODO(), slog.LevelInfo, "runner successfully launched instance", "registration_id", registrationId, @@ -155,13 +154,20 @@ func (r *Runner) launchInstanceWithRetries(registrationId string) (*OsqueryInsta return instance, nil } - // Launching was not successful. Unlock, log the error, and wait to retry. - r.instanceLock.Unlock() + // Launching was not successful. Shut down the instance, log the error, and wait to retry. r.slogger.Log(context.TODO(), slog.LevelWarn, "could not launch instance, will retry after delay", "err", err, "registration_id", registrationId, ) + instance.BeginShutdown() + if err := instance.WaitShutdown(); err != context.Canceled && err != nil { + r.slogger.Log(context.TODO(), slog.LevelWarn, + "error shutting down instance that failed to launch", + "err", err, + "registration_id", registrationId, + ) + } select { case <-r.shutdown: diff --git a/pkg/osquery/runtime/runtime_test.go b/pkg/osquery/runtime/runtime_test.go index b063d36f4..4fb39a2ed 100644 --- a/pkg/osquery/runtime/runtime_test.go +++ b/pkg/osquery/runtime/runtime_test.go @@ -560,8 +560,8 @@ func TestRunnerHandlesImmediateShutdownWithMultipleInstances(t *testing.T) { // Start the instance go runner.Run() - // Wait very briefly for the launch routines to begin, then shut it down - time.Sleep(100 * time.Millisecond) + // Wait briefly for the launch routines to begin, then shut it down + time.Sleep(10 * time.Second) waitShutdown(t, runner, logBytes) // Confirm the default instance was started, and then exited diff --git a/pkg/service/mock/service.go b/pkg/service/mock/service.go index de504d279..ff341df5d 100644 --- a/pkg/service/mock/service.go +++ b/pkg/service/mock/service.go @@ -4,6 +4,7 @@ package mock import ( "context" + "sync" "github.com/kolide/launcher/pkg/service" "github.com/osquery/osquery-go/plugin/distributed" @@ -42,34 +43,48 @@ type KolideService struct { CheckHealthFunc CheckHealthFunc CheckHealthFuncInvoked bool + + invokedLock sync.Mutex } func (s *KolideService) RequestEnrollment(ctx context.Context, enrollSecret string, hostIdentifier string, enrollDetails service.EnrollmentDetails) (string, bool, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.RequestEnrollmentFuncInvoked = true return s.RequestEnrollmentFunc(ctx, enrollSecret, hostIdentifier, enrollDetails) } func (s *KolideService) RequestConfig(ctx context.Context, nodeKey string) (string, bool, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.RequestConfigFuncInvoked = true return s.RequestConfigFunc(ctx, nodeKey) } func (s *KolideService) PublishLogs(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.PublishLogsFuncInvoked = true return s.PublishLogsFunc(ctx, nodeKey, logType, logs) } func (s *KolideService) RequestQueries(ctx context.Context, nodeKey string) (*distributed.GetQueriesResult, bool, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.RequestQueriesFuncInvoked = true return s.RequestQueriesFunc(ctx, nodeKey) } func (s *KolideService) PublishResults(ctx context.Context, nodeKey string, results []distributed.Result) (string, string, bool, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.PublishResultsFuncInvoked = true return s.PublishResultsFunc(ctx, nodeKey, results) } func (s *KolideService) CheckHealth(ctx context.Context) (int32, error) { + s.invokedLock.Lock() + defer s.invokedLock.Unlock() s.CheckHealthFuncInvoked = true return s.CheckHealthFunc(ctx) }