diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go
index b5c684ebfb628..8b5bf3eed4d93 100644
--- a/api/client/webclient/webclient.go
+++ b/api/client/webclient/webclient.go
@@ -359,8 +359,8 @@ type ProxySettings struct {
type AutoUpdateSettings struct {
// ToolsVersion defines the version of {tsh, tctl} for client auto update.
ToolsVersion string `json:"tools_version"`
- // ToolsMode defines mode client auto update feature `enabled|disabled`.
- ToolsMode string `json:"tools_mode"`
+ // ToolsAutoUpdate indicates if the requesting tools client should be updated.
+ ToolsAutoUpdate bool `json:"tools_auto_update"`
// AgentVersion defines the version of teleport that agents enrolled into autoupdates should run.
AgentVersion string `json:"agent_version"`
// AgentAutoUpdate indicates if the requesting agent should attempt to update now.
diff --git a/integration/autoupdate/tools/main_test.go b/integration/autoupdate/tools/main_test.go
index a14a6dc9fc683..bbc3f559f65c0 100644
--- a/integration/autoupdate/tools/main_test.go
+++ b/integration/autoupdate/tools/main_test.go
@@ -37,7 +37,8 @@ import (
"github.com/gravitational/trace"
- "github.com/gravitational/teleport/integration/helpers"
+ "github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/integration/helpers/archive"
)
const (
@@ -133,9 +134,9 @@ func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, vers
for _, app := range []string{"tsh", "tctl"} {
output := filepath.Join(versionPath, app)
switch runtime.GOOS {
- case "windows":
+ case constants.WindowsOS:
output = filepath.Join(versionPath, app+".exe")
- case "darwin":
+ case constants.DarwinOS:
output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app)
}
if err := buildBinary(output, toolsDir, version, baseURL); err != nil {
@@ -143,15 +144,15 @@ func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, vers
}
}
switch runtime.GOOS {
- case "darwin":
+ case constants.DarwinOS:
archivePath := filepath.Join(path, fmt.Sprintf("teleport-%s.pkg", version))
- return trace.Wrap(helpers.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest"))
- case "windows":
+ return trace.Wrap(archive.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest"))
+ case constants.WindowsOS:
archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-windows-amd64-bin.zip", version))
- return trace.Wrap(helpers.CompressDirToZipFile(ctx, versionPath, archivePath))
+ return trace.Wrap(archive.CompressDirToZipFile(ctx, versionPath, archivePath))
default:
archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-linux-%s-bin.tar.gz", version, runtime.GOARCH))
- return trace.Wrap(helpers.CompressDirToTarGzFile(ctx, versionPath, archivePath))
+ return trace.Wrap(archive.CompressDirToTarGzFile(ctx, versionPath, archivePath))
}
}
diff --git a/integration/autoupdate/tools/updater/main.go b/integration/autoupdate/tools/updater/main.go
index e14c76e5d5aa8..775c7ab7b2e9d 100644
--- a/integration/autoupdate/tools/updater/main.go
+++ b/integration/autoupdate/tools/updater/main.go
@@ -25,11 +25,9 @@ import (
"log"
"os"
"os/signal"
- "runtime"
"syscall"
"time"
- "github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/autoupdate/tools"
)
@@ -40,17 +38,20 @@ var (
)
func main() {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
ctx, _ = signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
updater := tools.NewUpdater(
- clientTools(),
+ tools.DefaultClientTools(),
toolsDir,
version,
tools.WithBaseURL(baseURL),
)
- toolsVersion, reExec := updater.CheckLocal()
+ toolsVersion, reExec, err := updater.CheckLocal()
+ if err != nil {
+ log.Fatal(err)
+ }
if reExec {
// Download and update the version of client tools required by the cluster.
// This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly.
@@ -76,13 +77,3 @@ func main() {
fmt.Printf("Teleport v%v git\n", version)
}
}
-
-// clientTools list of the client tools needs to be updated.
-func clientTools() []string {
- switch runtime.GOOS {
- case constants.WindowsOS:
- return []string{"tsh.exe", "tctl.exe"}
- default:
- return []string{"tsh", "tctl"}
- }
-}
diff --git a/integration/autoupdate/tools/updater_test.go b/integration/autoupdate/tools/updater_test.go
index 52d7a7b8ab385..d429a7483910a 100644
--- a/integration/autoupdate/tools/updater_test.go
+++ b/integration/autoupdate/tools/updater_test.go
@@ -26,7 +26,6 @@ import (
"os/exec"
"path/filepath"
"regexp"
- "runtime"
"strings"
"testing"
"time"
@@ -34,7 +33,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/autoupdate/tools"
)
@@ -51,7 +49,7 @@ func TestUpdate(t *testing.T) {
// Fetch compiled test binary with updater logic and install to $TELEPORT_HOME.
updater := tools.NewUpdater(
- clientTools(),
+ tools.DefaultClientTools(),
toolsDir,
testVersions[0],
tools.WithBaseURL(baseURL),
@@ -93,7 +91,7 @@ func TestParallelUpdate(t *testing.T) {
// Initial fetch the updater binary un-archive and replace.
updater := tools.NewUpdater(
- clientTools(),
+ tools.DefaultClientTools(),
toolsDir,
testVersions[0],
tools.WithBaseURL(baseURL),
@@ -167,7 +165,7 @@ func TestUpdateInterruptSignal(t *testing.T) {
// Initial fetch the updater binary un-archive and replace.
updater := tools.NewUpdater(
- clientTools(),
+ tools.DefaultClientTools(),
toolsDir,
testVersions[0],
tools.WithBaseURL(baseURL),
@@ -220,12 +218,3 @@ func TestUpdateInterruptSignal(t *testing.T) {
}
assert.Contains(t, output.String(), "Update progress:")
}
-
-func clientTools() []string {
- switch runtime.GOOS {
- case constants.WindowsOS:
- return []string{"tsh.exe", "tctl.exe"}
- default:
- return []string{"tsh", "tctl"}
- }
-}
diff --git a/integration/helpers/archive.go b/integration/helpers/archive/packaging.go
similarity index 99%
rename from integration/helpers/archive.go
rename to integration/helpers/archive/packaging.go
index 6e48108013d86..ee237749115a3 100644
--- a/integration/helpers/archive.go
+++ b/integration/helpers/archive/packaging.go
@@ -16,7 +16,7 @@
* along with this program. If not, see .
*/
-package helpers
+package archive
import (
"archive/tar"
diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod
index c9459a5381189..030b29ba08341 100644
--- a/integrations/terraform/go.mod
+++ b/integrations/terraform/go.mod
@@ -184,6 +184,7 @@ require (
github.com/google/go-tspi v0.3.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af // indirect
+ github.com/google/renameio/v2 v2.0.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/safetext v0.0.0-20240104143208-7a7d9b3d812f // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum
index 257c076fd1e5d..95ef579188b65 100644
--- a/integrations/terraform/go.sum
+++ b/integrations/terraform/go.sum
@@ -1252,7 +1252,6 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
-github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
diff --git a/lib/autoupdate/tools/progress.go b/lib/autoupdate/tools/progress.go
index 95395003730ec..34bbad2ff888d 100644
--- a/lib/autoupdate/tools/progress.go
+++ b/lib/autoupdate/tools/progress.go
@@ -20,24 +20,59 @@ package tools
import (
"fmt"
+ "io"
"strings"
+
+ "github.com/gravitational/trace"
)
type progressWriter struct {
- n int64
- limit int64
+ n int64
+ limit int64
+ size int
+ progress int
}
-func (w *progressWriter) Write(p []byte) (int, error) {
- w.n = w.n + int64(len(p))
+// newProgressWriter creates progress writer instance and prints empty
+// progress bar right after initialisation.
+func newProgressWriter(size int) (*progressWriter, func()) {
+ pw := &progressWriter{size: size}
+ pw.Print(0)
+ return pw, func() {
+ fmt.Print("\n")
+ }
+}
- n := int((w.n*100)/w.limit) / 10
- bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n)
+// Print prints the update progress bar with `n` bricks.
+func (w *progressWriter) Print(n int) {
+ bricks := strings.Repeat("▒", n) + strings.Repeat(" ", w.size-n)
fmt.Print("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)")
+}
- if w.n == w.limit {
- fmt.Print("\n")
+func (w *progressWriter) Write(p []byte) (int, error) {
+ if w.limit == 0 || w.size == 0 {
+ return len(p), nil
+ }
+
+ w.n += int64(len(p))
+ bricks := int((w.n*100)/w.limit) / w.size
+ if w.progress != bricks {
+ w.Print(bricks)
+ w.progress = bricks
}
return len(p), nil
}
+
+// CopyLimit sets the limit of writing bytes to the progress writer and initiate copying process.
+func (w *progressWriter) CopyLimit(dst io.Writer, src io.Reader, limit int64) (written int64, err error) {
+ if limit < 0 {
+ n, err := io.Copy(dst, io.TeeReader(src, w))
+ w.Print(w.size)
+ return n, trace.Wrap(err)
+ }
+
+ w.limit = limit
+ n, err := io.CopyN(dst, io.TeeReader(src, w), limit)
+ return n, trace.Wrap(err)
+}
diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go
index 96991044ccc31..96352e34d9910 100644
--- a/lib/autoupdate/tools/updater.go
+++ b/lib/autoupdate/tools/updater.go
@@ -36,12 +36,12 @@ import (
"syscall"
"time"
+ "github.com/coreos/go-semver/semver"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/constants"
- "github.com/gravitational/teleport/api/types/autoupdate"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/packaging"
)
@@ -114,30 +114,37 @@ func NewUpdater(tools []string, toolsDir string, localVersion string, options ..
// CheckLocal is run at client tool startup and will only perform local checks.
// Returns the version needs to be updated and re-executed, by re-execution flag we
// understand that update and re-execute is required.
-func (u *Updater) CheckLocal() (version string, reExec bool) {
+func (u *Updater) CheckLocal() (version string, reExec bool, err error) {
// Check if the user has requested a specific version of client tools.
requestedVersion := os.Getenv(teleportToolsVersionEnv)
switch requestedVersion {
// The user has turned off any form of automatic updates.
case "off":
- return "", false
+ return "", false, nil
// Requested version already the same as client version.
case u.localVersion:
- return u.localVersion, false
+ return u.localVersion, false, nil
+ // No requested version, we continue.
+ case "":
+ // Requested version that is not the local one.
+ default:
+ if _, err := semver.NewVersion(requestedVersion); err != nil {
+ return "", false, trace.Wrap(err, "checking that request version is semantic")
+ }
+ return requestedVersion, true, nil
}
// If a version of client tools has already been downloaded to
// tools directory, return that.
- toolsVersion, err := checkToolVersion(u.toolsDir)
- if err != nil {
- return "", false
+ toolsVersion, err := CheckToolVersion(u.toolsDir)
+ if trace.IsNotFound(err) {
+ return u.localVersion, false, nil
}
- // The user has requested a specific version of client tools.
- if requestedVersion != "" && requestedVersion != toolsVersion {
- return requestedVersion, true
+ if err != nil {
+ return "", false, trace.Wrap(err)
}
- return toolsVersion, false
+ return toolsVersion, true, nil
}
// CheckRemote first checks the version set by the environment variable. If not set or disabled,
@@ -145,7 +152,7 @@ func (u *Updater) CheckLocal() (version string, reExec bool) {
// the `webapi/find` handler, which stores information about the required client tools version to
// operate with this cluster. It returns the semantic version that needs updating and whether
// re-execution is necessary, by re-execution flag we understand that update and re-execute is required.
-func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version string, reExec bool, err error) {
+func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string, insecure bool) (version string, reExec bool, err error) {
// Check if the user has requested a specific version of client tools.
requestedVersion := os.Getenv(teleportToolsVersionEnv)
switch requestedVersion {
@@ -155,6 +162,14 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st
// Requested version already the same as client version.
case u.localVersion:
return u.localVersion, false, nil
+ // No requested version, we continue.
+ case "":
+ // Requested version that is not the local one.
+ default:
+ if _, err := semver.NewVersion(requestedVersion); err != nil {
+ return "", false, trace.Wrap(err, "checking that request version is semantic")
+ }
+ return requestedVersion, true, nil
}
certPool, err := x509.SystemCertPool()
@@ -165,7 +180,8 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st
Context: ctx,
ProxyAddr: proxyAddr,
Pool: certPool,
- Timeout: 30 * time.Second,
+ Timeout: 10 * time.Second,
+ Insecure: insecure,
})
if err != nil {
return "", false, trace.Wrap(err)
@@ -173,28 +189,28 @@ func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version st
// If a version of client tools has already been downloaded to
// tools directory, return that.
- toolsVersion, err := checkToolVersion(u.toolsDir)
- if err != nil {
+ toolsVersion, err := CheckToolVersion(u.toolsDir)
+ if err != nil && !trace.IsNotFound(err) {
return "", false, trace.Wrap(err)
}
switch {
- case requestedVersion != "" && requestedVersion != toolsVersion:
- return requestedVersion, true, nil
- case resp.AutoUpdate.ToolsMode != autoupdate.ToolsUpdateModeEnabled || resp.AutoUpdate.ToolsVersion == "":
- return "", false, nil
+ case !resp.AutoUpdate.ToolsAutoUpdate || resp.AutoUpdate.ToolsVersion == "":
+ if toolsVersion == "" {
+ return u.localVersion, false, nil
+ }
case u.localVersion == resp.AutoUpdate.ToolsVersion:
- return resp.AutoUpdate.ToolsVersion, false, nil
+ return u.localVersion, false, nil
case resp.AutoUpdate.ToolsVersion != toolsVersion:
return resp.AutoUpdate.ToolsVersion, true, nil
}
- return toolsVersion, false, nil
+ return toolsVersion, true, nil
}
// UpdateWithLock acquires filesystem lock, downloads requested version package,
// unarchive and replace existing one.
-func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) {
+func (u *Updater) UpdateWithLock(ctx context.Context, updateToolsVersion string) (err error) {
// Create tools directory if it does not exist.
if err := os.MkdirAll(u.toolsDir, 0o755); err != nil {
return trace.Wrap(err)
@@ -211,21 +227,20 @@ func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err
// If the version of the running binary or the version downloaded to
// tools directory is the same as the requested version of client tools,
// nothing to be done, exit early.
- teleportVersion, err := checkToolVersion(u.toolsDir)
+ toolsVersion, err := CheckToolVersion(u.toolsDir)
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
-
}
- if toolsVersion == u.localVersion || toolsVersion == teleportVersion {
+ if updateToolsVersion == toolsVersion {
return nil
}
// Download and update client tools in tools directory.
- if err := u.Update(ctx, toolsVersion); err != nil {
+ if err := u.Update(ctx, updateToolsVersion); err != nil {
return trace.Wrap(err)
}
- return
+ return nil
}
// Update downloads requested version and replace it with existing one and cleanups the previous downloads
@@ -237,10 +252,18 @@ func (u *Updater) Update(ctx context.Context, toolsVersion string) error {
return trace.Wrap(err)
}
+ var pkgNames []string
for _, pkg := range packages {
- if err := u.update(ctx, pkg); err != nil {
+ pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix)
+ if err := u.update(ctx, pkg, pkgName); err != nil {
return trace.Wrap(err)
}
+ pkgNames = append(pkgNames, pkgName)
+ }
+
+ // Cleanup the tools directory with previously downloaded and un-archived versions.
+ if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgNames); err != nil {
+ slog.WarnContext(ctx, "failed to clean up tools directory", "error", err)
}
return nil
@@ -248,16 +271,8 @@ func (u *Updater) Update(ctx context.Context, toolsVersion string) error {
// update downloads the archive and validate against the hash. Download to a
// temporary path within tools directory.
-func (u *Updater) update(ctx context.Context, pkg packageURL) error {
- hash, err := u.downloadHash(ctx, pkg.Hash)
- if pkg.Optional && trace.IsNotFound(err) {
- return nil
- }
- if err != nil {
- return trace.Wrap(err)
- }
-
- f, err := os.CreateTemp(u.toolsDir, "tmp-")
+func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) error {
+ f, err := os.CreateTemp("", "teleport-")
if err != nil {
return trace.Wrap(err)
}
@@ -275,11 +290,19 @@ func (u *Updater) update(ctx context.Context, pkg packageURL) error {
if err != nil {
return trace.Wrap(err)
}
+
+ hash, err := u.downloadHash(ctx, pkg.Hash)
+ if pkg.Optional && trace.IsNotFound(err) {
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
if !bytes.Equal(archiveHash, hash) {
return trace.BadParameter("hash of archive does not match downloaded archive")
}
- pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix)
extractDir := filepath.Join(u.toolsDir, pkgName)
if runtime.GOOS != constants.DarwinOS {
if err := os.Mkdir(extractDir, 0o755); err != nil {
@@ -291,10 +314,6 @@ func (u *Updater) update(ctx context.Context, pkg packageURL) error {
if err := packaging.ReplaceToolsBinaries(u.toolsDir, f.Name(), extractDir, u.tools); err != nil {
return trace.Wrap(err)
}
- // Cleanup the tools directory with previously downloaded and un-archived versions.
- if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgName); err != nil {
- return trace.Wrap(err)
- }
return nil
}
@@ -340,7 +359,7 @@ func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
- return nil, trace.NotFound("hash file is not found: %v", resp.StatusCode)
+ return nil, trace.NotFound("hash file is not found: %q", url)
}
if resp.StatusCode != http.StatusOK {
return nil, trace.BadParameter("bad status when downloading archive hash: %v", resp.StatusCode)
@@ -362,6 +381,11 @@ func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error)
// downloadArchive downloads the archive package by `url` and writes content to the writer interface,
// return calculated sha256 hash sum of the content.
func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) ([]byte, error) {
+ // Display a progress bar before initiating the update request to inform the user that
+ // an update is in progress, allowing them the option to cancel before actual response
+ // which might be delayed with slow internet connection or complete isolation to CDN.
+ pw, finish := newProgressWriter(10)
+ defer finish()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, trace.Wrap(err)
@@ -385,14 +409,10 @@ func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer)
}
h := sha256.New()
- pw := &progressWriter{n: 0, limit: resp.ContentLength}
- body := io.TeeReader(io.TeeReader(resp.Body, h), pw)
-
// It is a little inefficient to download the file to disk and then re-load
// it into memory to unarchive later, but this is safer as it allows client
// tools to validate the hash before trying to operate on the archive.
- _, err = io.CopyN(f, body, resp.ContentLength)
- if err != nil {
+ if _, err := pw.CopyLimit(f, io.TeeReader(resp.Body, h), resp.ContentLength); err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/autoupdate/tools/utils.go b/lib/autoupdate/tools/utils.go
index d552b31abefe4..f937d228b5cd4 100644
--- a/lib/autoupdate/tools/utils.go
+++ b/lib/autoupdate/tools/utils.go
@@ -33,6 +33,7 @@ import (
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
+ "github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/utils"
@@ -52,14 +53,26 @@ func Dir() (string, error) {
return filepath.Join(home, ".tsh", "bin"), nil
}
-func checkToolVersion(toolsDir string) (string, error) {
+// DefaultClientTools list of the client tools needs to be updated by default.
+func DefaultClientTools() []string {
+ switch runtime.GOOS {
+ case constants.WindowsOS:
+ return []string{"tsh.exe", "tctl.exe"}
+ default:
+ return []string{"tsh", "tctl"}
+ }
+}
+
+// CheckToolVersion returns current installed client tools version, must return NotFoundError if
+// the client tools is not found in tools directory.
+func CheckToolVersion(toolsDir string) (string, error) {
// Find the path to the current executable.
path, err := toolName(toolsDir)
if err != nil {
return "", trace.Wrap(err)
}
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
- return "", nil
+ return "", trace.NotFound("autoupdate tool not found in %q", toolsDir)
} else if err != nil {
return "", trace.Wrap(err)
}
@@ -76,7 +89,7 @@ func checkToolVersion(toolsDir string) (string, error) {
command.Env = []string{teleportToolsVersionEnv + "=off"}
output, err := command.Output()
if err != nil {
- return "", trace.Wrap(err)
+ return "", trace.WrapWithMessage(err, "failed to determine version of %q tool", path)
}
// The output for "{tsh, tctl} version" can be multiple lines. Find the
diff --git a/lib/client/api.go b/lib/client/api.go
index d087fb02d1e34..d8a35dc95feee 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -74,6 +74,7 @@ import (
"github.com/gravitational/teleport/lib/auth/touchid"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
"github.com/gravitational/teleport/lib/authz"
+ "github.com/gravitational/teleport/lib/autoupdate/tools"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/client/sso"
"github.com/gravitational/teleport/lib/client/terminal"
@@ -95,6 +96,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/agentconn"
logutils "github.com/gravitational/teleport/lib/utils/log"
+ "github.com/gravitational/teleport/lib/utils/signal"
)
const (
@@ -702,6 +704,39 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error,
return trace.Wrap(err)
}
+ // The user has typed a command like `tsh ssh ...` without being logged in,
+ // if the running binary needs to be updated, update and re-exec.
+ //
+ // If needed, download the new version of {tsh, tctl} and re-exec. Make
+ // sure to exit this process with the same exit code as the child process.
+ //
+ toolsDir, err := tools.Dir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version)
+ toolsVersion, reExec, err := updater.CheckRemote(ctx, tc.WebProxyAddr, tc.InsecureSkipVerify)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if reExec {
+ ctxUpdate, cancel := signal.GetSignalHandler().NotifyContext(context.Background())
+ defer cancel()
+ // Download the version of client tools required by the cluster.
+ err := updater.UpdateWithLock(ctxUpdate, toolsVersion)
+ if err != nil && !errors.Is(err, context.Canceled) {
+ utils.FatalError(err)
+ }
+ // Re-execute client tools with the correct version of client tools.
+ code, err := updater.Exec()
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ log.Debugf("Failed to re-exec client tool: %v.", err)
+ os.Exit(code)
+ } else if err == nil {
+ os.Exit(code)
+ }
+ }
+
if opt.afterLoginHook != nil {
if err := opt.afterLoginHook(); err != nil {
return trace.Wrap(err)
diff --git a/lib/utils/packaging/unarchive.go b/lib/utils/packaging/unarchive.go
index f1a197e095b1a..6496afbc182c3 100644
--- a/lib/utils/packaging/unarchive.go
+++ b/lib/utils/packaging/unarchive.go
@@ -38,13 +38,13 @@ const (
)
// RemoveWithSuffix removes all that matches the provided suffix, except for file or directory with `skipName`.
-func RemoveWithSuffix(dir, suffix, skipName string) error {
+func RemoveWithSuffix(dir, suffix string, skipNames []string) error {
var removePaths []string
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return trace.Wrap(err)
}
- if skipName == info.Name() {
+ if slices.Contains(skipNames, info.Name()) {
return nil
}
if !strings.HasSuffix(info.Name(), suffix) {
@@ -59,12 +59,13 @@ func RemoveWithSuffix(dir, suffix, skipName string) error {
if err != nil {
return trace.Wrap(err)
}
+ var aggErr []error
for _, path := range removePaths {
if err := os.RemoveAll(path); err != nil {
- return trace.Wrap(err)
+ aggErr = append(aggErr, err)
}
}
- return nil
+ return trace.NewAggregate(aggErr...)
}
// replaceZip un-archives the Teleport package in .zip format, iterates through
@@ -118,7 +119,7 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName
defer file.Close()
dest := filepath.Join(extractDir, baseName)
- destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
+ destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return trace.Wrap(err)
}
@@ -131,7 +132,10 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName
if err := os.Remove(appPath); err != nil && !os.IsNotExist(err) {
return trace.Wrap(err)
}
- if err := os.Symlink(dest, appPath); err != nil {
+ // For the Windows build we have to copy binary to be able
+ // to do this without administrative access as it required
+ // for symlinks.
+ if err := utils.CopyFile(dest, appPath, 0o755); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(destFile.Close())
diff --git a/lib/utils/packaging/unarchive_test.go b/lib/utils/packaging/unarchive_test.go
index 30933bbb75927..b124b603b0fd5 100644
--- a/lib/utils/packaging/unarchive_test.go
+++ b/lib/utils/packaging/unarchive_test.go
@@ -30,7 +30,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/gravitational/teleport/integration/helpers"
+ "github.com/gravitational/teleport/integration/helpers/archive"
)
// TestPackaging verifies un-archiving of all supported teleport package formats.
@@ -63,7 +63,7 @@ func TestPackaging(t *testing.T) {
t.Run("tar.gz", func(t *testing.T) {
archivePath := filepath.Join(toolsDir, "tsh.tar.gz")
- err = helpers.CompressDirToTarGzFile(ctx, sourceDir, archivePath)
+ err = archive.CompressDirToTarGzFile(ctx, sourceDir, archivePath)
require.NoError(t, err)
require.FileExists(t, archivePath, "archive not created")
@@ -85,7 +85,7 @@ func TestPackaging(t *testing.T) {
t.Skip("unsupported platform")
}
archivePath := filepath.Join(toolsDir, "tsh.pkg")
- err = helpers.CompressDirToPkgFile(ctx, sourceDir, archivePath, "com.example.pkgtest")
+ err = archive.CompressDirToPkgFile(ctx, sourceDir, archivePath, "com.example.pkgtest")
require.NoError(t, err)
require.FileExists(t, archivePath, "archive not created")
@@ -101,7 +101,7 @@ func TestPackaging(t *testing.T) {
t.Run("zip", func(t *testing.T) {
archivePath := filepath.Join(toolsDir, "tsh.zip")
- err = helpers.CompressDirToZipFile(ctx, sourceDir, archivePath)
+ err = archive.CompressDirToZipFile(ctx, sourceDir, archivePath)
require.NoError(t, err)
require.FileExists(t, archivePath, "archive not created")
@@ -132,7 +132,7 @@ func TestRemoveWithSuffix(t *testing.T) {
dirInSkipPath := filepath.Join(skipPath, dirForRemove)
require.NoError(t, os.MkdirAll(skipPath, 0o755))
- err := RemoveWithSuffix(testDir, dirForRemove, skipName)
+ err := RemoveWithSuffix(testDir, dirForRemove, []string{skipName})
require.NoError(t, err)
_, err = os.Stat(filepath.Join(testDir, dirForRemove))
diff --git a/lib/utils/packaging/unarchive_unix.go b/lib/utils/packaging/unarchive_unix.go
index ea51afdbbc7f0..8daf1b3aa5525 100644
--- a/lib/utils/packaging/unarchive_unix.go
+++ b/lib/utils/packaging/unarchive_unix.go
@@ -186,9 +186,10 @@ func replacePkg(toolsDir string, archivePath string, extractDir string, execName
// swap operations. This ensures that the "com.apple.macl" extended
// attribute is set and macOS will not send a SIGKILL to the process
// if multiple processes are trying to operate on it.
- command := exec.Command(path, "version", "--client")
+ command := exec.Command(path, "version")
+ command.Env = []string{"TELEPORT_TOOLS_VERSION=off"}
if err := command.Run(); err != nil {
- return trace.Wrap(err)
+ return trace.WrapWithMessage(err, "failed to validate binary")
}
// Due to macOS applications not being a single binary (they are a
diff --git a/lib/utils/signal/stack_handler.go b/lib/utils/signal/stack_handler.go
new file mode 100644
index 0000000000000..0fcbefb081d11
--- /dev/null
+++ b/lib/utils/signal/stack_handler.go
@@ -0,0 +1,98 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package signal
+
+import (
+ "container/list"
+ "context"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+)
+
+// Handler implements stack for context cancellation.
+type Handler struct {
+ mu sync.Mutex
+ list *list.List
+}
+
+var handler = &Handler{
+ list: list.New(),
+}
+
+// GetSignalHandler returns global singleton instance of signal
+func GetSignalHandler() *Handler {
+ return handler
+}
+
+// NotifyContext creates context which going to be canceled after SIGINT, SIGTERM
+// in order of adding them to the stack. When very first context is canceled
+// we stop watching the OS signals.
+func (s *Handler) NotifyContext(parent context.Context) (context.Context, context.CancelFunc) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.list.Len() == 0 {
+ s.listenSignals()
+ }
+
+ ctx, cancel := context.WithCancel(parent)
+ element := s.list.PushBack(cancel)
+
+ return ctx, func() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.list.Remove(element)
+ cancel()
+ }
+}
+
+// listenSignals sets up the signal listener for SIGINT, SIGTERM.
+func (s *Handler) listenSignals() {
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+ go func() {
+ for {
+ if sig := <-sigChan; sig == nil {
+ return
+ }
+ if !s.cancelNext() {
+ signal.Stop(sigChan)
+ return
+ }
+ }
+ }()
+}
+
+// cancelNext calls the most recent cancel func in the stack.
+func (s *Handler) cancelNext() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.list.Len() > 0 {
+ cancel := s.list.Remove(s.list.Back())
+ if cancel != nil {
+ cancel.(context.CancelFunc)()
+ }
+ }
+
+ return s.list.Len() != 0
+}
diff --git a/lib/utils/signal/stack_handler_test.go b/lib/utils/signal/stack_handler_test.go
new file mode 100644
index 0000000000000..b900939b886e8
--- /dev/null
+++ b/lib/utils/signal/stack_handler_test.go
@@ -0,0 +1,88 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package signal
+
+import (
+ "context"
+ "os"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestGetSignalHandler verifies the cancellation stack order.
+func TestGetSignalHandler(t *testing.T) {
+ testHandler := GetSignalHandler()
+ parent := context.Background()
+
+ ctx1, cancel1 := testHandler.NotifyContext(parent)
+ ctx2, cancel2 := testHandler.NotifyContext(parent)
+ ctx3, cancel3 := testHandler.NotifyContext(parent)
+ ctx4, cancel4 := testHandler.NotifyContext(parent)
+ t.Cleanup(func() {
+ cancel4()
+ cancel2()
+ cancel1()
+ })
+
+ // Verify that all context not canceled.
+ require.NoError(t, ctx4.Err())
+ require.NoError(t, ctx3.Err())
+ require.NoError(t, ctx2.Err())
+ require.NoError(t, ctx1.Err())
+
+ // Cancel context manually to ensure it was removed from stack in right order.
+ cancel3()
+
+ // Check that last added context is canceled.
+ require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT))
+ select {
+ case <-ctx4.Done():
+ assert.ErrorIs(t, ctx3.Err(), context.Canceled)
+ assert.NoError(t, ctx2.Err())
+ assert.NoError(t, ctx1.Err())
+ case <-time.After(time.Second):
+ assert.Fail(t, "context 3 must be canceled")
+ }
+
+ // Send interrupt signal to cancel next context in the stack.
+ require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT))
+ select {
+ case <-ctx2.Done():
+ assert.ErrorIs(t, ctx4.Err(), context.Canceled)
+ assert.ErrorIs(t, ctx3.Err(), context.Canceled)
+ assert.NoError(t, ctx1.Err())
+ case <-time.After(time.Second):
+ assert.Fail(t, "context 2 must be canceled")
+ }
+
+ // All context must be canceled.
+ require.NoError(t, syscall.Kill(os.Getpid(), syscall.SIGINT))
+ select {
+ case <-ctx1.Done():
+ assert.ErrorIs(t, ctx4.Err(), context.Canceled)
+ assert.ErrorIs(t, ctx3.Err(), context.Canceled)
+ assert.ErrorIs(t, ctx2.Err(), context.Canceled)
+ case <-time.After(time.Second):
+ assert.Fail(t, "context 1 must be canceled")
+ }
+}
diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go
index 2bbb2bcbdd561..cc309efe1043f 100644
--- a/lib/web/apiserver.go
+++ b/lib/web/apiserver.go
@@ -1593,7 +1593,7 @@ func (h *Handler) automaticUpdateSettings(ctx context.Context) webclient.AutoUpd
}
return webclient.AutoUpdateSettings{
- ToolsMode: getToolsMode(autoUpdateConfig),
+ ToolsAutoUpdate: getToolsAutoUpdate(autoUpdateConfig),
ToolsVersion: getToolsVersion(autoUpdateVersion),
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentVersion: getAgentVersion(autoUpdateVersion),
@@ -5154,15 +5154,15 @@ func readEtagFromAppHash(fs http.FileSystem) (string, error) {
return etag, nil
}
-func getToolsMode(config *autoupdatepb.AutoUpdateConfig) string {
+func getToolsAutoUpdate(config *autoupdatepb.AutoUpdateConfig) bool {
// If we can't get the AU config or if AUs are not configured, we default to "disabled".
// This ensures we fail open and don't accidentally update agents if something is going wrong.
// If we want to enable AUs by default, it would be better to create a default "autoupdate_config" resource
// than changing this logic.
- if config.GetSpec().GetTools() == nil {
- return autoupdate.ToolsUpdateModeDisabled
+ if config.GetSpec().GetTools() != nil {
+ return config.GetSpec().GetTools().GetMode() == autoupdate.ToolsUpdateModeEnabled
}
- return config.GetSpec().GetTools().GetMode()
+ return false
}
func getToolsVersion(version *autoupdatepb.AutoUpdateVersion) string {
diff --git a/lib/web/apiserver_ping_test.go b/lib/web/apiserver_ping_test.go
index 5ce3720375c46..2bf325d4f7902 100644
--- a/lib/web/apiserver_ping_test.go
+++ b/lib/web/apiserver_ping_test.go
@@ -306,7 +306,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
name: "resources not defined",
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
AgentVersion: api.Version,
@@ -320,7 +320,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
},
expected: webclient.AutoUpdateSettings{
- ToolsMode: autoupdate.ToolsUpdateModeEnabled,
+ ToolsAutoUpdate: true,
ToolsVersion: api.Version,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
@@ -346,7 +346,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: true,
AgentVersion: "1.2.4",
@@ -371,7 +371,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
AgentVersion: "1.2.4",
@@ -384,7 +384,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
version: &autoupdatev1pb.AutoUpdateVersionSpec{},
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
AgentVersion: api.Version,
@@ -400,7 +400,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
expected: webclient.AutoUpdateSettings{
ToolsVersion: "1.2.3",
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
AgentVersion: api.Version,
@@ -420,7 +420,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
},
expected: webclient.AutoUpdateSettings{
- ToolsMode: autoupdate.ToolsUpdateModeEnabled,
+ ToolsAutoUpdate: true,
ToolsVersion: "1.2.3",
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
@@ -440,7 +440,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
},
expected: webclient.AutoUpdateSettings{
- ToolsMode: autoupdate.ToolsUpdateModeDisabled,
+ ToolsAutoUpdate: false,
ToolsVersion: "3.2.1",
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
AgentAutoUpdate: false,
diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go
index 4a85be6e934b0..51ff5f8687f75 100644
--- a/tool/tctl/common/tctl.go
+++ b/tool/tctl/common/tctl.go
@@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/storage"
+ "github.com/gravitational/teleport/lib/autoupdate/tools"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/identityfile"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
@@ -54,6 +55,7 @@ import (
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/lib/utils/signal"
"github.com/gravitational/teleport/tool/common"
)
@@ -104,8 +106,43 @@ type CLICommand interface {
// "distributions" like OSS or Enterprise
//
// distribution: name of the Teleport distribution
-func Run(commands []CLICommand) {
- err := TryRun(commands, os.Args[1:])
+func Run(ctx context.Context, commands []CLICommand) {
+ // The user has typed a command like `tsh ssh ...` without being logged in,
+ // if the running binary needs to be updated, update and re-exec.
+ //
+ // If needed, download the new version of {tsh, tctl} and re-exec. Make
+ // sure to exit this process with the same exit code as the child process.
+ //
+ toolsDir, err := tools.Dir()
+ if err != nil {
+ utils.FatalError(err)
+ }
+ updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version)
+ toolsVersion, reExec, err := updater.CheckLocal()
+ if err != nil {
+ utils.FatalError(err)
+ }
+ if reExec {
+ ctxUpdate, cancel := signal.GetSignalHandler().NotifyContext(ctx)
+ defer cancel()
+ // Download the version of client tools required by the cluster. This
+ // is required if the user passed in the TELEPORT_TOOLS_VERSION
+ // explicitly.
+ err := updater.UpdateWithLock(ctxUpdate, toolsVersion)
+ if err != nil && !errors.Is(err, context.Canceled) {
+ utils.FatalError(err)
+ }
+ // Re-execute client tools with the correct version of client tools.
+ code, err := updater.Exec()
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ log.Debugf("Failed to re-exec client tool: %v.", err)
+ os.Exit(code)
+ } else if err == nil {
+ os.Exit(code)
+ }
+ }
+
+ err = TryRun(commands, os.Args[1:])
if err != nil {
var exitError *common.ExitCodeError
if errors.As(err, &exitError) {
diff --git a/tool/tctl/main.go b/tool/tctl/main.go
index f363e347f25c9..6dfae87ffdef2 100644
--- a/tool/tctl/main.go
+++ b/tool/tctl/main.go
@@ -19,9 +19,15 @@
package main
import (
+ "context"
+
+ "github.com/gravitational/teleport/lib/utils/signal"
"github.com/gravitational/teleport/tool/tctl/common"
)
func main() {
- common.Run(common.Commands())
+ ctx, cancel := signal.GetSignalHandler().NotifyContext(context.Background())
+ defer cancel()
+
+ common.Run(ctx, common.Commands())
}
diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go
index 1adabe7b337c1..a1b0edc1d15c3 100644
--- a/tool/tsh/common/tsh.go
+++ b/tool/tsh/common/tsh.go
@@ -41,7 +41,6 @@ import (
"strconv"
"strings"
"sync"
- "syscall"
"time"
"github.com/alecthomas/kingpin/v2"
@@ -74,6 +73,7 @@ import (
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/auth/authclient"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
+ "github.com/gravitational/teleport/lib/autoupdate/tools"
"github.com/gravitational/teleport/lib/benchmark"
benchmarkdb "github.com/gravitational/teleport/lib/benchmark/db"
"github.com/gravitational/teleport/lib/client"
@@ -94,6 +94,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/diagnostics/latency"
"github.com/gravitational/teleport/lib/utils/mlock"
+ stacksignal "github.com/gravitational/teleport/lib/utils/signal"
"github.com/gravitational/teleport/tool/common"
"github.com/gravitational/teleport/tool/common/fido2"
"github.com/gravitational/teleport/tool/common/touchid"
@@ -609,7 +610,7 @@ func Main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
- ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
+ ctx, cancel := stacksignal.GetSignalHandler().NotifyContext(context.Background())
defer cancel()
// lets see: if the executable name is 'ssh' or 'scp' we convert
@@ -707,6 +708,38 @@ func initLogger(cf *CLIConf) {
//
// DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions).
func Run(ctx context.Context, args []string, opts ...CliOption) error {
+ // At process startup, check if a version has already been downloaded to
+ // $TELEPORT_HOME/bin or if the user has set the TELEPORT_TOOLS_VERSION
+ // environment variable. If so, re-exec that version of {tsh, tctl}.
+ toolsDir, err := tools.Dir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version)
+ toolsVersion, reExec, err := updater.CheckLocal()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if reExec {
+ ctxUpdate, cancel := stacksignal.GetSignalHandler().NotifyContext(ctx)
+ defer cancel()
+ // Download the version of client tools required by the cluster. This
+ // is required if the user passed in the TELEPORT_TOOLS_VERSION
+ // explicitly.
+ err := updater.UpdateWithLock(ctxUpdate, toolsVersion)
+ if err != nil && !errors.Is(err, context.Canceled) {
+ return trace.Wrap(err)
+ }
+ // Re-execute client tools with the correct version of client tools.
+ code, err := updater.Exec()
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ log.Debugf("Failed to re-exec client tool: %v.", err)
+ os.Exit(code)
+ } else if err == nil {
+ os.Exit(code)
+ }
+ }
+
cf := CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
@@ -1239,8 +1272,6 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
bench.Hidden()
}
- var err error
-
cf.executablePath, err = os.Executable()
if err != nil {
return trace.Wrap(err)
@@ -1866,6 +1897,14 @@ func onLogin(cf *CLIConf) error {
}
tc.HomePath = cf.HomePath
+ // The user is not logged in and has typed in `tsh --proxy=... login`, if
+ // the running binary needs to be updated, update and re-exec.
+ if profile == nil {
+ if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
// client is already logged in and profile is not expired
if profile != nil && !profile.IsExpired(time.Now()) {
switch {
@@ -1876,6 +1915,13 @@ func onLogin(cf *CLIConf) error {
// current status
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "" ||
host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "":
+
+ // The user has typed `tsh login`, if the running binary needs to
+ // be updated, update and re-exec.
+ if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil {
+ return trace.Wrap(err)
+ }
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1889,6 +1935,13 @@ func onLogin(cf *CLIConf) error {
// if the proxy names match but nothing else is specified; show motd and update active profile and kube configs
case host(cf.Proxy) == host(profile.ProxyURL.Host) &&
cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "":
+
+ // The user has typed `tsh login`, if the running binary needs to
+ // be updated, update and re-exec.
+ if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil {
+ return trace.Wrap(err)
+ }
+
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
@@ -1959,7 +2012,11 @@ func onLogin(cf *CLIConf) error {
// otherwise just pass through to standard login
default:
-
+ // The user is logged in and has typed in `tsh --proxy=... login`, if
+ // the running binary needs to be updated, update and re-exec.
+ if err := updateAndRun(cf.Context, tc.WebProxyAddr, tc.InsecureSkipVerify); err != nil {
+ return trace.Wrap(err)
+ }
}
}
@@ -5552,6 +5609,43 @@ const (
"https://goteleport.com/docs/access-controls/guides/headless/#troubleshooting"
)
+func updateAndRun(ctx context.Context, proxy string, insecure bool) error {
+ // The user has typed a command like `tsh ssh ...` without being logged in,
+ // if the running binary needs to be updated, update and re-exec.
+ //
+ // If needed, download the new version of {tsh, tctl} and re-exec. Make
+ // sure to exit this process with the same exit code as the child process.
+ //
+ toolsDir, err := tools.Dir()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ updater := tools.NewUpdater(tools.DefaultClientTools(), toolsDir, teleport.Version)
+ toolsVersion, reExec, err := updater.CheckRemote(ctx, proxy, insecure)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if reExec {
+ ctxUpdate, cancel := stacksignal.GetSignalHandler().NotifyContext(context.Background())
+ defer cancel()
+ // Download the version of client tools required by the cluster.
+ err := updater.UpdateWithLock(ctxUpdate, toolsVersion)
+ if err != nil && !errors.Is(err, context.Canceled) {
+ return trace.Wrap(err)
+ }
+ // Re-execute client tools with the correct version of client tools.
+ code, err := updater.Exec()
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ log.Debugf("Failed to re-exec client tool: %v.", err)
+ os.Exit(code)
+ } else if err == nil {
+ os.Exit(code)
+ }
+ }
+
+ return nil
+}
+
// Lock the process memory to prevent rsa keys and certificates in memory from being exposed in a swap.
func tryLockMemory(cf *CLIConf) error {
if cf.MlockMode == mlockModeAuto {