diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index b5c684ebfb628..8b5bf3eed4d93 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -359,8 +359,8 @@ type ProxySettings struct { type AutoUpdateSettings struct { // ToolsVersion defines the version of {tsh, tctl} for client auto update. ToolsVersion string `json:"tools_version"` - // ToolsMode defines mode client auto update feature `enabled|disabled`. - ToolsMode string `json:"tools_mode"` + // ToolsAutoUpdate indicates if the requesting tools client should be updated. + ToolsAutoUpdate bool `json:"tools_auto_update"` // AgentVersion defines the version of teleport that agents enrolled into autoupdates should run. AgentVersion string `json:"agent_version"` // AgentAutoUpdate indicates if the requesting agent should attempt to update now. diff --git a/integration/autoupdate/tools/main_test.go b/integration/autoupdate/tools/main_test.go index a14a6dc9fc683..bbc3f559f65c0 100644 --- a/integration/autoupdate/tools/main_test.go +++ b/integration/autoupdate/tools/main_test.go @@ -37,7 +37,8 @@ import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/integration/helpers" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/integration/helpers/archive" ) const ( @@ -133,9 +134,9 @@ func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, vers for _, app := range []string{"tsh", "tctl"} { output := filepath.Join(versionPath, app) switch runtime.GOOS { - case "windows": + case constants.WindowsOS: output = filepath.Join(versionPath, app+".exe") - case "darwin": + case constants.DarwinOS: output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app) } if err := buildBinary(output, toolsDir, version, baseURL); err != nil { @@ -143,15 +144,15 @@ func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, vers } } switch runtime.GOOS { - case "darwin": + case constants.DarwinOS: archivePath := filepath.Join(path, fmt.Sprintf("teleport-%s.pkg", version)) - return trace.Wrap(helpers.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest")) - case "windows": + return trace.Wrap(archive.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest")) + case constants.WindowsOS: archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-windows-amd64-bin.zip", version)) - return trace.Wrap(helpers.CompressDirToZipFile(ctx, versionPath, archivePath)) + return trace.Wrap(archive.CompressDirToZipFile(ctx, versionPath, archivePath)) default: archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-linux-%s-bin.tar.gz", version, runtime.GOARCH)) - return trace.Wrap(helpers.CompressDirToTarGzFile(ctx, versionPath, archivePath)) + return trace.Wrap(archive.CompressDirToTarGzFile(ctx, versionPath, archivePath)) } } diff --git a/integration/autoupdate/tools/updater/main.go b/integration/autoupdate/tools/updater/main.go index e14c76e5d5aa8..775c7ab7b2e9d 100644 --- a/integration/autoupdate/tools/updater/main.go +++ b/integration/autoupdate/tools/updater/main.go @@ -25,11 +25,9 @@ import ( "log" "os" "os/signal" - "runtime" "syscall" "time" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/autoupdate/tools" ) @@ -40,17 +38,20 @@ var ( ) func main() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() ctx, _ = signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) updater := tools.NewUpdater( - clientTools(), + tools.DefaultClientTools(), toolsDir, version, tools.WithBaseURL(baseURL), ) - toolsVersion, reExec := updater.CheckLocal() + toolsVersion, reExec, err := updater.CheckLocal() + if err != nil { + log.Fatal(err) + } if reExec { // Download and update the version of client tools required by the cluster. // This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly. @@ -76,13 +77,3 @@ func main() { fmt.Printf("Teleport v%v git\n", version) } } - -// clientTools list of the client tools needs to be updated. -func clientTools() []string { - switch runtime.GOOS { - case constants.WindowsOS: - return []string{"tsh.exe", "tctl.exe"} - default: - return []string{"tsh", "tctl"} - } -} diff --git a/integration/autoupdate/tools/updater_test.go b/integration/autoupdate/tools/updater_test.go index 52d7a7b8ab385..d429a7483910a 100644 --- a/integration/autoupdate/tools/updater_test.go +++ b/integration/autoupdate/tools/updater_test.go @@ -26,7 +26,6 @@ import ( "os/exec" "path/filepath" "regexp" - "runtime" "strings" "testing" "time" @@ -34,7 +33,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/autoupdate/tools" ) @@ -51,7 +49,7 @@ func TestUpdate(t *testing.T) { // Fetch compiled test binary with updater logic and install to $TELEPORT_HOME. updater := tools.NewUpdater( - clientTools(), + tools.DefaultClientTools(), toolsDir, testVersions[0], tools.WithBaseURL(baseURL), @@ -93,7 +91,7 @@ func TestParallelUpdate(t *testing.T) { // Initial fetch the updater binary un-archive and replace. updater := tools.NewUpdater( - clientTools(), + tools.DefaultClientTools(), toolsDir, testVersions[0], tools.WithBaseURL(baseURL), @@ -167,7 +165,7 @@ func TestUpdateInterruptSignal(t *testing.T) { // Initial fetch the updater binary un-archive and replace. updater := tools.NewUpdater( - clientTools(), + tools.DefaultClientTools(), toolsDir, testVersions[0], tools.WithBaseURL(baseURL), @@ -220,12 +218,3 @@ func TestUpdateInterruptSignal(t *testing.T) { } assert.Contains(t, output.String(), "Update progress:") } - -func clientTools() []string { - switch runtime.GOOS { - case constants.WindowsOS: - return []string{"tsh.exe", "tctl.exe"} - default: - return []string{"tsh", "tctl"} - } -} diff --git a/integration/helpers/archive.go b/integration/helpers/archive/packaging.go similarity index 99% rename from integration/helpers/archive.go rename to integration/helpers/archive/packaging.go index 6e48108013d86..ee237749115a3 100644 --- a/integration/helpers/archive.go +++ b/integration/helpers/archive/packaging.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package helpers +package archive import ( "archive/tar" diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index c9459a5381189..030b29ba08341 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -184,6 +184,7 @@ require ( github.com/google/go-tspi v0.3.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af // indirect + github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/google/safetext v0.0.0-20240104143208-7a7d9b3d812f // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 257c076fd1e5d..95ef579188b65 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -1252,7 +1252,6 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM= github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= -github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= diff --git a/lib/autoupdate/tools/progress.go b/lib/autoupdate/tools/progress.go index 95395003730ec..34bbad2ff888d 100644 --- a/lib/autoupdate/tools/progress.go +++ b/lib/autoupdate/tools/progress.go @@ -20,24 +20,59 @@ package tools import ( "fmt" + "io" "strings" + + "github.com/gravitational/trace" ) type progressWriter struct { - n int64 - limit int64 + n int64 + limit int64 + size int + progress int } -func (w *progressWriter) Write(p []byte) (int, error) { - w.n = w.n + int64(len(p)) +// newProgressWriter creates progress writer instance and prints empty +// progress bar right after initialisation. +func newProgressWriter(size int) (*progressWriter, func()) { + pw := &progressWriter{size: size} + pw.Print(0) + return pw, func() { + fmt.Print("\n") + } +} - n := int((w.n*100)/w.limit) / 10 - bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n) +// Print prints the update progress bar with `n` bricks. +func (w *progressWriter) Print(n int) { + bricks := strings.Repeat("▒", n) + strings.Repeat(" ", w.size-n) fmt.Print("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)") +} - if w.n == w.limit { - fmt.Print("\n") +func (w *progressWriter) Write(p []byte) (int, error) { + if w.limit == 0 || w.size == 0 { + return len(p), nil + } + + w.n += int64(len(p)) + bricks := int((w.n*100)/w.limit) / w.size + if w.progress != bricks { + w.Print(bricks) + w.progress = bricks } return len(p), nil } + +// CopyLimit sets the limit of writing bytes to the progress writer and initiate copying process. +func (w *progressWriter) CopyLimit(dst io.Writer, src io.Reader, limit int64) (written int64, err error) { + if limit < 0 { + n, err := io.Copy(dst, io.TeeReader(src, w)) + w.Print(w.size) + return n, trace.Wrap(err) + } + + w.limit = limit + n, err := io.CopyN(dst, io.TeeReader(src, w), limit) + return n, trace.Wrap(err) +} diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go index 96991044ccc31..96352e34d9910 100644 --- a/lib/autoupdate/tools/updater.go +++ b/lib/autoupdate/tools/updater.go @@ -36,12 +36,12 @@ import ( "syscall" "time" + "github.com/coreos/go-semver/semver" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" - "github.com/gravitational/teleport/api/types/autoupdate" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/packaging" ) @@ -114,30 +114,37 @@ func NewUpdater(tools []string, toolsDir string, localVersion string, options .. // CheckLocal is run at client tool startup and will only perform local checks. // Returns the version needs to be updated and re-executed, by re-execution flag we // understand that update and re-execute is required. -func (u *Updater) CheckLocal() (version string, reExec bool) { +func (u *Updater) CheckLocal() (version string, reExec bool, err error) { // Check if the user has requested a specific version of client tools. requestedVersion := os.Getenv(teleportToolsVersionEnv) switch requestedVersion { // The user has turned off any form of automatic updates. case "off": - return "", false + return "", false, nil // Requested version already the same as client version. case u.localVersion: - return u.localVersion, false + return u.localVersion, false, nil + // No requested version, we continue. + case "": + // Requested version that is not the local one. + default: + if _, err := semver.NewVersion(requestedVersion); err != nil { + return "", false, trace.Wrap(err, "checking that request version is semantic") + } + return requestedVersion, true, nil } // If a version of client tools has already been downloaded to // tools directory, return that. - toolsVersion, err := checkToolVersion(u.toolsDir) - if err != nil { - return "", false + toolsVersion, err := CheckToolVersion(u.toolsDir) + if trace.IsNotFound(err) { + return u.localVersion, false, nil } - // The user has requested a specific version of client tools. - if requestedVersion != "" && requestedVersion != toolsVersion { - return requestedVersion, true + if err != nil { + return "", false, trace.Wrap(err) } - return toolsVersion, false + return toolsVersion, true, nil } // CheckRemote first checks the version set by the environment variable. If not set or disabled, @@ -145,7 +152,7 @@ func (u *Updater) CheckLocal() (version string, reExec bool) { // the `webapi/find` handler, which stores information about the required client tools version to // operate with this cluster. It returns the semantic version that needs updating and whether // re-execution is necessary, by re-execution flag we understand that update and re-execute is required. -func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version string, reExec bool, err error) { +func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string, insecure bool) (version string, reExec bool, err error) { // Check if the user has requested a specific version of client tools. requestedVersion := os.Getenv(teleportToolsVersionEnv) switch requestedVersion { @@ -155,6 +162,14 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st // Requested version already the same as client version. case u.localVersion: return u.localVersion, false, nil + // No requested version, we continue. + case "": + // Requested version that is not the local one. + default: + if _, err := semver.NewVersion(requestedVersion); err != nil { + return "", false, trace.Wrap(err, "checking that request version is semantic") + } + return requestedVersion, true, nil } certPool, err := x509.SystemCertPool() @@ -165,7 +180,8 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st Context: ctx, ProxyAddr: proxyAddr, Pool: certPool, - Timeout: 30 * time.Second, + Timeout: 10 * time.Second, + Insecure: insecure, }) if err != nil { return "", false, trace.Wrap(err) @@ -173,28 +189,28 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st // If a version of client tools has already been downloaded to // tools directory, return that. - toolsVersion, err := checkToolVersion(u.toolsDir) - if err != nil { + toolsVersion, err := CheckToolVersion(u.toolsDir) + if err != nil && !trace.IsNotFound(err) { return "", false, trace.Wrap(err) } switch { - case requestedVersion != "" && requestedVersion != toolsVersion: - return requestedVersion, true, nil - case resp.AutoUpdate.ToolsMode != autoupdate.ToolsUpdateModeEnabled || resp.AutoUpdate.ToolsVersion == "": - return "", false, nil + case !resp.AutoUpdate.ToolsAutoUpdate || resp.AutoUpdate.ToolsVersion == "": + if toolsVersion == "" { + return u.localVersion, false, nil + } case u.localVersion == resp.AutoUpdate.ToolsVersion: - return resp.AutoUpdate.ToolsVersion, false, nil + return u.localVersion, false, nil case resp.AutoUpdate.ToolsVersion != toolsVersion: return resp.AutoUpdate.ToolsVersion, true, nil } - return toolsVersion, false, nil + return toolsVersion, true, nil } // UpdateWithLock acquires filesystem lock, downloads requested version package, // unarchive and replace existing one. -func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) { +func (u *Updater) UpdateWithLock(ctx context.Context, updateToolsVersion string) (err error) { // Create tools directory if it does not exist. if err := os.MkdirAll(u.toolsDir, 0o755); err != nil { return trace.Wrap(err) @@ -211,21 +227,20 @@ func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err // If the version of the running binary or the version downloaded to // tools directory is the same as the requested version of client tools, // nothing to be done, exit early. - teleportVersion, err := checkToolVersion(u.toolsDir) + toolsVersion, err := CheckToolVersion(u.toolsDir) if err != nil && !trace.IsNotFound(err) { return trace.Wrap(err) - } - if toolsVersion == u.localVersion || toolsVersion == teleportVersion { + if updateToolsVersion == toolsVersion { return nil } // Download and update client tools in tools directory. - if err := u.Update(ctx, toolsVersion); err != nil { + if err := u.Update(ctx, updateToolsVersion); err != nil { return trace.Wrap(err) } - return + return nil } // Update downloads requested version and replace it with existing one and cleanups the previous downloads @@ -237,10 +252,18 @@ func (u *Updater) Update(ctx context.Context, toolsVersion string) error { return trace.Wrap(err) } + var pkgNames []string for _, pkg := range packages { - if err := u.update(ctx, pkg); err != nil { + pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) + if err := u.update(ctx, pkg, pkgName); err != nil { return trace.Wrap(err) } + pkgNames = append(pkgNames, pkgName) + } + + // Cleanup the tools directory with previously downloaded and un-archived versions. + if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgNames); err != nil { + slog.WarnContext(ctx, "failed to clean up tools directory", "error", err) } return nil @@ -248,16 +271,8 @@ func (u *Updater) Update(ctx context.Context, toolsVersion string) error { // update downloads the archive and validate against the hash. Download to a // temporary path within tools directory. -func (u *Updater) update(ctx context.Context, pkg packageURL) error { - hash, err := u.downloadHash(ctx, pkg.Hash) - if pkg.Optional && trace.IsNotFound(err) { - return nil - } - if err != nil { - return trace.Wrap(err) - } - - f, err := os.CreateTemp(u.toolsDir, "tmp-") +func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) error { + f, err := os.CreateTemp("", "teleport-") if err != nil { return trace.Wrap(err) } @@ -275,11 +290,19 @@ func (u *Updater) update(ctx context.Context, pkg packageURL) error { if err != nil { return trace.Wrap(err) } + + hash, err := u.downloadHash(ctx, pkg.Hash) + if pkg.Optional && trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + if !bytes.Equal(archiveHash, hash) { return trace.BadParameter("hash of archive does not match downloaded archive") } - pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) extractDir := filepath.Join(u.toolsDir, pkgName) if runtime.GOOS != constants.DarwinOS { if err := os.Mkdir(extractDir, 0o755); err != nil { @@ -291,10 +314,6 @@ func (u *Updater) update(ctx context.Context, pkg packageURL) error { if err := packaging.ReplaceToolsBinaries(u.toolsDir, f.Name(), extractDir, u.tools); err != nil { return trace.Wrap(err) } - // Cleanup the tools directory with previously downloaded and un-archived versions. - if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgName); err != nil { - return trace.Wrap(err) - } return nil } @@ -340,7 +359,7 @@ func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error) } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { - return nil, trace.NotFound("hash file is not found: %v", resp.StatusCode) + return nil, trace.NotFound("hash file is not found: %q", url) } if resp.StatusCode != http.StatusOK { return nil, trace.BadParameter("bad status when downloading archive hash: %v", resp.StatusCode) @@ -362,6 +381,11 @@ func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error) // downloadArchive downloads the archive package by `url` and writes content to the writer interface, // return calculated sha256 hash sum of the content. func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) ([]byte, error) { + // Display a progress bar before initiating the update request to inform the user that + // an update is in progress, allowing them the option to cancel before actual response + // which might be delayed with slow internet connection or complete isolation to CDN. + pw, finish := newProgressWriter(10) + defer finish() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, trace.Wrap(err) @@ -385,14 +409,10 @@ func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) } h := sha256.New() - pw := &progressWriter{n: 0, limit: resp.ContentLength} - body := io.TeeReader(io.TeeReader(resp.Body, h), pw) - // It is a little inefficient to download the file to disk and then re-load // it into memory to unarchive later, but this is safer as it allows client // tools to validate the hash before trying to operate on the archive. - _, err = io.CopyN(f, body, resp.ContentLength) - if err != nil { + if _, err := pw.CopyLimit(f, io.TeeReader(resp.Body, h), resp.ContentLength); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/autoupdate/tools/utils.go b/lib/autoupdate/tools/utils.go index d552b31abefe4..f937d228b5cd4 100644 --- a/lib/autoupdate/tools/utils.go +++ b/lib/autoupdate/tools/utils.go @@ -33,6 +33,7 @@ import ( "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/utils" @@ -52,14 +53,26 @@ func Dir() (string, error) { return filepath.Join(home, ".tsh", "bin"), nil } -func checkToolVersion(toolsDir string) (string, error) { +// DefaultClientTools list of the client tools needs to be updated by default. +func DefaultClientTools() []string { + switch runtime.GOOS { + case constants.WindowsOS: + return []string{"tsh.exe", "tctl.exe"} + default: + return []string{"tsh", "tctl"} + } +} + +// CheckToolVersion returns current installed client tools version, must return NotFoundError if +// the client tools is not found in tools directory. +func CheckToolVersion(toolsDir string) (string, error) { // Find the path to the current executable. path, err := toolName(toolsDir) if err != nil { return "", trace.Wrap(err) } if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { - return "", nil + return "", trace.NotFound("autoupdate tool not found in %q", toolsDir) } else if err != nil { return "", trace.Wrap(err) } @@ -76,7 +89,7 @@ func checkToolVersion(toolsDir string) (string, error) { command.Env = []string{teleportToolsVersionEnv + "=off"} output, err := command.Output() if err != nil { - return "", trace.Wrap(err) + return "", trace.WrapWithMessage(err, "failed to determine version of %q tool", path) } // The output for "{tsh, tctl} version" can be multiple lines. Find the diff --git a/lib/client/api.go b/lib/client/api.go index d087fb02d1e34..d8a35dc95feee 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -74,6 +74,7 @@ import ( "github.com/gravitational/teleport/lib/auth/touchid" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/autoupdate/tools" libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/client/sso" "github.com/gravitational/teleport/lib/client/terminal" @@ -95,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/agentconn" logutils "github.com/gravitational/teleport/lib/utils/log" + "github.com/gravitational/teleport/lib/utils/signal" ) const ( @@ -702,6 +704,39 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, return trace.Wrap(err) } + // The user has typed a command like `tsh ssh ...` without being logged in, + // if the running binary needs to be updated, update and re-exec. + // + // If needed, download the new version of {tsh, tctl} and re-exec. Make + // sure to exit this process with the same exit code as the child process. + // + toolsDir, err := tools.Dir() + if err != nil { + return trace.Wrap(err) + } + updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version) + toolsVersion, reExec, err := updater.CheckRemote(ctx, tc.WebProxyAddr, tc.InsecureSkipVerify) + if err != nil { + return trace.Wrap(err) + } + if reExec { + ctxUpdate, cancel := signal.GetSignalHandler().NotifyContext(context.Background()) + defer cancel() + // Download the version of client tools required by the cluster. + err := updater.UpdateWithLock(ctxUpdate, toolsVersion) + if err != nil && !errors.Is(err, context.Canceled) { + utils.FatalError(err) + } + // Re-execute client tools with the correct version of client tools. + code, err := updater.Exec() + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) + } + } + if opt.afterLoginHook != nil { if err := opt.afterLoginHook(); err != nil { return trace.Wrap(err) diff --git a/lib/utils/packaging/unarchive.go b/lib/utils/packaging/unarchive.go index f1a197e095b1a..6496afbc182c3 100644 --- a/lib/utils/packaging/unarchive.go +++ b/lib/utils/packaging/unarchive.go @@ -38,13 +38,13 @@ const ( ) // RemoveWithSuffix removes all that matches the provided suffix, except for file or directory with `skipName`. -func RemoveWithSuffix(dir, suffix, skipName string) error { +func RemoveWithSuffix(dir, suffix string, skipNames []string) error { var removePaths []string err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { return trace.Wrap(err) } - if skipName == info.Name() { + if slices.Contains(skipNames, info.Name()) { return nil } if !strings.HasSuffix(info.Name(), suffix) { @@ -59,12 +59,13 @@ func RemoveWithSuffix(dir, suffix, skipName string) error { if err != nil { return trace.Wrap(err) } + var aggErr []error for _, path := range removePaths { if err := os.RemoveAll(path); err != nil { - return trace.Wrap(err) + aggErr = append(aggErr, err) } } - return nil + return trace.NewAggregate(aggErr...) } // replaceZip un-archives the Teleport package in .zip format, iterates through @@ -118,7 +119,7 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName defer file.Close() dest := filepath.Join(extractDir, baseName) - destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) + destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return trace.Wrap(err) } @@ -131,7 +132,10 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName if err := os.Remove(appPath); err != nil && !os.IsNotExist(err) { return trace.Wrap(err) } - if err := os.Symlink(dest, appPath); err != nil { + // For the Windows build we have to copy binary to be able + // to do this without administrative access as it required + // for symlinks. + if err := utils.CopyFile(dest, appPath, 0o755); err != nil { return trace.Wrap(err) } return trace.Wrap(destFile.Close()) diff --git a/lib/utils/packaging/unarchive_test.go b/lib/utils/packaging/unarchive_test.go index 30933bbb75927..b124b603b0fd5 100644 --- a/lib/utils/packaging/unarchive_test.go +++ b/lib/utils/packaging/unarchive_test.go @@ -30,7 +30,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/integration/helpers" + "github.com/gravitational/teleport/integration/helpers/archive" ) // TestPackaging verifies un-archiving of all supported teleport package formats. @@ -63,7 +63,7 @@ func TestPackaging(t *testing.T) { t.Run("tar.gz", func(t *testing.T) { archivePath := filepath.Join(toolsDir, "tsh.tar.gz") - err = helpers.CompressDirToTarGzFile(ctx, sourceDir, archivePath) + err = archive.CompressDirToTarGzFile(ctx, sourceDir, archivePath) require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") @@ -85,7 +85,7 @@ func TestPackaging(t *testing.T) { t.Skip("unsupported platform") } archivePath := filepath.Join(toolsDir, "tsh.pkg") - err = helpers.CompressDirToPkgFile(ctx, sourceDir, archivePath, "com.example.pkgtest") + err = archive.CompressDirToPkgFile(ctx, sourceDir, archivePath, "com.example.pkgtest") require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") @@ -101,7 +101,7 @@ func TestPackaging(t *testing.T) { t.Run("zip", func(t *testing.T) { archivePath := filepath.Join(toolsDir, "tsh.zip") - err = helpers.CompressDirToZipFile(ctx, sourceDir, archivePath) + err = archive.CompressDirToZipFile(ctx, sourceDir, archivePath) require.NoError(t, err) require.FileExists(t, archivePath, "archive not created") @@ -132,7 +132,7 @@ func TestRemoveWithSuffix(t *testing.T) { dirInSkipPath := filepath.Join(skipPath, dirForRemove) require.NoError(t, os.MkdirAll(skipPath, 0o755)) - err := RemoveWithSuffix(testDir, dirForRemove, skipName) + err := RemoveWithSuffix(testDir, dirForRemove, []string{skipName}) require.NoError(t, err) _, err = os.Stat(filepath.Join(testDir, dirForRemove)) diff --git a/lib/utils/packaging/unarchive_unix.go b/lib/utils/packaging/unarchive_unix.go index ea51afdbbc7f0..8daf1b3aa5525 100644 --- a/lib/utils/packaging/unarchive_unix.go +++ b/lib/utils/packaging/unarchive_unix.go @@ -186,9 +186,10 @@ func replacePkg(toolsDir string, archivePath string, extractDir string, execName // swap operations. This ensures that the "com.apple.macl" extended // attribute is set and macOS will not send a SIGKILL to the process // if multiple processes are trying to operate on it. - command := exec.Command(path, "version", "--client") + command := exec.Command(path, "version") + command.Env = []string{"TELEPORT_TOOLS_VERSION=off"} if err := command.Run(); err != nil { - return trace.Wrap(err) + return trace.WrapWithMessage(err, "failed to validate binary") } // Due to macOS applications not being a single binary (they are a diff --git a/lib/utils/signal/stack_handler.go b/lib/utils/signal/stack_handler.go new file mode 100644 index 0000000000000..0fcbefb081d11 --- /dev/null +++ b/lib/utils/signal/stack_handler.go @@ -0,0 +1,98 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package signal + +import ( + "container/list" + "context" + "os" + "os/signal" + "sync" + "syscall" +) + +// Handler implements stack for context cancellation. +type Handler struct { + mu sync.Mutex + list *list.List +} + +var handler = &Handler{ + list: list.New(), +} + +// GetSignalHandler returns global singleton instance of signal +func GetSignalHandler() *Handler { + return handler +} + +// NotifyContext creates context which going to be canceled after SIGINT, SIGTERM +// in order of adding them to the stack. When very first context is canceled +// we stop watching the OS signals. +func (s *Handler) NotifyContext(parent context.Context) (context.Context, context.CancelFunc) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.list.Len() == 0 { + s.listenSignals() + } + + ctx, cancel := context.WithCancel(parent) + element := s.list.PushBack(cancel) + + return ctx, func() { + s.mu.Lock() + defer s.mu.Unlock() + + s.list.Remove(element) + cancel() + } +} + +// listenSignals sets up the signal listener for SIGINT, SIGTERM. +func (s *Handler) listenSignals() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + for { + if sig := <-sigChan; sig == nil { + return + } + if !s.cancelNext() { + signal.Stop(sigChan) + return + } + } + }() +} + +// cancelNext calls the most recent cancel func in the stack. +func (s *Handler) cancelNext() bool { + s.mu.Lock() + defer s.mu.Unlock() + + if s.list.Len() > 0 { + cancel := s.list.Remove(s.list.Back()) + if cancel != nil { + cancel.(context.CancelFunc)() + } + } + + return s.list.Len() != 0 +} diff --git a/lib/utils/signal/stack_handler_test.go b/lib/utils/signal/stack_handler_test.go new file mode 100644 index 0000000000000..b900939b886e8 --- /dev/null +++ b/lib/utils/signal/stack_handler_test.go @@ -0,0 +1,88 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package signal + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetSignalHandler verifies the cancellation stack order. +func TestGetSignalHandler(t *testing.T) { + testHandler := GetSignalHandler() + parent := context.Background() + + ctx1, cancel1 := testHandler.NotifyContext(parent) + ctx2, cancel2 := testHandler.NotifyContext(parent) + ctx3, cancel3 := testHandler.NotifyContext(parent) + ctx4, cancel4 := testHandler.NotifyContext(parent) + t.Cleanup(func() { + cancel4() + cancel2() + cancel1() + }) + + // Verify that all context not canceled. + require.NoError(t, ctx4.Err()) + require.NoError(t, ctx3.Err()) + require.NoError(t, ctx2.Err()) + require.NoError(t, ctx1.Err()) + + // Cancel context manually to ensure it was removed from stack in right order. + cancel3() + + // Check that last added context is canceled. + require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT)) + select { + case <-ctx4.Done(): + assert.ErrorIs(t, ctx3.Err(), context.Canceled) + assert.NoError(t, ctx2.Err()) + assert.NoError(t, ctx1.Err()) + case <-time.After(time.Second): + assert.Fail(t, "context 3 must be canceled") + } + + // Send interrupt signal to cancel next context in the stack. + require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT)) + select { + case <-ctx2.Done(): + assert.ErrorIs(t, ctx4.Err(), context.Canceled) + assert.ErrorIs(t, ctx3.Err(), context.Canceled) + assert.NoError(t, ctx1.Err()) + case <-time.After(time.Second): + assert.Fail(t, "context 2 must be canceled") + } + + // All context must be canceled. + require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT)) + select { + case <-ctx1.Done(): + assert.ErrorIs(t, ctx4.Err(), context.Canceled) + assert.ErrorIs(t, ctx3.Err(), context.Canceled) + assert.ErrorIs(t, ctx2.Err(), context.Canceled) + case <-time.After(time.Second): + assert.Fail(t, "context 1 must be canceled") + } +} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 2bbb2bcbdd561..cc309efe1043f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1593,7 +1593,7 @@ func (h *Handler) automaticUpdateSettings(ctx context.Context) webclient.AutoUpd } return webclient.AutoUpdateSettings{ - ToolsMode: getToolsMode(autoUpdateConfig), + ToolsAutoUpdate: getToolsAutoUpdate(autoUpdateConfig), ToolsVersion: getToolsVersion(autoUpdateVersion), AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentVersion: getAgentVersion(autoUpdateVersion), @@ -5154,15 +5154,15 @@ func readEtagFromAppHash(fs http.FileSystem) (string, error) { return etag, nil } -func getToolsMode(config *autoupdatepb.AutoUpdateConfig) string { +func getToolsAutoUpdate(config *autoupdatepb.AutoUpdateConfig) bool { // If we can't get the AU config or if AUs are not configured, we default to "disabled". // This ensures we fail open and don't accidentally update agents if something is going wrong. // If we want to enable AUs by default, it would be better to create a default "autoupdate_config" resource // than changing this logic. - if config.GetSpec().GetTools() == nil { - return autoupdate.ToolsUpdateModeDisabled + if config.GetSpec().GetTools() != nil { + return config.GetSpec().GetTools().GetMode() == autoupdate.ToolsUpdateModeEnabled } - return config.GetSpec().GetTools().GetMode() + return false } func getToolsVersion(version *autoupdatepb.AutoUpdateVersion) string { diff --git a/lib/web/apiserver_ping_test.go b/lib/web/apiserver_ping_test.go index 5ce3720375c46..2bf325d4f7902 100644 --- a/lib/web/apiserver_ping_test.go +++ b/lib/web/apiserver_ping_test.go @@ -306,7 +306,7 @@ func TestPing_autoUpdateResources(t *testing.T) { name: "resources not defined", expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, AgentVersion: api.Version, @@ -320,7 +320,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, }, expected: webclient.AutoUpdateSettings{ - ToolsMode: autoupdate.ToolsUpdateModeEnabled, + ToolsAutoUpdate: true, ToolsVersion: api.Version, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, @@ -346,7 +346,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: true, AgentVersion: "1.2.4", @@ -371,7 +371,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, AgentVersion: "1.2.4", @@ -384,7 +384,7 @@ func TestPing_autoUpdateResources(t *testing.T) { version: &autoupdatev1pb.AutoUpdateVersionSpec{}, expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, AgentVersion: api.Version, @@ -400,7 +400,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, expected: webclient.AutoUpdateSettings{ ToolsVersion: "1.2.3", - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, AgentVersion: api.Version, @@ -420,7 +420,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, }, expected: webclient.AutoUpdateSettings{ - ToolsMode: autoupdate.ToolsUpdateModeEnabled, + ToolsAutoUpdate: true, ToolsVersion: "1.2.3", AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, @@ -440,7 +440,7 @@ func TestPing_autoUpdateResources(t *testing.T) { }, }, expected: webclient.AutoUpdateSettings{ - ToolsMode: autoupdate.ToolsUpdateModeDisabled, + ToolsAutoUpdate: false, ToolsVersion: "3.2.1", AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, AgentAutoUpdate: false, diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 4a85be6e934b0..51ff5f8687f75 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/storage" + "github.com/gravitational/teleport/lib/autoupdate/tools" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/identityfile" libmfa "github.com/gravitational/teleport/lib/client/mfa" @@ -54,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/signal" "github.com/gravitational/teleport/tool/common" ) @@ -104,8 +106,43 @@ type CLICommand interface { // "distributions" like OSS or Enterprise // // distribution: name of the Teleport distribution -func Run(commands []CLICommand) { - err := TryRun(commands, os.Args[1:]) +func Run(ctx context.Context, commands []CLICommand) { + // The user has typed a command like `tsh ssh ...` without being logged in, + // if the running binary needs to be updated, update and re-exec. + // + // If needed, download the new version of {tsh, tctl} and re-exec. Make + // sure to exit this process with the same exit code as the child process. + // + toolsDir, err := tools.Dir() + if err != nil { + utils.FatalError(err) + } + updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version) + toolsVersion, reExec, err := updater.CheckLocal() + if err != nil { + utils.FatalError(err) + } + if reExec { + ctxUpdate, cancel := signal.GetSignalHandler().NotifyContext(ctx) + defer cancel() + // Download the version of client tools required by the cluster. This + // is required if the user passed in the TELEPORT_TOOLS_VERSION + // explicitly. + err := updater.UpdateWithLock(ctxUpdate, toolsVersion) + if err != nil && !errors.Is(err, context.Canceled) { + utils.FatalError(err) + } + // Re-execute client tools with the correct version of client tools. + code, err := updater.Exec() + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) + } + } + + err = TryRun(commands, os.Args[1:]) if err != nil { var exitError *common.ExitCodeError if errors.As(err, &exitError) { diff --git a/tool/tctl/main.go b/tool/tctl/main.go index f363e347f25c9..6dfae87ffdef2 100644 --- a/tool/tctl/main.go +++ b/tool/tctl/main.go @@ -19,9 +19,15 @@ package main import ( + "context" + + "github.com/gravitational/teleport/lib/utils/signal" "github.com/gravitational/teleport/tool/tctl/common" ) func main() { - common.Run(common.Commands()) + ctx, cancel := signal.GetSignalHandler().NotifyContext(context.Background()) + defer cancel() + + common.Run(ctx, common.Commands()) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 1adabe7b337c1..a1b0edc1d15c3 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -41,7 +41,6 @@ import ( "strconv" "strings" "sync" - "syscall" "time" "github.com/alecthomas/kingpin/v2" @@ -74,6 +73,7 @@ import ( "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/auth/authclient" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" + "github.com/gravitational/teleport/lib/autoupdate/tools" "github.com/gravitational/teleport/lib/benchmark" benchmarkdb "github.com/gravitational/teleport/lib/benchmark/db" "github.com/gravitational/teleport/lib/client" @@ -94,6 +94,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/diagnostics/latency" "github.com/gravitational/teleport/lib/utils/mlock" + stacksignal "github.com/gravitational/teleport/lib/utils/signal" "github.com/gravitational/teleport/tool/common" "github.com/gravitational/teleport/tool/common/fido2" "github.com/gravitational/teleport/tool/common/touchid" @@ -609,7 +610,7 @@ func Main() { cmdLineOrig := os.Args[1:] var cmdLine []string - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + ctx, cancel := stacksignal.GetSignalHandler().NotifyContext(context.Background()) defer cancel() // lets see: if the executable name is 'ssh' or 'scp' we convert @@ -707,6 +708,38 @@ func initLogger(cf *CLIConf) { // // DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions). func Run(ctx context.Context, args []string, opts ...CliOption) error { + // At process startup, check if a version has already been downloaded to + // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION + // environment variable. If so, re-exec that version of {tsh, tctl}. + toolsDir, err := tools.Dir() + if err != nil { + return trace.Wrap(err) + } + updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version) + toolsVersion, reExec, err := updater.CheckLocal() + if err != nil { + return trace.Wrap(err) + } + if reExec { + ctxUpdate, cancel := stacksignal.GetSignalHandler().NotifyContext(ctx) + defer cancel() + // Download the version of client tools required by the cluster. This + // is required if the user passed in the TELEPORT_TOOLS_VERSION + // explicitly. + err := updater.UpdateWithLock(ctxUpdate, toolsVersion) + if err != nil && !errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } + // Re-execute client tools with the correct version of client tools. + code, err := updater.Exec() + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) + } + } + cf := CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), @@ -1239,8 +1272,6 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { bench.Hidden() } - var err error - cf.executablePath, err = os.Executable() if err != nil { return trace.Wrap(err) @@ -1866,6 +1897,14 @@ func onLogin(cf *CLIConf) error { } tc.HomePath = cf.HomePath + // The user is not logged in and has typed in `tsh --proxy=... login`, if + // the running binary needs to be updated, update and re-exec. + if profile == nil { + if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil { + return trace.Wrap(err) + } + } + // client is already logged in and profile is not expired if profile != nil && !profile.IsExpired(time.Now()) { switch { @@ -1876,6 +1915,13 @@ func onLogin(cf *CLIConf) error { // current status case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "" || host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "": + + // The user has typed `tsh login`, if the running binary needs to + // be updated, update and re-exec. + if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil { + return trace.Wrap(err) + } + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1889,6 +1935,13 @@ func onLogin(cf *CLIConf) error { // if the proxy names match but nothing else is specified; show motd and update active profile and kube configs case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "": + + // The user has typed `tsh login`, if the running binary needs to + // be updated, update and re-exec. + if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil { + return trace.Wrap(err) + } + _, err := tc.PingAndShowMOTD(cf.Context) if err != nil { return trace.Wrap(err) @@ -1959,7 +2012,11 @@ func onLogin(cf *CLIConf) error { // otherwise just pass through to standard login default: - + // The user is logged in and has typed in `tsh --proxy=... login`, if + // the running binary needs to be updated, update and re-exec. + if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil { + return trace.Wrap(err) + } } } @@ -5552,6 +5609,43 @@ const ( "https://goteleport.com/docs/access-controls/guides/headless/#troubleshooting" ) +func updateAndRun(ctx context.Context, proxy string, insecure bool) error { + // The user has typed a command like `tsh ssh ...` without being logged in, + // if the running binary needs to be updated, update and re-exec. + // + // If needed, download the new version of {tsh, tctl} and re-exec. Make + // sure to exit this process with the same exit code as the child process. + // + toolsDir, err := tools.Dir() + if err != nil { + return trace.Wrap(err) + } + updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version) + toolsVersion, reExec, err := updater.CheckRemote(ctx, proxy, insecure) + if err != nil { + return trace.Wrap(err) + } + if reExec { + ctxUpdate, cancel := stacksignal.GetSignalHandler().NotifyContext(context.Background()) + defer cancel() + // Download the version of client tools required by the cluster. + err := updater.UpdateWithLock(ctxUpdate, toolsVersion) + if err != nil && !errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } + // Re-execute client tools with the correct version of client tools. + code, err := updater.Exec() + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) + } + } + + return nil +} + // Lock the process memory to prevent rsa keys and certificates in memory from being exposed in a swap. func tryLockMemory(cf *CLIConf) error { if cf.MlockMode == mlockModeAuto {