From e5dcb70ea30b8e21a84da15cc09771fa7dec8b19 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Thu, 14 Nov 2024 21:35:52 -0500 Subject: [PATCH 01/14] test config writer --- lib/autoupdate/agent/config.go | 95 ++++++++++++++++++++++ lib/autoupdate/agent/config_test.go | 66 +++++++++++++++ lib/autoupdate/agent/installer.go | 23 ++++-- lib/autoupdate/agent/process.go | 42 +++++++--- lib/autoupdate/agent/updater.go | 46 +++++++++-- lib/autoupdate/agent/updater_test.go | 86 +++++++++++++------- tool/teleport-update/main.go | 116 +++++++++++++-------------- 7 files changed, 366 insertions(+), 108 deletions(-) create mode 100644 lib/autoupdate/agent/config.go create mode 100644 lib/autoupdate/agent/config_test.go diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go new file mode 100644 index 0000000000000..6eb9320c66ac2 --- /dev/null +++ b/lib/autoupdate/agent/config.go @@ -0,0 +1,95 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "os" + "path/filepath" + "text/template" + + "github.com/gravitational/trace" +) + +const ( + teleportDropinTemplate = `# teleport-update +[Service] +ExecStopPost=/bin/bash -c 'date +%%s > {{.DataDir}}/last-restart' +` + updateServiceTemplate = `# teleport-update +[Unit] +Description=Teleport update service + +[Service] +Type=oneshot +ExecStart={{.LinkDir}}/bin/teleport-update update +` + updateTimerTemplate = `# teleport-update +[Unit] +Description=Teleport update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport.service +` +) + +func WriteConfigFiles(linkDir, dataDir string) error { + // TODO(sclevine): revert on failure + + dropinPath := filepath.Join(linkDir, serviceDir, serviceName+".d", serviceDropinName) + err := writeTemplate(dropinPath, teleportDropinTemplate, linkDir, dataDir) + if err != nil { + return trace.Wrap(err) + } + servicePath := filepath.Join(linkDir, serviceDir, updateServiceName) + err = writeTemplate(servicePath, updateServiceTemplate, linkDir, dataDir) + if err != nil { + return trace.Wrap(err) + } + timerPath := filepath.Join(linkDir, serviceDir, updateTimerName) + err = writeTemplate(timerPath, updateTimerTemplate, linkDir, dataDir) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +func writeTemplate(path, t, linkDir, dataDir string) error { + if err := os.MkdirAll(filepath.Dir(path), systemDirMode); err != nil { + return trace.Wrap(err) + } + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, configFileMode) + if err != nil { + return trace.Wrap(err) + } + defer f.Close() + tmpl, err := template.New(filepath.Base(path)).Parse(t) + if err != nil { + return trace.Wrap(err) + } + err = tmpl.Execute(f, struct { + LinkDir string + DataDir string + }{linkDir, dataDir}) + return trace.Wrap(f.Close()) +} diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go new file mode 100644 index 0000000000000..1530bae9b868b --- /dev/null +++ b/lib/autoupdate/agent/config_test.go @@ -0,0 +1,66 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + libdefaults "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils/golden" +) + +func TestWriteConfigFiles(t *testing.T) { + t.Parallel() + linkDir := t.TempDir() + dataDir := t.TempDir() + err := WriteConfigFiles(linkDir, dataDir) + require.NoError(t, err) + + for _, p := range []string{ + filepath.Join(linkDir, serviceDir, serviceName+".d", serviceDropinName), + filepath.Join(linkDir, serviceDir, updateServiceName), + filepath.Join(linkDir, serviceDir, updateTimerName), + } { + t.Run(filepath.Base(p), func(t *testing.T) { + data, err := os.ReadFile(p) + require.NoError(t, err) + data = replaceValues(data, map[string]string{ + DefaultLinkDir: linkDir, + libdefaults.DataDir: dataDir, + }) + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func replaceValues(data []byte, m map[string]string) []byte { + for k, v := range m { + data = bytes.ReplaceAll(data, []byte(v), + []byte(k)) + } + return data +} diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 957e90779c2ab..3312d8bae9190 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -55,11 +55,17 @@ const ( systemDirMode = 0755 ) -var ( +const ( // serviceDir contains the relative path to the Teleport SystemD service dir. - serviceDir = filepath.Join("lib", "systemd", "system") + serviceDir = "lib/systemd/system" // serviceName contains the name of the Teleport SystemD service file. serviceName = "teleport.service" + // serviceDropinName contains the name of the Teleport Systemd service drop-in to support updates. + serviceDropinName = "teleport-update.conf" + // updateServiceName contains the name of the Teleport Update Systemd service + updateServiceName = "teleport-update.service" + // updateTimerName contains the name of the Teleport Update Systemd timer + updateTimerName = "teleport-update.timer" ) // LocalInstaller manages the creation and removal of installations @@ -539,7 +545,7 @@ func (li *LocalInstaller) forceLinks(ctx context.Context, binDir, svcDir string) dst := filepath.Join(li.LinkServiceDir, serviceName) orig, err := forceCopy(dst, src, maxServiceFileSize) if err != nil && !errors.Is(err, os.ErrExist) { - return revert, trace.Errorf("failed to create file for %s: %w", serviceName, err) + return revert, trace.Errorf("failed to write file %s: %w", serviceName, err) } if orig != nil { revertFiles = append(revertFiles, *orig) @@ -598,6 +604,13 @@ func forceCopy(dst, src string, n int64) (orig *smallFile, err error) { if err != nil { return nil, trace.Wrap(err) } + return forceWrite(dst, srcData, n) +} + +func forceWrite(dst string, data []byte, n int64) (orig *smallFile, err error) { + if l := len(data); int64(l) > n { + return nil, trace.Errorf("data too large for file (%d > %d)", l, n) + } fi, err := os.Lstat(dst) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, trace.Wrap(err) @@ -614,11 +627,11 @@ func forceCopy(dst, src string, n int64) (orig *smallFile, err error) { if err != nil { return nil, trace.Wrap(err) } - if bytes.Equal(srcData, orig.data) { + if bytes.Equal(data, orig.data) { return nil, trace.Wrap(os.ErrExist) } } - err = renameio.WriteFile(dst, srcData, configFileMode) + err = renameio.WriteFile(dst, data, configFileMode) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index eba70aa56a690..0dffd6d79b089 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -41,6 +41,8 @@ type SystemdService struct { // Attempts a graceful reload before a hard restart. // See Process interface for more details. func (s SystemdService) Reload(ctx context.Context) error { + // TODO(sclevine): allow server to force restart instead of reload + if err := s.checkSystem(ctx); err != nil { return trace.Wrap(err) } @@ -106,9 +108,35 @@ func (s SystemdService) checkSystem(ctx context.Context) error { // Output sent to stdout is logged at debug level. // Output sent to stderr is logged at the level specified by errLevel. func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { - cmd := exec.CommandContext(ctx, "systemctl", args...) - stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel} - stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug} + cmd := &LocalExec{ + Log: s.Log, + ErrLevel: errLevel, + OutLevel: slog.LevelDebug, + } + code, err := cmd.Run(ctx, "systemctl", args...) + if err != nil { + s.Log.Log(ctx, errLevel, "Failed to run systemctl.", + "args", args, + "code", code, + errorKey, err) + } + return code +} + +type LocalExec struct { + // Log contains a slog logger. + // Defaults to slog.Default() if nil. + Log *slog.Logger + // ErrLevel is the log level for stderr. + ErrLevel slog.Level + // OutLevel is the log level for stdout. + OutLevel slog.Level +} + +func (c *LocalExec) Run(ctx context.Context, name string, args ...string) (int, error) { + cmd := exec.CommandContext(ctx, name, args...) + stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel} + stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel} cmd.Stderr = stderr cmd.Stdout = stdout err := cmd.Run() @@ -122,13 +150,7 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args if code == 255 { code = -1 } - if err != nil { - s.Log.Log(ctx, errLevel, "Failed to run systemctl.", - "args", args, - "code", code, - "error", err) - } - return code + return code, trace.Wrap(err) } // lineLogger logs each line written to it. diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 9625481df2cd2..60c487fbb0bf7 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -46,6 +46,10 @@ const ( DefaultLinkDir = "/usr/local" // DefaultSystemDir is the location where packaged Teleport binaries and services are installed. DefaultSystemDir = "/usr/local/teleport-system" + // VersionsDirName specifies the name of the subdirectory inside the Teleport data dir for storing Teleport versions. + VersionsDirName = "versions" + // BinaryName specifies the name of the updater binary. + BinaryName = "teleport-update" ) const ( @@ -136,16 +140,20 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { if cfg.SystemDir == "" { cfg.SystemDir = DefaultSystemDir } - if cfg.VersionsDir == "" { - cfg.VersionsDir = filepath.Join(libdefaults.DataDir, "versions") + if cfg.DataDir == "" { + cfg.DataDir = libdefaults.DataDir + } + installDir := filepath.Join(cfg.DataDir, VersionsDirName) + if err := os.MkdirAll(installDir, systemDirMode); err != nil { + return nil, trace.Errorf("failed to create install directory: %w", err) } return &Updater{ Log: cfg.Log, Pool: certPool, InsecureSkipVerify: cfg.InsecureSkipVerify, - ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName), + ConfigPath: filepath.Join(installDir, updateConfigName), Installer: &LocalInstaller{ - InstallDir: cfg.VersionsDir, + InstallDir: installDir, LinkBinDir: filepath.Join(cfg.LinkDir, "bin"), // For backwards-compatibility with symlinks created by package-based installs, we always // link into /lib/systemd/system, even though, e.g., /usr/local/lib/systemd/system would work. @@ -161,6 +169,18 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { ServiceName: "teleport.service", Log: cfg.Log, }, + Setup: func(ctx context.Context) error { + exec := &LocalExec{ + Log: cfg.Log, + ErrLevel: slog.LevelError, + OutLevel: slog.LevelDebug, + } + _, err := exec.Run(ctx, filepath.Join(cfg.LinkDir, "bin", BinaryName), + "--data-dir", cfg.DataDir, + "--link-dir", cfg.LinkDir, + "setup") + return err + }, }, nil } @@ -174,8 +194,8 @@ type LocalUpdaterConfig struct { // DownloadTimeout is a timeout for file download requests. // Defaults to no timeout. DownloadTimeout time.Duration - // VersionsDir for installing Teleport (usually /var/lib/teleport/versions). - VersionsDir string + // DataDir for Teleport (usually /var/lib/teleport). + DataDir string // LinkDir for installing Teleport (usually /usr/local). LinkDir string // SystemDir for package-installed Teleport installations (usually /usr/local/teleport-system). @@ -196,6 +216,8 @@ type Updater struct { Installer Installer // Process manages a running instance of Teleport. Process Process + // Setup installs the Teleport updater service using the linked installation. + Setup func(ctx context.Context) error } // Installer provides an API for installing Teleport agents. @@ -476,7 +498,7 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s } } - // Install the desired version (or validate existing installation) + // Install and link the desired version (or validate existing installation) template := cfg.Spec.URLTemplate if template == "" { @@ -491,6 +513,16 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s return trace.Errorf("failed to link: %w", err) } + // Verify that the linked installation contains a valid updater binary, + // and use it to update the updater's service files. + + if err := u.Setup(ctx); err != nil { + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") + } + return trace.Errorf("failed to setup updater: %w", err) + } + // If we fail to revert after this point, the next update/enable will // fix the link to restore the active version. diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index 1197ac3d5a795..8f919762f53fb 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -83,7 +83,13 @@ func TestUpdater_Disable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -92,11 +98,7 @@ func TestUpdater_Disable(t *testing.T) { err = os.WriteFile(cfgPath, b, 0600) require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) + err = updater.Disable(context.Background()) if tt.errMatch != "" { require.Error(t, err) @@ -142,6 +144,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls int reloadCalls int revertCalls int + setupCalls int errMatch string }{ { @@ -166,6 +169,7 @@ func TestUpdater_Update(t *testing.T) { requestGroup: "group", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "updates disabled during window", @@ -295,6 +299,7 @@ func TestUpdater_Update(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "backup version kept when no change", @@ -338,6 +343,7 @@ func TestUpdater_Update(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "invalid metadata", @@ -368,6 +374,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls: 2, reloadCalls: 0, revertCalls: 1, + setupCalls: 1, errMatch: "sync error", }, { @@ -394,6 +401,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls: 2, reloadCalls: 2, revertCalls: 1, + setupCalls: 1, errMatch: "reload error", }, } @@ -419,7 +427,13 @@ func TestUpdater_Update(t *testing.T) { t.Cleanup(server.Close) dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -430,12 +444,6 @@ func TestUpdater_Update(t *testing.T) { require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var ( installedVersion string installedTemplate string @@ -481,6 +489,12 @@ func TestUpdater_Update(t *testing.T) { }, } + var setupCalls int + updater.Setup = func(_ context.Context) error { + setupCalls++ + return nil + } + ctx := context.Background() err = updater.Update(ctx) if tt.errMatch != "" { @@ -498,6 +512,7 @@ func TestUpdater_Update(t *testing.T) { require.Equal(t, tt.syncCalls, syncCalls) require.Equal(t, tt.reloadCalls, reloadCalls) require.Equal(t, tt.revertCalls, revertCalls) + require.Equal(t, tt.setupCalls, setupCalls) if tt.cfg == nil { _, err := os.Stat(cfgPath) @@ -594,7 +609,13 @@ func TestUpdater_LinkPackage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -604,12 +625,6 @@ func TestUpdater_LinkPackage(t *testing.T) { require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var tryLinkSystemCalls int updater.Installer = &testInstaller{ FuncTryLinkSystem: func(_ context.Context) error { @@ -659,6 +674,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls int reloadCalls int revertCalls int + setupCalls int errMatch string }{ { @@ -681,6 +697,7 @@ func TestUpdater_Enable(t *testing.T) { requestGroup: "group", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "config from user", @@ -706,6 +723,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "new-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "already enabled", @@ -725,6 +743,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "insecure URL", @@ -766,6 +785,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 0, + setupCalls: 1, }, { name: "backup version removed on install", @@ -784,6 +804,7 @@ func TestUpdater_Enable(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "backup version kept for validation", @@ -802,6 +823,7 @@ func TestUpdater_Enable(t *testing.T) { removedVersion: "", syncCalls: 1, reloadCalls: 0, + setupCalls: 1, }, { name: "config does not exist", @@ -811,6 +833,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "FIPS and Enterprise flags", @@ -820,6 +843,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "invalid metadata", @@ -836,6 +860,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls: 2, reloadCalls: 0, revertCalls: 1, + setupCalls: 1, errMatch: "sync error", }, { @@ -848,6 +873,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls: 2, reloadCalls: 2, revertCalls: 1, + setupCalls: 1, errMatch: "reload error", }, } @@ -855,7 +881,13 @@ func TestUpdater_Enable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -886,12 +918,6 @@ func TestUpdater_Enable(t *testing.T) { tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://") } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var ( installedVersion string installedTemplate string @@ -936,6 +962,11 @@ func TestUpdater_Enable(t *testing.T) { return tt.reloadErr }, } + var setupCalls int + updater.Setup = func(_ context.Context) error { + setupCalls++ + return nil + } ctx := context.Background() err = updater.Enable(ctx, tt.userCfg) @@ -954,6 +985,7 @@ func TestUpdater_Enable(t *testing.T) { require.Equal(t, tt.syncCalls, syncCalls) require.Equal(t, tt.reloadCalls, reloadCalls) require.Equal(t, tt.revertCalls, revertCalls) + require.Equal(t, tt.setupCalls, setupCalls) if tt.cfg == nil && err != nil { _, err := os.Stat(cfgPath) diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index d559ad3e75cdd..7e6ebffa93403 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -21,6 +21,7 @@ package main import ( "context" "errors" + "fmt" "log/slog" "os" "os/signal" @@ -58,10 +59,8 @@ const ( ) const ( - // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions. - versionsDirName = "versions" - // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. - lockFileName = ".lock" + // lockFileName specifies the name of the file containing the flock lock preventing concurrent updater execution. + lockFileName = ".update-lock" ) var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater) @@ -91,7 +90,7 @@ func Run(args []string) error { ctx := context.Background() ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) - app := libutils.InitCLIParser("teleport-update", appHelp).Interspersed(false) + app := libutils.InitCLIParser(autoupdate.BinaryName, appHelp).Interspersed(false) app.Flag("debug", "Verbose logging to stdout."). Short('d').BoolVar(&ccfg.Debug) app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited."). @@ -103,7 +102,7 @@ func Run(args []string) error { app.HelpFlag.Short('h') - versionCmd := app.Command("version", "Print the version of your teleport-updater binary.") + versionCmd := app.Command("version", fmt.Sprintf("Print the version of your %s binary.", autoupdate.BinaryName)) enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.") enableCmd.Flag("proxy", "Address of the Teleport Proxy."). @@ -122,6 +121,9 @@ func Run(args []string) error { linkCmd := app.Command("link", "Link the system installation of Teleport from the Teleport package, if auto-updates is disabled.") + setupCmd := app.Command("setup", "Write configuration files that run the update subcommand on a timer."). + Hidden() + libutils.UpdateAppUsageTemplate(app, args) command, err := app.Parse(args) if err != nil { @@ -143,6 +145,8 @@ func Run(args []string) error { err = cmdUpdate(ctx, &ccfg) case linkCmd.FullCommand(): err = cmdLink(ctx, &ccfg) + case setupCmd.FullCommand(): + err = cmdSetup(ctx, &ccfg) case versionCmd.FullCommand(): modules.GetModules().PrintVersion() default: @@ -172,12 +176,16 @@ func setupLogger(debug bool, format string) error { // cmdDisable disables updates. func cmdDisable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } - - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -186,15 +194,6 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Disable(ctx); err != nil { return trace.Wrap(err) } @@ -203,13 +202,18 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { // cmdEnable enables updates and triggers an initial update. func cmdEnable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } // Ensure enable can't run concurrently. - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -218,16 +222,6 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil { return trace.Wrap(err) } @@ -236,13 +230,17 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { // cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } - // Ensure update can't run concurrently. - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -252,15 +250,6 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Update(ctx); err != nil { return trace.Wrap(err) } @@ -269,10 +258,18 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { // cmdLink creates system package links if no version is linked and auto-updates is disabled. func cmdLink(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) + } // Skip operation and warn if the updater is currently running. - unlock, err := libutils.FSTryReadLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSTryReadLock(filepath.Join(ccfg.DataDir, lockFileName)) if errors.Is(err, libutils.ErrUnsuccessfulLockTry) { plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.") return nil @@ -285,17 +282,18 @@ func cmdLink(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } + if err := updater.LinkPackage(ctx); err != nil { return trace.Wrap(err) } return nil } + +// cmdSetup writes configuration files that are needed to run teleport-update update. +func cmdSetup(ctx context.Context, ccfg *cliConfig) error { + err := autoupdate.WriteConfigFiles(ccfg.LinkDir, ccfg.DataDir) + if err != nil { + return trace.Errorf("failed to write config files: %w", err) + } + return nil +} From df302161f7f4f13cf06d516877d9a165e5aa2548 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Thu, 14 Nov 2024 21:38:02 -0500 Subject: [PATCH 02/14] test fixtures --- .../TestWriteConfigFiles/teleport-update.conf.golden | 3 +++ .../teleport-update.service.golden | 7 +++++++ .../TestWriteConfigFiles/teleport-update.timer.golden | 11 +++++++++++ 3 files changed, 21 insertions(+) create mode 100644 lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden create mode 100644 lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden create mode 100644 lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden new file mode 100644 index 0000000000000..1e2cfb333bcbc --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden @@ -0,0 +1,3 @@ +# teleport-update +[Service] +ExecStopPost=/bin/bash -c 'date +%%s > /var/lib/teleport/last-restart' diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden new file mode 100644 index 0000000000000..b8d6f7f75ae72 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden @@ -0,0 +1,7 @@ +# teleport-update +[Unit] +Description=Teleport update service + +[Service] +Type=oneshot +ExecStart=/usr/local/bin/teleport-update update diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden new file mode 100644 index 0000000000000..dbc8e1d12c404 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden @@ -0,0 +1,11 @@ +# teleport-update +[Unit] +Description=Teleport update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport.service From a293634dfe0385f78cbefade878eb8dadcce1e93 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Thu, 14 Nov 2024 21:51:43 -0500 Subject: [PATCH 03/14] cleanup --- lib/autoupdate/agent/installer.go | 11 ++--------- lib/autoupdate/agent/process.go | 3 +++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 3312d8bae9190..530911a3c1184 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -604,13 +604,6 @@ func forceCopy(dst, src string, n int64) (orig *smallFile, err error) { if err != nil { return nil, trace.Wrap(err) } - return forceWrite(dst, srcData, n) -} - -func forceWrite(dst string, data []byte, n int64) (orig *smallFile, err error) { - if l := len(data); int64(l) > n { - return nil, trace.Errorf("data too large for file (%d > %d)", l, n) - } fi, err := os.Lstat(dst) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, trace.Wrap(err) @@ -627,11 +620,11 @@ func forceWrite(dst string, data []byte, n int64) (orig *smallFile, err error) { if err != nil { return nil, trace.Wrap(err) } - if bytes.Equal(data, orig.data) { + if bytes.Equal(srcData, orig.data) { return nil, trace.Wrap(os.ErrExist) } } - err = renameio.WriteFile(dst, data, configFileMode) + err = renameio.WriteFile(dst, srcData, configFileMode) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 0dffd6d79b089..5161b47cac86c 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -123,6 +123,7 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args return code } +// LocalExec runs a command locally, logging any output. type LocalExec struct { // Log contains a slog logger. // Defaults to slog.Default() if nil. @@ -133,6 +134,8 @@ type LocalExec struct { OutLevel slog.Level } +// Run the command. Same arguments as exec.CommandContext. +// Outputs the status code, or -1 if out-of-range or unstarted. func (c *LocalExec) Run(ctx context.Context, name string, args ...string) (int, error) { cmd := exec.CommandContext(ctx, name, args...) stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel} From 3b8ca5ff198abecb28f26d4b4427f4ae7af67098 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Thu, 14 Nov 2024 21:53:01 -0500 Subject: [PATCH 04/14] cleanup --- lib/autoupdate/agent/process.go | 8 ++++---- lib/autoupdate/agent/updater.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 5161b47cac86c..e2bc76f1c4be9 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -108,7 +108,7 @@ func (s SystemdService) checkSystem(ctx context.Context) error { // Output sent to stdout is logged at debug level. // Output sent to stderr is logged at the level specified by errLevel. func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { - cmd := &LocalExec{ + cmd := &localExec{ Log: s.Log, ErrLevel: errLevel, OutLevel: slog.LevelDebug, @@ -123,8 +123,8 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args return code } -// LocalExec runs a command locally, logging any output. -type LocalExec struct { +// localExec runs a command locally, logging any output. +type localExec struct { // Log contains a slog logger. // Defaults to slog.Default() if nil. Log *slog.Logger @@ -136,7 +136,7 @@ type LocalExec struct { // Run the command. Same arguments as exec.CommandContext. // Outputs the status code, or -1 if out-of-range or unstarted. -func (c *LocalExec) Run(ctx context.Context, name string, args ...string) (int, error) { +func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) { cmd := exec.CommandContext(ctx, name, args...) stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel} stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel} diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 60c487fbb0bf7..e6e3eb26f734e 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -170,7 +170,7 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { Log: cfg.Log, }, Setup: func(ctx context.Context) error { - exec := &LocalExec{ + exec := &localExec{ Log: cfg.Log, ErrLevel: slog.LevelError, OutLevel: slog.LevelDebug, From 344ce35f2320f2885a2b2d9d9350c7b227dd1c73 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Fri, 15 Nov 2024 02:00:31 -0500 Subject: [PATCH 05/14] healthcheck --- lib/autoupdate/agent/process.go | 81 ++++++++++++- lib/autoupdate/agent/process_test.go | 173 +++++++++++++++++++++++++++ 2 files changed, 252 insertions(+), 2 deletions(-) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index e2bc76f1c4be9..1c0f0431c380f 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -25,14 +25,19 @@ import ( "log/slog" "os" "os/exec" + "strconv" + "time" "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" ) // SystemdService manages a Teleport systemd service. type SystemdService struct { // ServiceName specifies the systemd service name. ServiceName string + // LastRestartPath is a path to a file containing the last restart time. + LastRestartPath string // Log contains a logger. Log *slog.Logger } @@ -75,10 +80,82 @@ func (s SystemdService) Reload(ctx context.Context) error { default: s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") } + s.Log.InfoContext(ctx, "Monitoring for excessive restarts.") + return trace.Wrap(s.monitor(ctx)) +} - // TODO(sclevine): Ensure restart was successful and verify healthcheck. +func (s SystemdService) monitor(ctx context.Context) error { + tickC := time.NewTicker(2 * time.Second).C + restartC := make(chan int64) + g := &errgroup.Group{} + g.Go(func() error { + return s.tickRestarts(ctx, restartC, tickC) + }) + err := s.monitorRestarts(ctx, restartC, 2, 6) + if err := g.Wait(); err != nil { + s.Log.WarnContext(ctx, "Unable to determine last restart time. Failed to monitor for crash loops.", errorKey, err) + } + return trace.Wrap(err) +} - return nil +func (s SystemdService) monitorRestarts(ctx context.Context, timeCh <-chan int64, maxStops, minClean int) error { + var ( + clean, stops int + restartTime int64 + ) + // TODO: thread init value of restartTime + for { + // wait first to ensure we initial stop has completed + select { + case <-ctx.Done(): + return ctx.Err() + case t := <-timeCh: + if t != restartTime { + clean = 0 + restartTime = t + stops++ + } else { + clean++ + } + } + switch { + case stops > maxStops: + return trace.Errorf("detected crash loop") + case clean >= minClean: + return nil + } + } +} + +func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC <-chan time.Time) error { + var err error + for { + // two select statements -> never skip restarts + select { + case <-tickC: + case <-ctx.Done(): + return err + } + var t int64 + t, err = s.getRestartTime() + select { + case ch <- t: + case <-ctx.Done(): + return err + } + } +} + +func (s SystemdService) getRestartTime() (int64, error) { + b, err := os.ReadFile(s.LastRestartPath) + if err != nil { + return 0, trace.Wrap(err) + } + restart, err := strconv.ParseInt(string(bytes.TrimSpace(b)), 10, 64) + if err != nil { + return 0, trace.Wrap(err) + } + return restart, nil } // Sync systemd service configuration by running systemctl daemon-reload. diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index 5ffa70dd0091e..6d40d95068767 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -21,8 +21,13 @@ package agent import ( "bytes" "context" + "errors" + "fmt" "log/slog" + "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -69,3 +74,171 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr { } return slog.Attr{Key: a.Key, Value: a.Value} } + +func TestMonitor(t *testing.T) { + t.Parallel() + + svc := &SystemdService{ + Log: slog.Default(), + } + + for _, tt := range []struct { + name string + ticks []int64 + maxStops int + minClean int + errored bool + canceled bool + }{ + { + name: "one restart", + ticks: []int64{1, 1, 1, 1}, + maxStops: 2, + minClean: 3, + errored: false, + }, + { + name: "two restarts", + ticks: []int64{1, 1, 1, 2, 2, 2, 2}, + maxStops: 2, + minClean: 3, + errored: false, + }, + { + name: "too many restarts long", + ticks: []int64{1, 1, 1, 2, 2, 2, 3}, + maxStops: 2, + minClean: 3, + errored: true, + }, + { + name: "too many restarts short", + ticks: []int64{1, 2, 3}, + maxStops: 2, + minClean: 3, + errored: true, + }, + { + name: "too many restarts after okay", + ticks: []int64{1, 1, 1, 1, 2, 3}, + maxStops: 2, + minClean: 3, + errored: false, + }, + { + name: "too many restarts before okay", + ticks: []int64{1, 2, 3, 3, 3, 3}, + maxStops: 2, + minClean: 3, + errored: true, + }, + { + name: "no error if no minClean", + ticks: []int64{1, 2, 3}, + maxStops: 2, + minClean: 0, + errored: false, + }, + { + name: "cancel", + maxStops: 2, + minClean: 3, + canceled: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan int64) + go func() { + defer cancel() // always quit after last tick + for _, tick := range tt.ticks { + ch <- tick + } + }() + err := svc.monitorRestarts(ctx, ch, tt.maxStops, tt.minClean) + require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) + if !tt.canceled { + require.Equal(t, tt.errored, err != nil) + } + }) + } +} + +func TestTicks(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + restartPath := filepath.Join(dir, "restart") + svc := &SystemdService{ + Log: slog.Default(), + LastRestartPath: restartPath, + } + + for _, tt := range []struct { + name string + ticks []int64 + errored bool + }{ + { + name: "consistent", + ticks: []int64{1, 1, 1}, + errored: false, + }, + { + name: "divergent", + ticks: []int64{1, 2, 3}, + errored: false, + }, + { + name: "start error", + ticks: []int64{-1, 1, 1}, + errored: false, + }, + { + name: "ephemeral error", + ticks: []int64{1, -1, 1}, + errored: false, + }, + { + name: "end error", + ticks: []int64{1, 1, -1}, + errored: true, + }, + { + name: "cancel", + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + tickC := make(chan time.Time) + ch := make(chan int64) + + go func() { + defer cancel() // always quit after last tick or fail + for _, tick := range tt.ticks { + if tick >= 0 { + err := os.WriteFile(restartPath, []byte(fmt.Sprintln(tick)), os.ModePerm) + require.NoError(t, err) + } else { + _ = os.Remove(restartPath) + } + tickC <- time.Now() + res := <-ch + if tick < 0 { + tick = 0 + } + require.Equal(t, tick, res) + } + }() + err := svc.tickRestarts(ctx, ch, tickC) + require.Equal(t, tt.errored, err != nil) + if err != nil { + require.ErrorIs(t, err, os.ErrNotExist) + } + }) + } +} From a73558a58473a7f786f1e7b73e85ffda7a89ff4a Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Fri, 15 Nov 2024 15:36:28 -0500 Subject: [PATCH 06/14] cleanup --- lib/autoupdate/agent/process.go | 70 ++++++++++++---- lib/autoupdate/agent/process_test.go | 120 +++++++++++++++------------ 2 files changed, 122 insertions(+), 68 deletions(-) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 1c0f0431c380f..9d6ec3a87d878 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -51,6 +51,10 @@ func (s SystemdService) Reload(ctx context.Context) error { if err := s.checkSystem(ctx); err != nil { return trace.Wrap(err) } + + // If getRestartTime fails consistently, error will be returned from monitor. + initRestartTime, _ := s.getRestartTime() + // Command error codes < 0 indicate that we are unable to run the command. // Errors from s.systemctl are logged along with stderr and stdout (debug only). @@ -81,54 +85,85 @@ func (s SystemdService) Reload(ctx context.Context) error { s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") } s.Log.InfoContext(ctx, "Monitoring for excessive restarts.") - return trace.Wrap(s.monitor(ctx)) + return trace.Wrap(s.monitor(ctx, initRestartTime)) } -func (s SystemdService) monitor(ctx context.Context) error { - tickC := time.NewTicker(2 * time.Second).C +const ( + restartMonitorInterval = 2 * time.Second + minCleanIntervalsBeforeSuccess = 6 + maxRestartsBeforeFailure = 2 +) + +// monitor for excessive restarts by polling the LastRestartPath file. +// This function detects crash-looping while minimizing its own runtime during updates. +// To accomplish this, monitor fails after seeing maxRestartsBeforeFailure, and stops checking +// after seeing minCleanIntervalsBeforeSuccess clean intervals. +// initRestartTime may be provided as a baseline restart time, to ensure we catch the initial restart. +func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + tickC := time.NewTicker(restartMonitorInterval).C restartC := make(chan int64) g := &errgroup.Group{} g.Go(func() error { - return s.tickRestarts(ctx, restartC, tickC) + return s.tickRestarts(ctx, restartC, tickC, initRestartTime) }) - err := s.monitorRestarts(ctx, restartC, 2, 6) + err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeSuccess) + cancel() if err := g.Wait(); err != nil { s.Log.WarnContext(ctx, "Unable to determine last restart time. Failed to monitor for crash loops.", errorKey, err) } return trace.Wrap(err) } -func (s SystemdService) monitorRestarts(ctx context.Context, timeCh <-chan int64, maxStops, minClean int) error { +// monitorRestarts receives restart times on timeCh. +// Each restart time that differs from the preceding restart time counts as a restart. +// If maxRestarts is exceeded, monitorRestarts returns an error. +// Each restart time that matches the proceeding restart time counts as a clean reading. +// If minClean is reached before maxRestarts is exceeded, monitorRestarts runs nil. +func (s SystemdService) monitorRestarts(ctx context.Context, timeCh <-chan int64, maxRestarts, minClean int) error { var ( - clean, stops int - restartTime int64 + same, diff int + restartTime int64 ) - // TODO: thread init value of restartTime for { // wait first to ensure we initial stop has completed select { case <-ctx.Done(): return ctx.Err() case t := <-timeCh: - if t != restartTime { - clean = 0 + switch t { + case restartTime: + same++ + default: + same = 0 restartTime = t - stops++ - } else { - clean++ + diff++ } + } switch { - case stops > maxStops: + case diff > maxRestarts+1: return trace.Errorf("detected crash loop") - case clean >= minClean: + case same >= minClean: return nil } } } -func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC <-chan time.Time) error { +// tickRestarts reads the current time on tickC, and outputs the last restart time on ch for each received tick. +// If the current time cannot be read, tickRestarts sends 0 on ch. +// Any error from the last attempt to receive restart times is returned when ctx is cancelled. +// The baseline restart time is sent as soon as the method is called +func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC <-chan time.Time, baseline int64) error { + t := baseline var err error + select { + case ch <- t: + case <-ctx.Done(): + return err + } for { // two select statements -> never skip restarts select { @@ -146,6 +181,7 @@ func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC } } +// getRestartTime returns the last restart time from the file at LastRestartPath. func (s SystemdService) getRestartTime() (int64, error) { b, err := os.ReadFile(s.LastRestartPath) if err != nil { diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index 6d40d95068767..e6440356bf8d0 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -75,7 +75,7 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr { return slog.Attr{Key: a.Key, Value: a.Value} } -func TestMonitor(t *testing.T) { +func TestRestartMonitor(t *testing.T) { t.Parallel() svc := &SystemdService{ @@ -83,67 +83,75 @@ func TestMonitor(t *testing.T) { } for _, tt := range []struct { - name string - ticks []int64 - maxStops int - minClean int - errored bool - canceled bool + name string + ticks []int64 + maxRestarts int + minClean int + errored bool + canceled bool }{ { - name: "one restart", - ticks: []int64{1, 1, 1, 1}, - maxStops: 2, - minClean: 3, - errored: false, + name: "no restarts", + ticks: []int64{1, 1, 1, 1}, + maxRestarts: 2, + minClean: 3, + errored: false, }, { - name: "two restarts", - ticks: []int64{1, 1, 1, 2, 2, 2, 2}, - maxStops: 2, - minClean: 3, - errored: false, + name: "one restart then stable", + ticks: []int64{1, 1, 1, 2, 2, 2, 2}, + maxRestarts: 2, + minClean: 3, + errored: false, }, { - name: "too many restarts long", - ticks: []int64{1, 1, 1, 2, 2, 2, 3}, - maxStops: 2, - minClean: 3, - errored: true, + name: "two restarts then stable", + ticks: []int64{1, 2, 3, 3, 3, 3}, + maxRestarts: 2, + minClean: 3, + errored: false, }, { - name: "too many restarts short", - ticks: []int64{1, 2, 3}, - maxStops: 2, - minClean: 3, - errored: true, + name: "too many restarts (slow)", + ticks: []int64{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: true, }, { - name: "too many restarts after okay", - ticks: []int64{1, 1, 1, 1, 2, 3}, - maxStops: 2, - minClean: 3, - errored: false, + name: "too many restarts (fast)", + ticks: []int64{1, 2, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: true, }, { - name: "too many restarts before okay", - ticks: []int64{1, 2, 3, 3, 3, 3}, - maxStops: 2, - minClean: 3, - errored: true, + name: "too many restarts after stable", + ticks: []int64{1, 1, 1, 1, 2, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: false, }, { - name: "no error if no minClean", - ticks: []int64{1, 2, 3}, - maxStops: 2, - minClean: 0, - errored: false, + name: "too many restarts before okay", + ticks: []int64{1, 2, 3, 4, 3, 3, 3}, + maxRestarts: 2, + minClean: 3, + errored: true, }, { - name: "cancel", - maxStops: 2, - minClean: 3, - canceled: true, + name: "no error if no minClean", + ticks: []int64{1, 2, 3, 4}, + maxRestarts: 2, + minClean: 0, + errored: false, + }, + { + name: "cancel", + ticks: []int64{1, 1, 1}, + maxRestarts: 2, + minClean: 3, + canceled: true, }, } { t.Run(tt.name, func(t *testing.T) { @@ -157,7 +165,7 @@ func TestMonitor(t *testing.T) { ch <- tick } }() - err := svc.monitorRestarts(ctx, ch, tt.maxStops, tt.minClean) + err := svc.monitorRestarts(ctx, ch, tt.maxRestarts, tt.minClean) require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) if !tt.canceled { require.Equal(t, tt.errored, err != nil) @@ -166,7 +174,7 @@ func TestMonitor(t *testing.T) { } } -func TestTicks(t *testing.T) { +func TestRestartTicks(t *testing.T) { t.Parallel() dir := t.TempDir() @@ -178,36 +186,42 @@ func TestTicks(t *testing.T) { for _, tt := range []struct { name string + init int64 ticks []int64 errored bool }{ { name: "consistent", + init: 1, ticks: []int64{1, 1, 1}, errored: false, }, { name: "divergent", + init: 1, ticks: []int64{1, 2, 3}, errored: false, }, { name: "start error", + init: 1, ticks: []int64{-1, 1, 1}, errored: false, }, { name: "ephemeral error", + init: 1, ticks: []int64{1, -1, 1}, errored: false, }, { name: "end error", + init: 1, ticks: []int64{1, 1, -1}, errored: true, }, { - name: "cancel", + name: "init-only", }, } { t.Run(tt.name, func(t *testing.T) { @@ -219,12 +233,16 @@ func TestTicks(t *testing.T) { go func() { defer cancel() // always quit after last tick or fail + require.Equal(t, tt.init, <-ch) for _, tick := range tt.ticks { if tick >= 0 { err := os.WriteFile(restartPath, []byte(fmt.Sprintln(tick)), os.ModePerm) require.NoError(t, err) } else { - _ = os.Remove(restartPath) + err := os.Remove(restartPath) + if err != nil { + require.ErrorIs(t, err, os.ErrNotExist) + } } tickC <- time.Now() res := <-ch @@ -234,7 +252,7 @@ func TestTicks(t *testing.T) { require.Equal(t, tick, res) } }() - err := svc.tickRestarts(ctx, ch, tickC) + err := svc.tickRestarts(ctx, ch, tickC, tt.init) require.Equal(t, tt.errored, err != nil) if err != nil { require.ErrorIs(t, err, os.ErrNotExist) From 664aabfcea1f60a44ec861240c93502f64f338f1 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Fri, 15 Nov 2024 15:39:21 -0500 Subject: [PATCH 07/14] cleanup --- lib/autoupdate/agent/process.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 9d6ec3a87d878..408d18a4db3b5 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -32,6 +32,15 @@ import ( "golang.org/x/sync/errgroup" ) +const ( + // restartMonitorInterval is the polling interval for determining restart times from LastRestartPath. + restartMonitorInterval = 2 * time.Second + // minCleanIntervalsBeforeStable is the number of consecutive intervals before the service is determined stable. + minCleanIntervalsBeforeStable = 6 + // maxRestartsBeforeFailure is the number of total restarts allowed before the service is marked as crash-looping. + maxRestartsBeforeFailure = 2 +) + // SystemdService manages a Teleport systemd service. type SystemdService struct { // ServiceName specifies the systemd service name. @@ -88,16 +97,10 @@ func (s SystemdService) Reload(ctx context.Context) error { return trace.Wrap(s.monitor(ctx, initRestartTime)) } -const ( - restartMonitorInterval = 2 * time.Second - minCleanIntervalsBeforeSuccess = 6 - maxRestartsBeforeFailure = 2 -) - // monitor for excessive restarts by polling the LastRestartPath file. // This function detects crash-looping while minimizing its own runtime during updates. // To accomplish this, monitor fails after seeing maxRestartsBeforeFailure, and stops checking -// after seeing minCleanIntervalsBeforeSuccess clean intervals. +// after seeing minCleanIntervalsBeforeStable clean intervals. // initRestartTime may be provided as a baseline restart time, to ensure we catch the initial restart. func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) error { ctx, cancel := context.WithCancel(ctx) @@ -109,7 +112,7 @@ func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) erro g.Go(func() error { return s.tickRestarts(ctx, restartC, tickC, initRestartTime) }) - err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeSuccess) + err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeStable) cancel() if err := g.Wait(); err != nil { s.Log.WarnContext(ctx, "Unable to determine last restart time. Failed to monitor for crash loops.", errorKey, err) From db53627216871f397e565a2e9b3cfbb796d25a30 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Fri, 15 Nov 2024 17:20:48 -0500 Subject: [PATCH 08/14] self-reloading --- lib/autoupdate/agent/config.go | 23 +++++++++++++++- lib/autoupdate/agent/config_test.go | 2 +- lib/autoupdate/agent/process.go | 42 ++++++++++++++++++++++------- lib/autoupdate/agent/updater.go | 15 ++++++++--- tool/teleport-update/main.go | 14 ++++++++-- 5 files changed, 79 insertions(+), 17 deletions(-) diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index 6eb9320c66ac2..2c42487b11ff6 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -19,6 +19,8 @@ package agent import ( + "context" + "log/slog" "os" "path/filepath" "text/template" @@ -53,7 +55,25 @@ WantedBy=teleport.service ` ) -func WriteConfigFiles(linkDir, dataDir string) error { +func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error { + err := writeConfigFiles(linkDir, dataDir) + if err != nil { + return trace.Errorf("failed to write teleport-update systemd config files: %w", err) + } + svc := &SystemdService{ + ServiceName: "teleport-update.timer", + Log: log, + } + if err := svc.Reload(ctx); err != nil { + return trace.Errorf("failed to reload systemd config: %w", err) + } + if err := svc.Enable(ctx, true); err != nil { + return trace.Errorf("failed to enable teleport-update systemd timer: %w", err) + } + return nil +} + +func writeConfigFiles(linkDir, dataDir string) error { // TODO(sclevine): revert on failure dropinPath := filepath.Join(linkDir, serviceDir, serviceName+".d", serviceDropinName) @@ -71,6 +91,7 @@ func WriteConfigFiles(linkDir, dataDir string) error { if err != nil { return trace.Wrap(err) } + return nil } diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go index 1530bae9b868b..d5f79571f8a56 100644 --- a/lib/autoupdate/agent/config_test.go +++ b/lib/autoupdate/agent/config_test.go @@ -34,7 +34,7 @@ func TestWriteConfigFiles(t *testing.T) { t.Parallel() linkDir := t.TempDir() dataDir := t.TempDir() - err := WriteConfigFiles(linkDir, dataDir) + err := writeConfigFiles(linkDir, dataDir) require.NoError(t, err) for _, p := range []string{ diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 408d18a4db3b5..bc8a74425ee0d 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -41,7 +41,12 @@ const ( maxRestartsBeforeFailure = 2 ) -// SystemdService manages a Teleport systemd service. +// log keys +const ( + unitKey = "unit" +) + +// SystemdService manages a systemd service (e.g., teleport or teleport-update). type SystemdService struct { // ServiceName specifies the systemd service name. ServiceName string @@ -51,7 +56,7 @@ type SystemdService struct { Log *slog.Logger } -// Reload a systemd service. +// Reload the systemd service. // Attempts a graceful reload before a hard restart. // See Process interface for more details. func (s SystemdService) Reload(ctx context.Context) error { @@ -75,25 +80,25 @@ func (s SystemdService) Reload(ctx context.Context) error { case code < 0: return trace.Errorf("unable to determine if systemd service is active") case code > 0: - s.Log.WarnContext(ctx, "Teleport systemd service not running.") + s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName) return trace.Wrap(ErrNotNeeded) } // Attempt graceful reload of running service. code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) switch { case code < 0: - return trace.Errorf("unable to attempt reload of Teleport systemd service") + return trace.Errorf("unable to reload systemd service") case code > 0: // Graceful reload fails, try hard restart. code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName) if code != 0 { - return trace.Errorf("hard restart of Teleport systemd service failed") + return trace.Errorf("hard restart of systemd service failed") } - s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.") + s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) default: - s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") + s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) } - s.Log.InfoContext(ctx, "Monitoring for excessive restarts.") + s.Log.InfoContext(ctx, "Monitoring for excessive restarts.", unitKey, s.ServiceName) return trace.Wrap(s.monitor(ctx, initRestartTime)) } @@ -115,7 +120,8 @@ func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) erro err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeStable) cancel() if err := g.Wait(); err != nil { - s.Log.WarnContext(ctx, "Unable to determine last restart time. Failed to monitor for crash loops.", errorKey, err) + s.Log.WarnContext(ctx, "Unable to determine last restart time. Cannot detect crash loops.", unitKey, s.ServiceName) + s.Log.DebugContext(ctx, "Error monitoring for crash loops.", errorKey, err, unitKey, s.ServiceName) } return trace.Wrap(err) } @@ -207,6 +213,24 @@ func (s SystemdService) Sync(ctx context.Context) error { if code != 0 { return trace.Errorf("unable to reload systemd configuration") } + s.Log.InfoContext(ctx, "Systemd configuration synced.") + return nil +} + +// Enable the systemd service. +func (s SystemdService) Enable(ctx context.Context, now bool) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + args := []string{"enable", s.ServiceName} + if now { + args = append(args, "--now") + } + code := s.systemctl(ctx, slog.LevelError, args...) + if code != 0 { + return trace.Errorf("unable to enable systemd service") + } + s.Log.InfoContext(ctx, "Service enabled.", unitKey, s.ServiceName) return nil } diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index e6e3eb26f734e..f9caaa16ead41 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -166,16 +166,21 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { ReservedFreeInstallDisk: reservedFreeDisk, }, Process: &SystemdService{ - ServiceName: "teleport.service", - Log: cfg.Log, + ServiceName: "teleport.service", + LastRestartPath: filepath.Join(cfg.DataDir, "last-restart"), + Log: cfg.Log, }, Setup: func(ctx context.Context) error { exec := &localExec{ Log: cfg.Log, - ErrLevel: slog.LevelError, + ErrLevel: slog.LevelInfo, OutLevel: slog.LevelDebug, } - _, err := exec.Run(ctx, filepath.Join(cfg.LinkDir, "bin", BinaryName), + name := filepath.Join(cfg.LinkDir, "bin", BinaryName) + if cfg.SelfSetup { + name = "/proc/self/exe" + } + _, err := exec.Run(ctx, name, "--data-dir", cfg.DataDir, "--link-dir", cfg.LinkDir, "setup") @@ -200,6 +205,8 @@ type LocalUpdaterConfig struct { LinkDir string // SystemDir for package-installed Teleport installations (usually /usr/local/teleport-system). SystemDir string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool } // Updater implements the agent-local logic for Teleport agent auto-updates. diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index 7e6ebffa93403..c2d8f07165d25 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -83,6 +83,8 @@ type cliConfig struct { DataDir string // LinkDir for linking binaries and systemd services LinkDir string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool } func Run(args []string) error { @@ -113,11 +115,15 @@ func Run(args []string) error { Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate) enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster."). Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) + enableCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates."). + Short('s').Hidden().BoolVar(&ccfg.SelfSetup) // TODO(sclevine): add force-fips and force-enterprise as hidden flags disableCmd := app.Command("disable", "Disable agent auto-updates.") updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.") + updateCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates."). + Short('s').Hidden().BoolVar(&ccfg.SelfSetup) linkCmd := app.Command("link", "Link the system installation of Teleport from the Teleport package, if auto-updates is disabled.") @@ -180,6 +186,7 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { DataDir: ccfg.DataDir, LinkDir: ccfg.LinkDir, SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, Log: plog, }) if err != nil { @@ -206,6 +213,7 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { DataDir: ccfg.DataDir, LinkDir: ccfg.LinkDir, SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, Log: plog, }) if err != nil { @@ -234,6 +242,7 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { DataDir: ccfg.DataDir, LinkDir: ccfg.LinkDir, SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, Log: plog, }) if err != nil { @@ -262,6 +271,7 @@ func cmdLink(ctx context.Context, ccfg *cliConfig) error { DataDir: ccfg.DataDir, LinkDir: ccfg.LinkDir, SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, Log: plog, }) if err != nil { @@ -291,9 +301,9 @@ func cmdLink(ctx context.Context, ccfg *cliConfig) error { // cmdSetup writes configuration files that are needed to run teleport-update update. func cmdSetup(ctx context.Context, ccfg *cliConfig) error { - err := autoupdate.WriteConfigFiles(ccfg.LinkDir, ccfg.DataDir) + err := autoupdate.Setup(ctx, plog, ccfg.LinkDir, ccfg.DataDir) if err != nil { - return trace.Errorf("failed to write config files: %w", err) + return trace.Errorf("failed to setup teleport-update service: %w", err) } return nil } From 5d3148d31eda25f69aa4b216c5df3d51b996e0ce Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Fri, 15 Nov 2024 17:25:33 -0500 Subject: [PATCH 09/14] bugs --- lib/autoupdate/agent/config.go | 4 ++-- lib/autoupdate/agent/updater.go | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index 2c42487b11ff6..81b57f81a4cd5 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -64,8 +64,8 @@ func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error ServiceName: "teleport-update.timer", Log: log, } - if err := svc.Reload(ctx); err != nil { - return trace.Errorf("failed to reload systemd config: %w", err) + if err := svc.Sync(ctx); err != nil { + return trace.Errorf("failed to sync systemd config: %w", err) } if err := svc.Enable(ctx, true); err != nil { return trace.Errorf("failed to enable teleport-update systemd timer: %w", err) diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index f9caaa16ead41..a0a4b0fea4c43 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -27,6 +27,7 @@ import ( "log/slog" "net/http" "os" + "os/exec" "path/filepath" "strings" "time" @@ -171,20 +172,17 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { Log: cfg.Log, }, Setup: func(ctx context.Context) error { - exec := &localExec{ - Log: cfg.Log, - ErrLevel: slog.LevelInfo, - OutLevel: slog.LevelDebug, - } name := filepath.Join(cfg.LinkDir, "bin", BinaryName) if cfg.SelfSetup { name = "/proc/self/exe" } - _, err := exec.Run(ctx, name, + cmd := exec.CommandContext(ctx, name, "--data-dir", cfg.DataDir, "--link-dir", cfg.LinkDir, "setup") - return err + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + return cmd.Run() }, }, nil } From b55bea9f5b181d1761a49c80f14e2f00c514c16b Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Mon, 18 Nov 2024 02:56:15 -0500 Subject: [PATCH 10/14] pid wip --- lib/autoupdate/agent/config.go | 8 +- lib/autoupdate/agent/installer.go | 10 +-- lib/autoupdate/agent/process.go | 128 +++++++++++++++++++++++---- lib/autoupdate/agent/process_test.go | 103 ++++++++++++++++++++- lib/autoupdate/agent/updater.go | 40 ++++++--- 5 files changed, 252 insertions(+), 37 deletions(-) diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index 81b57f81a4cd5..457c4bffcd1f1 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -20,6 +20,7 @@ package agent import ( "context" + "errors" "log/slog" "os" "path/filepath" @@ -64,7 +65,12 @@ func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error ServiceName: "teleport-update.timer", Log: log, } - if err := svc.Sync(ctx); err != nil { + err = svc.Sync(ctx) + if errors.Is(err, ErrNotSupported) { + log.WarnContext(ctx, "Not enabling systemd service because systemd is not running.") + return nil + } + if err != nil { return trace.Errorf("failed to sync systemd config: %w", err) } if err := svc.Enable(ctx, true); err != nil { diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 530911a3c1184..e9b2caadea780 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -788,13 +788,5 @@ func (li *LocalInstaller) isLinked(versionDir string) (bool, error) { return true, nil } } - linkData, err := readFileN(filepath.Join(li.LinkServiceDir, serviceName), maxServiceFileSize) - if err != nil { - return false, nil - } - versionData, err := readFileN(filepath.Join(versionDir, serviceDir, serviceName), maxServiceFileSize) - if err != nil { - return false, nil - } - return bytes.Equal(linkData, versionData), nil + return false, nil } diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index bc8a74425ee0d..547a7a7e6be3e 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -26,6 +26,7 @@ import ( "os" "os/exec" "strconv" + "syscall" "time" "github.com/gravitational/trace" @@ -52,6 +53,8 @@ type SystemdService struct { ServiceName string // LastRestartPath is a path to a file containing the last restart time. LastRestartPath string + // PIDPath is a path to a file containing the service's PID. + PIDPath string // Log contains a logger. Log *slog.Logger } @@ -66,9 +69,6 @@ func (s SystemdService) Reload(ctx context.Context) error { return trace.Wrap(err) } - // If getRestartTime fails consistently, error will be returned from monitor. - initRestartTime, _ := s.getRestartTime() - // Command error codes < 0 indicate that we are unable to run the command. // Errors from s.systemctl are logged along with stderr and stdout (debug only). @@ -83,6 +83,19 @@ func (s SystemdService) Reload(ctx context.Context) error { s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName) return trace.Wrap(ErrNotNeeded) } + + // Get initial restart time and initial PID. + + // If getRestartTime fails consistently, error will be returned from monitor. + initRestartTime, err := readInt64(s.LastRestartPath) + if err != nil { + s.Log.DebugContext(ctx, "Initial restart time not present.", unitKey, s.ServiceName) + } + initPID, err := readInt(s.PIDPath) + if err != nil { + s.Log.DebugContext(ctx, "Initial PID not present.", unitKey, s.ServiceName) + } + // Attempt graceful reload of running service. code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) switch { @@ -97,11 +110,85 @@ func (s SystemdService) Reload(ctx context.Context) error { s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) default: s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) + if err := s.verifyPID(ctx, initPID); err != nil { + return trace.Wrap(err) + } } s.Log.InfoContext(ctx, "Monitoring for excessive restarts.", unitKey, s.ServiceName) return trace.Wrap(s.monitor(ctx, initRestartTime)) } +func (s SystemdService) verifyPID(ctx context.Context, initPID int64) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + tickC := time.NewTicker(restartMonitorInterval).C + + pidC := make(chan int64) + g := &errgroup.Group{} + g.Go(func() error { + return tickFile(ctx, s.PIDPath, pidC, tickC, initPID) + }) + err := s.waitForStablePID(ctx, initPID, tickC) + cancel() + if err := g.Wait(); err != nil { + s.Log.WarnContext(ctx, "Unable to determine PID. Cannot failed reload.", unitKey, s.ServiceName) + s.Log.DebugContext(ctx, "Error monitoring for crashing fork.", errorKey, err, unitKey, s.ServiceName) + } + return trace.Wrap(err) +} + +func (s SystemdService) waitForStablePID(ctx context.Context, baseline int, pidC <-chan int64) error { + var warnPID int + var pid int + for n, last := 0, 0; n < 3; n++ { + select { + case <-ctx.Done(): + return ctx.Err() + case p := <-pidC: + last = pid + pid = int(p) + } + if pid != last || + pid == baseline || + pid == warnPID { + n = 0 + continue + } + process, err := os.FindProcess(pid) + if err != nil { + return trace.Wrap(err) + } + err = process.Signal(syscall.Signal(0)) + if errors.Is(err, syscall.ESRCH) { + if pid != warnPID && + pid != baseline { + s.Log.WarnContext(ctx, "Detecting crashing fork.", unitKey, s.ServiceName, "pid", pid) + warnPID = pid + } + n = 0 + continue + } + } + return nil +} + +func readInt64(path string) (int64, error) { + p, err := readFileN(path, 32) + if err != nil { + return 0, trace.Wrap(err) + } + i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64) + if err != nil { + return 0, trace.Wrap(err) + } + return i, nil +} + +func readInt(path string) (int, error) { + i, err := readInt64(path) + return int(i), trace.Wrap(err) +} + // monitor for excessive restarts by polling the LastRestartPath file. // This function detects crash-looping while minimizing its own runtime during updates. // To accomplish this, monitor fails after seeing maxRestartsBeforeFailure, and stops checking @@ -115,7 +202,7 @@ func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) erro restartC := make(chan int64) g := &errgroup.Group{} g.Go(func() error { - return s.tickRestarts(ctx, restartC, tickC, initRestartTime) + return tickFile(ctx, s.LastRestartPath, restartC, tickC, initRestartTime) }) err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeStable) cancel() @@ -150,7 +237,6 @@ func (s SystemdService) monitorRestarts(ctx context.Context, timeCh <-chan int64 restartTime = t diff++ } - } switch { case diff > maxRestarts+1: @@ -181,7 +267,7 @@ func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC return err } var t int64 - t, err = s.getRestartTime() + t, err = readInt64(s.LastRestartPath) select { case ch <- t: case <-ctx.Done(): @@ -190,17 +276,29 @@ func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC } } -// getRestartTime returns the last restart time from the file at LastRestartPath. -func (s SystemdService) getRestartTime() (int64, error) { - b, err := os.ReadFile(s.LastRestartPath) - if err != nil { - return 0, trace.Wrap(err) +func tickFile(ctx context.Context, path string, ch chan<- int64, tickC <-chan time.Time, baseline int64) error { + t := baseline + var err error + select { + case ch <- t: + case <-ctx.Done(): + return err } - restart, err := strconv.ParseInt(string(bytes.TrimSpace(b)), 10, 64) - if err != nil { - return 0, trace.Wrap(err) + for { + // two select statements -> never skip reads + select { + case <-tickC: + case <-ctx.Done(): + return err + } + var t int64 + t, err = readInt64(path) + select { + case ch <- t: + case <-ctx.Done(): + return err + } } - return restart, nil } // Sync systemd service configuration by running systemctl daemon-reload. diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index e6440356bf8d0..b29315719ec55 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -75,7 +75,7 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr { return slog.Attr{Key: a.Key, Value: a.Value} } -func TestRestartMonitor(t *testing.T) { +func TestWaitForStablePID(t *testing.T) { t.Parallel() svc := &SystemdService{ @@ -174,7 +174,106 @@ func TestRestartMonitor(t *testing.T) { } } -func TestRestartTicks(t *testing.T) { +func TestMonitorRestarts(t *testing.T) { + t.Parallel() + + svc := &SystemdService{ + Log: slog.Default(), + } + + for _, tt := range []struct { + name string + ticks []int64 + maxRestarts int + minClean int + errored bool + canceled bool + }{ + { + name: "no restarts", + ticks: []int64{1, 1, 1, 1}, + maxRestarts: 2, + minClean: 3, + errored: false, + }, + { + name: "one restart then stable", + ticks: []int64{1, 1, 1, 2, 2, 2, 2}, + maxRestarts: 2, + minClean: 3, + errored: false, + }, + { + name: "two restarts then stable", + ticks: []int64{1, 2, 3, 3, 3, 3}, + maxRestarts: 2, + minClean: 3, + errored: false, + }, + { + name: "too many restarts (slow)", + ticks: []int64{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: true, + }, + { + name: "too many restarts (fast)", + ticks: []int64{1, 2, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: true, + }, + { + name: "too many restarts after stable", + ticks: []int64{1, 1, 1, 1, 2, 3, 4}, + maxRestarts: 2, + minClean: 3, + errored: false, + }, + { + name: "too many restarts before okay", + ticks: []int64{1, 2, 3, 4, 3, 3, 3}, + maxRestarts: 2, + minClean: 3, + errored: true, + }, + { + name: "no error if no minClean", + ticks: []int64{1, 2, 3, 4}, + maxRestarts: 2, + minClean: 0, + errored: false, + }, + { + name: "cancel", + ticks: []int64{1, 1, 1}, + maxRestarts: 2, + minClean: 3, + canceled: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan int64) + go func() { + defer cancel() // always quit after last tick + for _, tick := range tt.ticks { + ch <- tick + } + }() + err := svc.monitorRestarts(ctx, ch, tt.maxRestarts, tt.minClean) + require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) + if !tt.canceled { + require.Equal(t, tt.errored, err != nil) + } + }) + } +} + +func TestTickRestarts(t *testing.T) { t.Parallel() dir := t.TempDir() diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index a0a4b0fea4c43..d96dcee5663fd 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -182,7 +182,9 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { "setup") cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout - return cmd.Run() + cfg.Log.InfoContext(ctx, "Executing new teleport-update binary to update configuration.") + defer cfg.Log.InfoContext(ctx, "Finished executing new teleport-update binary.") + return trace.Wrap(cmd.Run()) }, }, nil } @@ -363,6 +365,8 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { return trace.Errorf("agent version not available from Teleport cluster") } + u.Log.InfoContext(ctx, "Initiating initial update.", targetVersionKey, targetVersion, activeVersionKey, cfg.Status.ActiveVersion) + if err := u.update(ctx, cfg, targetVersion, flags); err != nil { return trace.Wrap(err) } @@ -513,6 +517,12 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s if err != nil { return trace.Errorf("failed to install: %w", err) } + + // TODO(slevine): if the target version has fewer binaries, this will + // leave old binaries linked. This may prevent the installation from + // being removed. To fix this, we should look for orphaned binaries + // and remove them. + revert, err := u.Installer.Link(ctx, targetVersion) if err != nil { return trace.Errorf("failed to link: %w", err) @@ -533,10 +543,12 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s // Sync process configuration after linking. - if err := u.Process.Sync(ctx); err != nil { - if errors.Is(err, context.Canceled) { - return trace.Errorf("sync canceled") - } + err = u.Process.Sync(ctx) + if errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Not syncing systemd configuration because systemd is not running.") + } else if errors.Is(err, context.Canceled) { + return trace.Errorf("sync canceled") + } else if err != nil { // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync. u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.") if ok := revert(ctx); !ok { @@ -554,10 +566,14 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s if cfg.Status.ActiveVersion != targetVersion { u.Log.InfoContext(ctx, "Target version successfully installed.", targetVersionKey, targetVersion) - if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { - if errors.Is(err, context.Canceled) { - return trace.Errorf("reload canceled") - } + err := u.Process.Reload(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("reload canceled") + } + if err != nil && + !errors.Is(err, ErrNotNeeded) && // no output if restart not needed + !errors.Is(err, ErrNotSupported) { // already logged above for Sync + // If reloading Teleport at the new version fails, revert, resync, and reload. u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.") if ok := revert(ctx); !ok { @@ -645,7 +661,11 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { if override.Group != "" { spec.Group = override.Group } - if override.URLTemplate != "" { + switch override.URLTemplate { + case "": + case "default": + spec.URLTemplate = "" + default: spec.URLTemplate = override.URLTemplate } if spec.URLTemplate != "" && From fd2fd1910011d3a3e6cd6d92fad376874d222348 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Mon, 18 Nov 2024 20:53:05 -0500 Subject: [PATCH 11/14] switch to PID monitoring --- lib/autoupdate/agent/config.go | 15 +- lib/autoupdate/agent/config_test.go | 1 - lib/autoupdate/agent/process.go | 240 +++++-------- lib/autoupdate/agent/process_test.go | 319 ++++++++---------- .../teleport-update.conf.golden | 3 - .../teleport-update.service.golden | 2 +- .../teleport-update.timer.golden | 2 +- lib/autoupdate/agent/updater.go | 6 +- 8 files changed, 246 insertions(+), 342 deletions(-) delete mode 100644 lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index 457c4bffcd1f1..59bb7751a00a8 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -30,13 +30,9 @@ import ( ) const ( - teleportDropinTemplate = `# teleport-update -[Service] -ExecStopPost=/bin/bash -c 'date +%%s > {{.DataDir}}/last-restart' -` updateServiceTemplate = `# teleport-update [Unit] -Description=Teleport update service +Description=Teleport auto-update service [Service] Type=oneshot @@ -44,7 +40,7 @@ ExecStart={{.LinkDir}}/bin/teleport-update update ` updateTimerTemplate = `# teleport-update [Unit] -Description=Teleport update timer unit +Description=Teleport auto-update timer unit [Timer] OnActiveSec=1m @@ -82,13 +78,8 @@ func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error func writeConfigFiles(linkDir, dataDir string) error { // TODO(sclevine): revert on failure - dropinPath := filepath.Join(linkDir, serviceDir, serviceName+".d", serviceDropinName) - err := writeTemplate(dropinPath, teleportDropinTemplate, linkDir, dataDir) - if err != nil { - return trace.Wrap(err) - } servicePath := filepath.Join(linkDir, serviceDir, updateServiceName) - err = writeTemplate(servicePath, updateServiceTemplate, linkDir, dataDir) + err := writeTemplate(servicePath, updateServiceTemplate, linkDir, dataDir) if err != nil { return trace.Wrap(err) } diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go index d5f79571f8a56..16cbdb5374fb6 100644 --- a/lib/autoupdate/agent/config_test.go +++ b/lib/autoupdate/agent/config_test.go @@ -38,7 +38,6 @@ func TestWriteConfigFiles(t *testing.T) { require.NoError(t, err) for _, p := range []string{ - filepath.Join(linkDir, serviceDir, serviceName+".d", serviceDropinName), filepath.Join(linkDir, serviceDir, updateServiceName), filepath.Join(linkDir, serviceDir, updateTimerName), } { diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 547a7a7e6be3e..0c3a603837532 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -34,12 +34,15 @@ import ( ) const ( - // restartMonitorInterval is the polling interval for determining restart times from LastRestartPath. - restartMonitorInterval = 2 * time.Second - // minCleanIntervalsBeforeStable is the number of consecutive intervals before the service is determined stable. - minCleanIntervalsBeforeStable = 6 - // maxRestartsBeforeFailure is the number of total restarts allowed before the service is marked as crash-looping. - maxRestartsBeforeFailure = 2 + // crashMonitorInterval is the polling interval for determining restart times from LastRestartPath. + crashMonitorInterval = 2 * time.Second + // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detect + // before the service is determined stable. + minRunningIntervalsBeforeStable = 6 + // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping. + maxCrashesBeforeFailure = 2 + // crashMonitorTimeout + crashMonitorTimeout = 30 * time.Second ) // log keys @@ -51,8 +54,6 @@ const ( type SystemdService struct { // ServiceName specifies the systemd service name. ServiceName string - // LastRestartPath is a path to a file containing the last restart time. - LastRestartPath string // PIDPath is a path to a file containing the service's PID. PIDPath string // Log contains a logger. @@ -84,16 +85,13 @@ func (s SystemdService) Reload(ctx context.Context) error { return trace.Wrap(ErrNotNeeded) } - // Get initial restart time and initial PID. + // Get initial PID for crash monitoring. - // If getRestartTime fails consistently, error will be returned from monitor. - initRestartTime, err := readInt64(s.LastRestartPath) - if err != nil { - s.Log.DebugContext(ctx, "Initial restart time not present.", unitKey, s.ServiceName) - } initPID, err := readInt(s.PIDPath) - if err != nil { - s.Log.DebugContext(ctx, "Initial PID not present.", unitKey, s.ServiceName) + if errors.Is(err, os.ErrNotExist) { + s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName) + } else if err != nil { + s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err) } // Attempt graceful reload of running service. @@ -110,69 +108,101 @@ func (s SystemdService) Reload(ctx context.Context) error { s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) default: s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) - if err := s.verifyPID(ctx, initPID); err != nil { - return trace.Wrap(err) - } } - s.Log.InfoContext(ctx, "Monitoring for excessive restarts.", unitKey, s.ServiceName) - return trace.Wrap(s.monitor(ctx, initRestartTime)) + if initPID != 0 { + s.Log.InfoContext(ctx, "Monitoring PID file to detecting crashes.", unitKey, s.ServiceName) + return trace.Wrap(s.monitor(ctx, initPID)) + } + return nil } -func (s SystemdService) verifyPID(ctx context.Context, initPID int64) error { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) +// monitor for the started process to ensure it's running by polling PIDFile. +// This function detects several types of crashes while minimizing its own runtime during updates. +// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID), +// or getting stuck on quit (no change in PID). +// initPID is the PID before the restart operation has been issued. +func (s SystemdService) monitor(ctx context.Context, initPID int) error { + ctx, cancel := context.WithTimeout(ctx, crashMonitorTimeout) defer cancel() - tickC := time.NewTicker(restartMonitorInterval).C + tickC := time.NewTicker(crashMonitorInterval).C - pidC := make(chan int64) + pidC := make(chan int) g := &errgroup.Group{} g.Go(func() error { - return tickFile(ctx, s.PIDPath, pidC, tickC, initPID) + return tickFile(ctx, s.PIDPath, pidC, tickC) }) - err := s.waitForStablePID(ctx, initPID, tickC) + err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure, + initPID, pidC, func(pid int) error { + process, err := os.FindProcess(pid) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(process.Signal(syscall.Signal(0))) + }) cancel() if err := g.Wait(); err != nil { - s.Log.WarnContext(ctx, "Unable to determine PID. Cannot failed reload.", unitKey, s.ServiceName) - s.Log.DebugContext(ctx, "Error monitoring for crashing fork.", errorKey, err, unitKey, s.ServiceName) + s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName) } return trace.Wrap(err) } -func (s SystemdService) waitForStablePID(ctx context.Context, baseline int, pidC <-chan int64) error { - var warnPID int - var pid int - for n, last := 0, 0; n < 3; n++ { +// monitorRestarts receives restart times on timeCh. +// Each restart time that differs from the preceding restart time counts as a restart. +// If maxRestarts is exceeded, monitorRestarts returns an error. +// Each restart time that matches the proceeding restart time counts as a clean reading. +// If minClean is reached before maxRestarts is exceeded, monitorRestarts runs nil. +func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, findPID func(pid int) error) error { + pid := baselinePID + var last, stale int + var crashes int + for stable := 0; stable < minStable; stable++ { select { case <-ctx.Done(): return ctx.Err() case p := <-pidC: last = pid - pid = int(p) + pid = p } - if pid != last || - pid == baseline || - pid == warnPID { - n = 0 - continue + // A "crash" is defined as a transition away from a new (non-baseline) PID, or + // an interval where the current PID remains non-running since the last check. + if (last != 0 && pid != last && last != baselinePID) || + (stale != 0 && pid == stale && last == stale) { + crashes++ } - process, err := os.FindProcess(pid) - if err != nil { - return trace.Wrap(err) + if crashes > maxCrashes { + return trace.Errorf("detected crashing process") + } + + // PID can only be stable if it is a real PID that is not new, has changed at least once, + // and hasn't been observed as missing. + if pid == 0 || + pid == baselinePID || + pid == stale || + pid != last { + stable = -1 + continue } - err = process.Signal(syscall.Signal(0)) + err := findPID(pid) + // A stale PID most likely indicates that the process forked and crashed without systemd noticing. + // There is a small chance that we read the PID file before systemd removed it. + // Note: we only perform this check on PIDs that survive one iteration. if errors.Is(err, syscall.ESRCH) { - if pid != warnPID && - pid != baseline { - s.Log.WarnContext(ctx, "Detecting crashing fork.", unitKey, s.ServiceName, "pid", pid) - warnPID = pid + if pid != stale && + pid != baselinePID { + stale = pid + s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale) } - n = 0 + stable = -1 continue } + if err != nil { + return trace.Wrap(err) + } } return nil } -func readInt64(path string) (int64, error) { +func readInt(path string) (int, error) { p, err := readFileN(path, 32) if err != nil { return 0, trace.Wrap(err) @@ -181,109 +211,14 @@ func readInt64(path string) (int64, error) { if err != nil { return 0, trace.Wrap(err) } - return i, nil -} - -func readInt(path string) (int, error) { - i, err := readInt64(path) - return int(i), trace.Wrap(err) + return int(i), nil } -// monitor for excessive restarts by polling the LastRestartPath file. -// This function detects crash-looping while minimizing its own runtime during updates. -// To accomplish this, monitor fails after seeing maxRestartsBeforeFailure, and stops checking -// after seeing minCleanIntervalsBeforeStable clean intervals. -// initRestartTime may be provided as a baseline restart time, to ensure we catch the initial restart. -func (s SystemdService) monitor(ctx context.Context, initRestartTime int64) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - tickC := time.NewTicker(restartMonitorInterval).C - restartC := make(chan int64) - g := &errgroup.Group{} - g.Go(func() error { - return tickFile(ctx, s.LastRestartPath, restartC, tickC, initRestartTime) - }) - err := s.monitorRestarts(ctx, restartC, maxRestartsBeforeFailure, minCleanIntervalsBeforeStable) - cancel() - if err := g.Wait(); err != nil { - s.Log.WarnContext(ctx, "Unable to determine last restart time. Cannot detect crash loops.", unitKey, s.ServiceName) - s.Log.DebugContext(ctx, "Error monitoring for crash loops.", errorKey, err, unitKey, s.ServiceName) - } - return trace.Wrap(err) -} - -// monitorRestarts receives restart times on timeCh. -// Each restart time that differs from the preceding restart time counts as a restart. -// If maxRestarts is exceeded, monitorRestarts returns an error. -// Each restart time that matches the proceeding restart time counts as a clean reading. -// If minClean is reached before maxRestarts is exceeded, monitorRestarts runs nil. -func (s SystemdService) monitorRestarts(ctx context.Context, timeCh <-chan int64, maxRestarts, minClean int) error { - var ( - same, diff int - restartTime int64 - ) - for { - // wait first to ensure we initial stop has completed - select { - case <-ctx.Done(): - return ctx.Err() - case t := <-timeCh: - switch t { - case restartTime: - same++ - default: - same = 0 - restartTime = t - diff++ - } - } - switch { - case diff > maxRestarts+1: - return trace.Errorf("detected crash loop") - case same >= minClean: - return nil - } - } -} - -// tickRestarts reads the current time on tickC, and outputs the last restart time on ch for each received tick. -// If the current time cannot be read, tickRestarts sends 0 on ch. -// Any error from the last attempt to receive restart times is returned when ctx is cancelled. -// The baseline restart time is sent as soon as the method is called -func (s SystemdService) tickRestarts(ctx context.Context, ch chan<- int64, tickC <-chan time.Time, baseline int64) error { - t := baseline +// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick. +// If the path cannot be read, tickFile sends 0 on ch. +// Any error from the last attempt to read path is returned when ctx is cancelled, unless the error is os.ErrNotExist. +func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error { var err error - select { - case ch <- t: - case <-ctx.Done(): - return err - } - for { - // two select statements -> never skip restarts - select { - case <-tickC: - case <-ctx.Done(): - return err - } - var t int64 - t, err = readInt64(s.LastRestartPath) - select { - case ch <- t: - case <-ctx.Done(): - return err - } - } -} - -func tickFile(ctx context.Context, path string, ch chan<- int64, tickC <-chan time.Time, baseline int64) error { - t := baseline - var err error - select { - case ch <- t: - case <-ctx.Done(): - return err - } for { // two select statements -> never skip reads select { @@ -291,8 +226,11 @@ func tickFile(ctx context.Context, path string, ch chan<- int64, tickC <-chan ti case <-ctx.Done(): return err } - var t int64 - t, err = readInt64(path) + var t int + t, err = readInt(path) + if errors.Is(err, os.ErrNotExist) { + err = nil + } select { case ch <- t: case <-ctx.Done(): diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index b29315719ec55..3962ec8da1dbb 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -26,6 +26,7 @@ import ( "log/slog" "os" "path/filepath" + "syscall" "testing" "time" @@ -83,188 +84,165 @@ func TestWaitForStablePID(t *testing.T) { } for _, tt := range []struct { - name string - ticks []int64 - maxRestarts int - minClean int - errored bool - canceled bool + name string + ticks []int + baseline int + minStable int + maxCrashes int + findErrs map[int]error + + errored bool + canceled bool }{ { - name: "no restarts", - ticks: []int64{1, 1, 1, 1}, - maxRestarts: 2, - minClean: 3, - errored: false, - }, - { - name: "one restart then stable", - ticks: []int64{1, 1, 1, 2, 2, 2, 2}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "immediate restart", + ticks: []int{2, 2}, + baseline: 1, + minStable: 1, + maxCrashes: 1, }, { - name: "two restarts then stable", - ticks: []int64{1, 2, 3, 3, 3, 3}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "zero stable", }, { - name: "too many restarts (slow)", - ticks: []int64{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "immediate crash", + ticks: []int{2, 3}, + baseline: 1, + minStable: 1, + maxCrashes: 0, + errored: true, }, { - name: "too many restarts (fast)", - ticks: []int64{1, 2, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "no changes times out", + ticks: []int{1, 1, 1, 1}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + canceled: true, }, { - name: "too many restarts after stable", - ticks: []int64{1, 1, 1, 1, 2, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "baseline restart", + ticks: []int{2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, }, { - name: "too many restarts before okay", - ticks: []int64{1, 2, 3, 4, 3, 3, 3}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "one restart then stable", + ticks: []int{1, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, }, { - name: "no error if no minClean", - ticks: []int64{1, 2, 3, 4}, - maxRestarts: 2, - minClean: 0, - errored: false, + name: "two restarts then stable", + ticks: []int{1, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, }, { - name: "cancel", - ticks: []int64{1, 1, 1}, - maxRestarts: 2, - minClean: 3, - canceled: true, + name: "three restarts then stable", + ticks: []int{1, 2, 3, 4, 4, 4, 4}, + baseline: 1, + minStable: 3, + maxCrashes: 2, }, - } { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - ch := make(chan int64) - go func() { - defer cancel() // always quit after last tick - for _, tick := range tt.ticks { - ch <- tick - } - }() - err := svc.monitorRestarts(ctx, ch, tt.maxRestarts, tt.minClean) - require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) - if !tt.canceled { - require.Equal(t, tt.errored, err != nil) - } - }) - } -} - -func TestMonitorRestarts(t *testing.T) { - t.Parallel() - - svc := &SystemdService{ - Log: slog.Default(), - } - - for _, tt := range []struct { - name string - ticks []int64 - maxRestarts int - minClean int - errored bool - canceled bool - }{ { - name: "no restarts", - ticks: []int64{1, 1, 1, 1}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "too many restarts excluding baseline", + ticks: []int{1, 2, 3, 4, 5}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + errored: true, }, { - name: "one restart then stable", - ticks: []int64{1, 1, 1, 2, 2, 2, 2}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "too many restarts including baseline", + ticks: []int{1, 2, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, }, { - name: "two restarts then stable", - ticks: []int64{1, 2, 3, 3, 3, 3}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "too many restarts slow", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, }, { - name: "too many restarts (slow)", - ticks: []int64{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "too many restarts after stable", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, }, { - name: "too many restarts (fast)", - ticks: []int64{1, 2, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "stable after too many restarts", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, }, { - name: "too many restarts after stable", - ticks: []int64{1, 1, 1, 1, 2, 3, 4}, - maxRestarts: 2, - minClean: 3, - errored: false, + name: "cancel", + ticks: []int{1, 1, 1}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + canceled: true, }, { - name: "too many restarts before okay", - ticks: []int64{1, 2, 3, 4, 3, 3, 3}, - maxRestarts: 2, - minClean: 3, - errored: true, + name: "stale PID crash", + ticks: []int{2, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: syscall.ESRCH, + }, + errored: true, }, { - name: "no error if no minClean", - ticks: []int64{1, 2, 3, 4}, - maxRestarts: 2, - minClean: 0, - errored: false, + name: "stale PID but fixed", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: syscall.ESRCH, + }, }, { - name: "cancel", - ticks: []int64{1, 1, 1}, - maxRestarts: 2, - minClean: 3, - canceled: true, + name: "error PID", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: errors.New("bad"), + }, + errored: true, }, } { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() - ch := make(chan int64) + ch := make(chan int) go func() { defer cancel() // always quit after last tick for _, tick := range tt.ticks { ch <- tick } }() - err := svc.monitorRestarts(ctx, ch, tt.maxRestarts, tt.minClean) + err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes, + tt.baseline, ch, func(pid int) error { + return tt.findErrs[pid] + }) require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) if !tt.canceled { require.Equal(t, tt.errored, err != nil) @@ -273,75 +251,79 @@ func TestMonitorRestarts(t *testing.T) { } } -func TestTickRestarts(t *testing.T) { +func TestTickFile(t *testing.T) { t.Parallel() - dir := t.TempDir() - restartPath := filepath.Join(dir, "restart") - svc := &SystemdService{ - Log: slog.Default(), - LastRestartPath: restartPath, - } - for _, tt := range []struct { name string - init int64 - ticks []int64 + ticks []int errored bool }{ { name: "consistent", - init: 1, - ticks: []int64{1, 1, 1}, + ticks: []int{1, 1, 1}, errored: false, }, { name: "divergent", - init: 1, - ticks: []int64{1, 2, 3}, + ticks: []int{1, 2, 3}, errored: false, }, { name: "start error", - init: 1, - ticks: []int64{-1, 1, 1}, + ticks: []int{-1, 1, 1}, errored: false, }, { name: "ephemeral error", - init: 1, - ticks: []int64{1, -1, 1}, + ticks: []int{1, -1, 1}, errored: false, }, { name: "end error", - init: 1, - ticks: []int64{1, 1, -1}, + ticks: []int{1, 1, -1}, errored: true, }, { - name: "init-only", + name: "start missing", + ticks: []int{0, 1, 1}, + errored: false, + }, + { + name: "ephemeral missing", + ticks: []int{1, 0, 1}, + errored: false, + }, + { + name: "end missing", + ticks: []int{1, 1, 0}, + errored: false, + }, + { + name: "cancel-only", + errored: false, }, } { t.Run(tt.name, func(t *testing.T) { + filePath := filepath.Join(t.TempDir(), "file") + ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() tickC := make(chan time.Time) - ch := make(chan int64) + ch := make(chan int) go func() { defer cancel() // always quit after last tick or fail - require.Equal(t, tt.init, <-ch) for _, tick := range tt.ticks { - if tick >= 0 { - err := os.WriteFile(restartPath, []byte(fmt.Sprintln(tick)), os.ModePerm) + _ = os.RemoveAll(filePath) + switch { + case tick > 0: + err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm) + require.NoError(t, err) + case tick < 0: + err := os.Mkdir(filePath, os.ModePerm) require.NoError(t, err) - } else { - err := os.Remove(restartPath) - if err != nil { - require.ErrorIs(t, err, os.ErrNotExist) - } } tickC <- time.Now() res := <-ch @@ -351,11 +333,8 @@ func TestTickRestarts(t *testing.T) { require.Equal(t, tick, res) } }() - err := svc.tickRestarts(ctx, ch, tickC, tt.init) + err := tickFile(ctx, filePath, ch, tickC) require.Equal(t, tt.errored, err != nil) - if err != nil { - require.ErrorIs(t, err, os.ErrNotExist) - } }) } } diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden deleted file mode 100644 index 1e2cfb333bcbc..0000000000000 --- a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.conf.golden +++ /dev/null @@ -1,3 +0,0 @@ -# teleport-update -[Service] -ExecStopPost=/bin/bash -c 'date +%%s > /var/lib/teleport/last-restart' diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden index b8d6f7f75ae72..185b4f07a1aa9 100644 --- a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden @@ -1,6 +1,6 @@ # teleport-update [Unit] -Description=Teleport update service +Description=Teleport auto-update service [Service] Type=oneshot diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden index dbc8e1d12c404..acca095d9825f 100644 --- a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden @@ -1,6 +1,6 @@ # teleport-update [Unit] -Description=Teleport update timer unit +Description=Teleport auto-update timer unit [Timer] OnActiveSec=1m diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index d96dcee5663fd..49dfe40fd27da 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -167,9 +167,9 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { ReservedFreeInstallDisk: reservedFreeDisk, }, Process: &SystemdService{ - ServiceName: "teleport.service", - LastRestartPath: filepath.Join(cfg.DataDir, "last-restart"), - Log: cfg.Log, + ServiceName: "teleport.service", + PIDPath: "/run/teleport.pid", + Log: cfg.Log, }, Setup: func(ctx context.Context) error { name := filepath.Join(cfg.LinkDir, "bin", BinaryName) From 24c18546313d38166f4b433a8b385f14834413c6 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Mon, 18 Nov 2024 21:50:32 -0500 Subject: [PATCH 12/14] cleanup --- lib/autoupdate/agent/process.go | 44 +++++++++++++++++----------- lib/autoupdate/agent/process_test.go | 5 ++-- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index 0c3a603837532..e8ed9659c2473 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -110,7 +110,7 @@ func (s SystemdService) Reload(ctx context.Context) error { s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) } if initPID != 0 { - s.Log.InfoContext(ctx, "Monitoring PID file to detecting crashes.", unitKey, s.ServiceName) + s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName) return trace.Wrap(s.monitor(ctx, initPID)) } return nil @@ -133,11 +133,11 @@ func (s SystemdService) monitor(ctx context.Context, initPID int) error { }) err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure, initPID, pidC, func(pid int) error { - process, err := os.FindProcess(pid) + p, err := os.FindProcess(pid) if err != nil { return trace.Wrap(err) } - return trace.Wrap(process.Signal(syscall.Signal(0))) + return trace.Wrap(p.Signal(syscall.Signal(0))) }) cancel() if err := g.Wait(); err != nil { @@ -186,7 +186,8 @@ func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCras // A stale PID most likely indicates that the process forked and crashed without systemd noticing. // There is a small chance that we read the PID file before systemd removed it. // Note: we only perform this check on PIDs that survive one iteration. - if errors.Is(err, syscall.ESRCH) { + if errors.Is(err, os.ErrProcessDone) || + errors.Is(err, syscall.ESRCH) { if pid != stale && pid != baselinePID { stale = pid @@ -290,12 +291,16 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args OutLevel: slog.LevelDebug, } code, err := cmd.Run(ctx, "systemctl", args...) - if err != nil { - s.Log.Log(ctx, errLevel, "Failed to run systemctl.", - "args", args, - "code", code, - errorKey, err) + if err == nil { + return code + } + if code >= 0 { + s.Log.Log(ctx, errLevel, "Error running systemctl.", + "args", args, "code", code) + return code } + s.Log.Log(ctx, errLevel, "Unable to run systemctl.", + "args", args, "code", code, errorKey, err) return code } @@ -314,8 +319,8 @@ type localExec struct { // Outputs the status code, or -1 if out-of-range or unstarted. func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) { cmd := exec.CommandContext(ctx, name, args...) - stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel} - stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel} + stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "} + stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "} cmd.Stderr = stderr cmd.Stdout = stdout err := cmd.Run() @@ -334,13 +339,18 @@ func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, // lineLogger logs each line written to it. type lineLogger struct { - ctx context.Context - log *slog.Logger - level slog.Level + ctx context.Context + log *slog.Logger + level slog.Level + prefix string last bytes.Buffer } +func (w *lineLogger) out(s string) { + w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant +} + func (w *lineLogger) Write(p []byte) (n int, err error) { lines := bytes.Split(p, []byte("\n")) // Finish writing line @@ -354,13 +364,13 @@ func (w *lineLogger) Write(p []byte) (n int, err error) { } // Newline found, log line - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) n += 1 w.last.Reset() // Log lines that are already newline-terminated for _, line := range lines[:len(lines)-1] { - w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant + w.out(string(line)) n += len(line) + 1 } @@ -375,6 +385,6 @@ func (w *lineLogger) Flush() { if w.last.Len() == 0 { return } - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) w.last.Reset() } diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index 3962ec8da1dbb..c558a7539831a 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -26,7 +26,6 @@ import ( "log/slog" "os" "path/filepath" - "syscall" "testing" "time" @@ -202,7 +201,7 @@ func TestWaitForStablePID(t *testing.T) { minStable: 3, maxCrashes: 2, findErrs: map[int]error{ - 2: syscall.ESRCH, + 2: os.ErrProcessDone, }, errored: true, }, @@ -213,7 +212,7 @@ func TestWaitForStablePID(t *testing.T) { minStable: 3, maxCrashes: 2, findErrs: map[int]error{ - 2: syscall.ESRCH, + 2: os.ErrProcessDone, }, }, { From dcf911b4c61ff0e7e3e5c8715991a24968be2e53 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Mon, 18 Nov 2024 22:00:36 -0500 Subject: [PATCH 13/14] atomic templates --- lib/autoupdate/agent/config.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index 59bb7751a00a8..6664fcaa94131 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -26,6 +26,7 @@ import ( "path/filepath" "text/template" + "github.com/google/renameio/v2" "github.com/gravitational/trace" ) @@ -96,11 +97,16 @@ func writeTemplate(path, t, linkDir, dataDir string) error { if err := os.MkdirAll(filepath.Dir(path), systemDirMode); err != nil { return trace.Wrap(err) } - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, configFileMode) + opts := []renameio.Option{ + renameio.WithPermissions(configFileMode), + renameio.WithExistingPermissions(), + } + f, err := renameio.NewPendingFile(path, opts...) if err != nil { return trace.Wrap(err) } - defer f.Close() + defer f.Cleanup() + tmpl, err := template.New(filepath.Base(path)).Parse(t) if err != nil { return trace.Wrap(err) @@ -109,5 +115,8 @@ func writeTemplate(path, t, linkDir, dataDir string) error { LinkDir string DataDir string }{linkDir, dataDir}) - return trace.Wrap(f.Close()) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(f.CloseAtomicallyReplace()) } From de5aaccda2306805f25240573f26e12cc786cd63 Mon Sep 17 00:00:00 2001 From: Stephen Levine Date: Mon, 18 Nov 2024 22:30:12 -0500 Subject: [PATCH 14/14] lint --- lib/autoupdate/agent/installer.go | 2 -- lib/autoupdate/agent/process.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index e9b2caadea780..2d31d26fd8262 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -60,8 +60,6 @@ const ( serviceDir = "lib/systemd/system" // serviceName contains the name of the Teleport SystemD service file. serviceName = "teleport.service" - // serviceDropinName contains the name of the Teleport Systemd service drop-in to support updates. - serviceDropinName = "teleport-update.conf" // updateServiceName contains the name of the Teleport Update Systemd service updateServiceName = "teleport-update.service" // updateTimerName contains the name of the Teleport Update Systemd timer diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index e8ed9659c2473..082e61156369b 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -217,7 +217,7 @@ func readInt(path string) (int, error) { // tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick. // If the path cannot be read, tickFile sends 0 on ch. -// Any error from the last attempt to read path is returned when ctx is cancelled, unless the error is os.ErrNotExist. +// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist. func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error { var err error for {