diff --git a/cmd/launcher/launcher.go b/cmd/launcher/launcher.go index e33d79a86..9e67d59c7 100644 --- a/cmd/launcher/launcher.go +++ b/cmd/launcher/launcher.go @@ -37,6 +37,7 @@ import ( "github.com/kolide/launcher/ee/control/consumers/flareconsumer" "github.com/kolide/launcher/ee/control/consumers/keyvalueconsumer" "github.com/kolide/launcher/ee/control/consumers/notificationconsumer" + "github.com/kolide/launcher/ee/control/consumers/remoterestartconsumer" "github.com/kolide/launcher/ee/control/consumers/uninstallconsumer" "github.com/kolide/launcher/ee/debug/checkups" desktopRunner "github.com/kolide/launcher/ee/desktop/runner" @@ -469,6 +470,10 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl // register notifications consumer actionsQueue.RegisterActor(notificationconsumer.NotificationSubsystem, notificationConsumer) + remoteRestartConsumer := remoterestartconsumer.New(k) + runGroup.Add("remoteRestart", remoteRestartConsumer.Execute, remoteRestartConsumer.Interrupt) + actionsQueue.RegisterActor(remoterestartconsumer.RemoteRestartActorType, remoteRestartConsumer) + // Set up our tracing instrumentation authTokenConsumer := keyvalueconsumer.New(k.TokenStore()) if err := controlService.RegisterConsumer(authTokensSubsystemName, authTokenConsumer); err != nil { diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go index 235ac6fd5..60dff9946 100644 --- a/cmd/launcher/main.go +++ b/cmd/launcher/main.go @@ -15,6 +15,7 @@ import ( "github.com/kolide/kit/env" "github.com/kolide/kit/logutil" "github.com/kolide/kit/version" + "github.com/kolide/launcher/ee/control/consumers/remoterestartconsumer" "github.com/kolide/launcher/ee/tuf" "github.com/kolide/launcher/ee/watchdog" "github.com/kolide/launcher/pkg/contexts/ctxlog" @@ -153,11 +154,11 @@ func runMain() int { ctx = ctxlog.NewContext(ctx, logger) if err := runLauncher(ctx, cancel, slogger, systemSlogger, opts); err != nil { - if !tuf.IsLauncherReloadNeededErr(err) { + if !tuf.IsLauncherReloadNeededErr(err) && !errors.Is(err, remoterestartconsumer.ErrRemoteRestartRequested) { level.Debug(logger).Log("msg", "run launcher", "stack", fmt.Sprintf("%+v", err)) return 1 } - level.Debug(logger).Log("msg", "runLauncher exited to run newer version of launcher", "err", err.Error()) + level.Debug(logger).Log("msg", "runLauncher exited to reload launcher", "err", err.Error()) if err := runNewerLauncherIfAvailable(ctx, slogger.Logger); err != nil { return 1 } diff --git a/ee/control/consumers/remoterestartconsumer/remoterestartconsumer.go b/ee/control/consumers/remoterestartconsumer/remoterestartconsumer.go new file mode 100644 index 000000000..46e21bb20 --- /dev/null +++ b/ee/control/consumers/remoterestartconsumer/remoterestartconsumer.go @@ -0,0 +1,130 @@ +package remoterestartconsumer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "time" + + "github.com/kolide/launcher/ee/agent/types" +) + +const ( + // RemoteRestartActorType identifies this action/actor type, which performs + // a launcher restart when requested by the control server. This actor type + // belongs to the action subsystem. + RemoteRestartActorType = "remote_restart" + + // restartDelay is the delay after receiving action before triggering the restart. + // We have a delay to allow the actionqueue. + restartDelay = 15 * time.Second +) + +var ( + ErrRemoteRestartRequested = errors.New("need to reload launcher: remote restart requested") +) + +type RemoteRestartConsumer struct { + knapsack types.Knapsack + slogger *slog.Logger + signalRestart chan error + interrupt chan struct{} + interrupted bool +} + +type remoteRestartAction struct { + RunID string `json:"run_id"` // the run ID for the launcher run to restart +} + +func New(knapsack types.Knapsack) *RemoteRestartConsumer { + return &RemoteRestartConsumer{ + knapsack: knapsack, + slogger: knapsack.Slogger().With("component", "remote_restart_consumer"), + signalRestart: make(chan error, 1), + interrupt: make(chan struct{}, 1), + } +} + +// Do implements the `actionqueue.actor` interface, and allows the actionqueue +// to pass `remote_restart` type actions to this consumer. The actionqueue validates +// that this action has not already been performed and that this action is still +// valid (i.e. not expired). `Do` additionally validates that the `run_id` given in +// the action matches the current launcher run ID. +func (r *RemoteRestartConsumer) Do(data io.Reader) error { + var restartAction remoteRestartAction + + if err := json.NewDecoder(data).Decode(&restartAction); err != nil { + return fmt.Errorf("decoding restart action: %w", err) + } + + // The action's run ID indicates the current `runLauncher` that should be restarted. + // If the action's run ID does not match the current run ID, we assume the restart + // has already happened and does not need to happen again. + if restartAction.RunID == "" { + r.slogger.Log(context.TODO(), slog.LevelInfo, + "received remote restart action with blank launcher run ID -- discarding", + ) + return nil + } + if restartAction.RunID != r.knapsack.GetRunID() { + r.slogger.Log(context.TODO(), slog.LevelInfo, + "received remote restart action for incorrect (assuming past) launcher run ID -- discarding", + "action_run_id", restartAction.RunID, + ) + return nil + } + + // Perform the restart by signaling actor shutdown, but delay slightly to give + // the actionqueue a chance to process all actions and store their statuses. + go func() { + r.slogger.Log(context.TODO(), slog.LevelInfo, + "received remote restart action for current launcher run ID -- signaling for restart shortly", + "action_run_id", restartAction.RunID, + "restart_delay", restartDelay.String(), + ) + + select { + case <-r.interrupt: + r.slogger.Log(context.TODO(), slog.LevelDebug, + "received external interrupt before remote restart could be performed", + ) + return + case <-time.After(restartDelay): + r.signalRestart <- ErrRemoteRestartRequested + r.slogger.Log(context.TODO(), slog.LevelInfo, + "signaled for restart after delay", + "action_run_id", restartAction.RunID, + ) + return + } + }() + + return nil +} + +// Execute allows the remote restart consumer to run in the main launcher rungroup. +// It waits until it receives a remote restart action from `Do`, or until it receives +// a `Interrupt` request. +func (r *RemoteRestartConsumer) Execute() (err error) { + select { + case <-r.interrupt: + return nil + case signalRestartErr := <-r.signalRestart: + return signalRestartErr + } +} + +// Interrupt allows the remote restart consumer to run in the main launcher rungroup +// and be shut down when the rungroup shuts down. +func (r *RemoteRestartConsumer) Interrupt(_ error) { + // Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls. + if r.interrupted { + return + } + r.interrupted = true + + r.interrupt <- struct{}{} +} diff --git a/ee/control/consumers/remoterestartconsumer/remoterestartconsumer_test.go b/ee/control/consumers/remoterestartconsumer/remoterestartconsumer_test.go new file mode 100644 index 000000000..b9a455e96 --- /dev/null +++ b/ee/control/consumers/remoterestartconsumer/remoterestartconsumer_test.go @@ -0,0 +1,160 @@ +package remoterestartconsumer + +import ( + "bytes" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/kolide/kit/ulid" + typesmocks "github.com/kolide/launcher/ee/agent/types/mocks" + "github.com/kolide/launcher/pkg/log/multislogger" + "github.com/stretchr/testify/require" +) + +func TestDo(t *testing.T) { + t.Parallel() + + currentRunId := ulid.New() + + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) + mockKnapsack.On("GetRunID").Return(currentRunId) + + remoteRestarter := New(mockKnapsack) + + testAction := remoteRestartAction{ + RunID: currentRunId, + } + testActionRaw, err := json.Marshal(testAction) + require.NoError(t, err) + + // We don't expect an error because we should process the action + require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action") + + // The restarter should delay before sending an error to `signalRestart` + require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it") + time.Sleep(restartDelay + 2*time.Second) + require.Len(t, remoteRestarter.signalRestart, 1, "expected restarter to signal for restart but channel is empty after delay") +} + +func TestDo_DoesNotSignalRestartWhenRunIDDoesNotMatch(t *testing.T) { + t.Parallel() + + currentRunId := ulid.New() + + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) + mockKnapsack.On("GetRunID").Return(currentRunId) + + remoteRestarter := New(mockKnapsack) + + testAction := remoteRestartAction{ + RunID: ulid.New(), // run ID will not match `currentRunId` + } + testActionRaw, err := json.Marshal(testAction) + require.NoError(t, err) + + // We don't expect an error because we want to discard this action + require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for old run ID") + + // The restarter should not send an error to `signalRestart` + time.Sleep(restartDelay + 2*time.Second) + require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty") +} + +func TestDo_DoesNotSignalRestartWhenRunIDIsEmpty(t *testing.T) { + t.Parallel() + + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) + + remoteRestarter := New(mockKnapsack) + + testAction := remoteRestartAction{ + RunID: "", // run ID is empty + } + testActionRaw, err := json.Marshal(testAction) + require.NoError(t, err) + + // We don't expect an error because we want to discard this action + require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for empty run ID") + + // The restarter should not send an error to `signalRestart` + time.Sleep(restartDelay + 2*time.Second) + require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty") +} + +func TestDo_DoesNotRestartIfInterruptedDuringDelay(t *testing.T) { + t.Parallel() + + currentRunId := ulid.New() + + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) + mockKnapsack.On("GetRunID").Return(currentRunId) + + remoteRestarter := New(mockKnapsack) + + testAction := remoteRestartAction{ + RunID: currentRunId, + } + testActionRaw, err := json.Marshal(testAction) + require.NoError(t, err) + + // We don't expect an error because the run ID is correct + require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action") + + // The restarter should delay before sending an error to `signalRestart` + require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it") + + // Now, send an interrupt + remoteRestarter.Interrupt(errors.New("test error")) + + // Sleep beyond the interrupt delay, and confirm we don't try to do a restart when we're already shutting down + time.Sleep(restartDelay + 2*time.Second) + require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have tried to signal for restart when interrupted during restart delay") +} + +func TestInterrupt_Multiple(t *testing.T) { + t.Parallel() + + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) + + remoteRestarter := New(mockKnapsack) + + // Let the remote restarter run for a bit + go remoteRestarter.Execute() + time.Sleep(3 * time.Second) + remoteRestarter.Interrupt(errors.New("test error")) + + // Confirm we can call Interrupt multiple times without blocking + interruptComplete := make(chan struct{}) + expectedInterrupts := 3 + for i := 0; i < expectedInterrupts; i += 1 { + go func() { + remoteRestarter.Interrupt(nil) + interruptComplete <- struct{}{} + }() + } + + receivedInterrupts := 0 + for { + if receivedInterrupts >= expectedInterrupts { + break + } + + select { + case <-interruptComplete: + receivedInterrupts += 1 + continue + case <-time.After(5 * time.Second): + t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts) + t.FailNow() + } + } + + require.Equal(t, expectedInterrupts, receivedInterrupts) +}