diff --git a/lib/client/api.go b/lib/client/api.go index 012875f09b73c..5496769c47115 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2411,7 +2411,7 @@ func PlayFile(ctx context.Context, filename, sid string, speed float64, skipIdle } // SFTP securely copies files between Nodes or SSH servers using SFTP -func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) { +func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination string, opts sftp.Options) (err error) { ctx, span := tc.Tracer.Start( ctx, "teleportClient/SFTP", @@ -2419,104 +2419,61 @@ func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opt ) defer span.End() - if len(args) < 2 { - return trace.Errorf("local and remote destinations are required") - } - first := args[0] - last := args[len(args)-1] + isDownload := strings.ContainsRune(source[0], ':') + isUpload := strings.ContainsRune(destination, ':') - // local copy? - if !isRemoteDest(first) && !isRemoteDest(last) { + if !isUpload && !isDownload { return trace.BadParameter("no remote destination specified") } - var config *sftpConfig - if isRemoteDest(last) { - config, err = tc.uploadConfig(args, port, opts) - if err != nil { - return trace.Wrap(err) - } - } else { - config, err = tc.downloadConfig(args, port, opts) - if err != nil { - return trace.Wrap(err) - } - } - if config.hostLogin == "" { - config.hostLogin = tc.Config.HostLogin - } - - if !quiet { - config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter { - return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), tc.Stdout) - } - } - - return trace.Wrap(tc.TransferFiles(ctx, config.hostLogin, config.addr, config.cfg)) -} - -type sftpConfig struct { - cfg *sftp.Config - addr string - hostLogin string -} - -func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { - // args are guaranteed to have len(args) > 1 - srcPaths := args[:len(args)-1] - // copy everything except the last arg (the destination) - dstPath := args[len(args)-1] - - dst, addr, err := getSFTPDestination(dstPath, port) + clt, err := tc.ConnectToCluster(ctx) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts) + defer clt.Close() + + // Respect any proxy templates and attempt host resolution. + resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{}) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - return &sftpConfig{ - cfg: cfg, - addr: addr, - hostLogin: dst.Login, - }, nil -} - -func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { - if len(args) > 2 { - return nil, trace.BadParameter("only one source file is supported when downloading files") + switch len(resolvedNodes) { + case 1: + case 0: + return trace.NotFound("no matching hosts found") + default: + return trace.BadParameter("multiple matching hosts found") } - // args are guaranteed to have len(args) > 1 - src, addr, err := getSFTPDestination(args[0], port) - if err != nil { - return nil, trace.Wrap(err) - } - cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts) - if err != nil { - return nil, trace.Wrap(err) + var cfg *sftp.Config + switch { + case isDownload: + dest, err := sftp.ParseDestination(source[0]) + if err != nil { + return trace.Wrap(err) + } + cfg, err = sftp.CreateDownloadConfig(dest.Path, destination, opts) + if err != nil { + return trace.Wrap(err) + } + case isUpload: + dest, err := sftp.ParseDestination(destination) + if err != nil { + return trace.Wrap(err) + } + cfg, err = sftp.CreateUploadConfig(source, dest.Path, opts) + if err != nil { + return trace.Wrap(err) + } } - return &sftpConfig{ - cfg: cfg, - addr: addr, - hostLogin: src.Login, - }, nil -} - -func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) { - dest, err = sftp.ParseDestination(target) - if err != nil { - return nil, "", trace.Wrap(err) - } - addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port)) - return dest, addr, nil + return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg)) } // TransferFiles copies files between the current machine and the // specified Node using the supplied config -func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr string, cfg *sftp.Config) error { +func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, hostLogin, nodeAddr string, cfg *sftp.Config) error { ctx, span := tc.Tracer.Start( ctx, "teleportClient/TransferFiles", @@ -2531,16 +2488,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr return trace.BadParameter("node address is not specified") } - if !tc.Config.ProxySpecified() { - return trace.BadParameter("proxy server is not specified") - } - clt, err := tc.ConnectToCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - defer clt.Close() - - client, err := tc.ConnectToNode( + nodeClient, err := tc.ConnectToNode( ctx, clt, NodeDetails{ @@ -2554,11 +2502,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr return trace.Wrap(err) } - return trace.Wrap(client.TransferFiles(ctx, cfg)) -} - -func isRemoteDest(name string) bool { - return strings.ContainsRune(name, ':') + return trace.Wrap(nodeClient.TransferFiles(ctx, cfg)) } // ListNodesWithFilters returns all nodes that match the filters in the current cluster diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index 3ad2469493a38..77bd23f3cda5a 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -20,6 +20,7 @@ package sftp import ( + "cmp" "context" "errors" "fmt" @@ -51,6 +52,10 @@ type Options struct { // PreserveAttrs preserves access and modification times // from the original file PreserveAttrs bool + // Quiet indicates whether progress should be displayed. + Quiet bool + // ProgressWriter is used to write the progress output. + ProgressWriter io.Writer } // Config describes the settings of a file transfer @@ -228,6 +233,12 @@ func (c *Config) setDefaults() { "PreserveAttrs": c.opts.PreserveAttrs, }, }) + + if !c.opts.Quiet { + c.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter { + return NewProgressBar(fileInfo.Size(), fileInfo.Name(), cmp.Or(c.opts.ProgressWriter, io.Writer(os.Stdout))) + } + } } // TransferFiles transfers files from the configured source paths to the diff --git a/lib/teleterm/clusters/cluster_file_transfer.go b/lib/teleterm/clusters/cluster_file_transfer.go index c55525123ddae..5047476c1f683 100644 --- a/lib/teleterm/clusters/cluster_file_transfer.go +++ b/lib/teleterm/clusters/cluster_file_transfer.go @@ -29,13 +29,14 @@ import ( "github.com/gravitational/trace" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" + "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/teleterm/api/uri" ) type FileTransferProgressSender = func(progress *api.FileTransferProgress) error -func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error { +func (c *Cluster) TransferFile(ctx context.Context, clt *client.ClusterClient, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error { config, err := getSftpConfig(request) if err != nil { return trace.Wrap(err) @@ -54,7 +55,7 @@ func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferReq } err = AddMetadataToRetryableError(ctx, func() error { - err := c.clusterClient.TransferFiles(ctx, request.GetLogin(), serverUUID+":0", config) + err := c.clusterClient.TransferFiles(ctx, clt, request.GetLogin(), serverUUID+":0", config) if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) { return trace.Errorf("transferring directories through Teleport Connect is not supported at the moment, please use tsh scp -r") } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index dec9cbc08b166..19724daceb6b3 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -923,7 +923,12 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq return trace.Wrap(err) } - return cluster.TransferFile(ctx, request, sendProgress) + clt, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return trace.Wrap(err) + } + + return cluster.TransferFile(ctx, clt, request, sendProgress) } // CreateConnectMyComputerRole creates a role which allows access to nodes with the label diff --git a/lib/web/files.go b/lib/web/files.go index e43c341d70022..53248258dd034 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -147,7 +147,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID) } - err = tc.TransferFiles(ctx, req.login, req.serverID+":0", cfg) + cl, err := tc.ConnectToCluster(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer cl.Close() + + err = tc.TransferFiles(ctx, cl, req.login, req.serverID+":0", cfg) if err != nil { if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) { return nil, trace.Errorf("transferring directories through the Web UI is not supported at the moment, please use tsh scp -r") diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index a3523876116e4..ac3c0a5f7a779 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -3928,6 +3928,10 @@ func onJoin(cf *CLIConf) error { // onSCP executes 'tsh scp' command func onSCP(cf *CLIConf) error { + if len(cf.CopySpec) < 2 { + return trace.Errorf("local and remote destinations are required") + } + tc, err := makeClient(cf) if err != nil { return trace.Wrap(err) @@ -3940,13 +3944,27 @@ func onSCP(cf *CLIConf) error { cf.Context = ctx defer cancel() - opts := sftp.Options{ - Recursive: cf.RecursiveCopy, - PreserveAttrs: cf.PreserveAttrs, + executor := client.RetryWithRelogin + if !cf.Relogin { + executor = func(ctx context.Context, teleportClient *client.TeleportClient, f func() error, option ...client.RetryWithReloginOption) error { + return f() + } } - err = client.RetryWithRelogin(cf.Context, tc, func() error { - return tc.SFTP(cf.Context, cf.CopySpec, int(cf.NodePort), opts, cf.Quiet) + + err = executor(cf.Context, tc, func() error { + return trace.Wrap(tc.SFTP( + cf.Context, + cf.CopySpec[:len(cf.CopySpec)-1], + cf.CopySpec[len(cf.CopySpec)-1], + sftp.Options{ + Recursive: cf.RecursiveCopy, + PreserveAttrs: cf.PreserveAttrs, + Quiet: cf.Quiet, + ProgressWriter: cf.Stdout(), + }, + )) }) + // don't print context canceled errors to the user if err == nil || errors.Is(err, context.Canceled) { return nil @@ -4048,14 +4066,20 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err } else if cf.CopySpec != nil { for _, location := range cf.CopySpec { // Extract username and host from "username@host:file/path" - parts := strings.Split(location, ":") - parts = strings.Split(parts[0], "@") - partsLength := len(parts) - if partsLength > 1 { - hostLogin = strings.Join(parts[:partsLength-1], "@") - hostUser = parts[partsLength-1] - break + userHost, _, found := strings.Cut(location, ":") + if !found { + continue + } + + login, hostname, found := strings.Cut(userHost, "@") + if found { + hostLogin = login + hostUser = hostname + } else { + hostUser = userHost } + break + } } diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 1bf5bd688ff66..1525282057d12 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -6832,3 +6832,294 @@ func TestVersionCompatibilityFlags(t *testing.T) { require.NoError(t, err, output) require.Equal(t, "Teleport CLI", string(bytes.TrimSpace(output))) } + +// TestSCP validates that tsh scp correctly copy file content while also +// ensuring that proxy templates are respected. +func TestSCP(t *testing.T) { + modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + accessRoleName := "access" + sshHostname := "test-ssh-server" + + accessUser, err := types.NewUser(accessRoleName) + require.NoError(t, err) + accessUser.SetRoles([]string{accessRoleName}) + + user, err := user.Current() + require.NoError(t, err) + accessUser.SetLogins([]string{user.Username}) + + traits := map[string][]string{ + constants.TraitLogins: {user.Username}, + } + accessUser.SetTraits(traits) + + connector := mockConnector(t) + rootServerOpts := []testserver.TestServerOptFunc{ + testserver.WithBootstrap(connector, accessUser), + testserver.WithHostname(sshHostname), + testserver.WithClusterName(t, "root"), + testserver.WithSSHPublicAddrs("127.0.0.1:0"), + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.SSH.Enabled = true + cfg.SSH.PublicAddrs = []utils.NetAddr{cfg.SSH.Addr} + cfg.SSH.DisableCreateHostUser = true + cfg.SSH.Labels = map[string]string{ + "animal": "llama", + "env": "dev", + } + }), + } + rootServer := testserver.MakeTestServer(t, rootServerOpts...) + + // Create a second server to test ambiguous matching. + testserver.MakeTestServer(t, + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.SetAuthServerAddresses(rootServer.Config.AuthServerAddresses()) + cfg.Hostname = "second-node" + cfg.Auth.Enabled = false + cfg.Proxy.Enabled = false + cfg.SSH.Enabled = true + cfg.SSH.DisableCreateHostUser = true + cfg.SSH.Labels = map[string]string{ + "animal": "shark", + "env": "dev", + } + })) + + rootProxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, err := rootServer.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + if !assert.NoError(t, err) || !assert.Len(t, found, 2) { + return + } + }, 10*time.Second, 100*time.Millisecond) + + tmpHomePath := t.TempDir() + rootAuth := rootServer.GetAuthServer() + + err = Run(ctx, []string{ + "login", + "--insecure", + "--proxy", rootProxyAddr.String(), + "--user", user.Username, + }, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth, accessUser, connector.GetName())) + require.NoError(t, err) + + sourceFile1 := filepath.Join(t.TempDir(), "source-file") + expectedFile1 := []byte{6, 7, 8, 9, 0} + require.NoError(t, os.WriteFile(sourceFile1, expectedFile1, 0o644)) + + sourceFile2 := filepath.Join(t.TempDir(), "source-file2") + expectedFile2 := []byte{1, 2, 3, 4, 5} + require.NoError(t, os.WriteFile(sourceFile2, expectedFile2, 0o644)) + + targetFile1 := uuid.NewString() + + createFile := func(t *testing.T, dir, file string) { + t.Helper() + f, err := os.CreateTemp(dir, file) + require.NoError(t, err) + require.NoError(t, f.Close()) + } + + tests := []struct { + name string + source []string + destination func(t *testing.T, dir string) string + assertion require.ErrorAssertionFunc + expected map[string][]byte + }{ + { + name: "no paths provided", + destination: func(*testing.T, string) string { return "" }, + assertion: func(tt require.TestingT, err error, i ...any) { + require.Error(tt, err, i...) + require.ErrorContains(tt, err, "local and remote destinations are required", i...) + }, + }, + { + name: "source resolved without using templates", + source: []string{sshHostname + ":" + sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "source resolved via predicate from template", + source: []string{"2.3.4.5:" + sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "source resolved via search from template", + source: []string{"llama.example.com:" + sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "source no matching host", + source: []string{"asdf.example.com:" + sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return filepath.Join(dir, targetFile1) + }, + assertion: func(tt require.TestingT, err error, i ...any) { + require.Error(tt, err, i...) + require.ErrorContains(tt, err, "no matching hosts", i...) + }, + }, + { + name: "source multiple matching hosts", + source: []string{"dev.example.com:" + sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return filepath.Join(dir, targetFile1) + }, + assertion: func(tt require.TestingT, err error, i ...any) { + require.Error(tt, err, i...) + require.ErrorContains(tt, err, "multiple matching hosts", i...) + }, + }, + { + name: "destination resolved without using templates", + source: []string{sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return sshHostname + ":" + filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "destination resolved via predicate from template", + source: []string{sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return "2.3.4.5:" + filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "destination resolved via search from template", + source: []string{sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return "llama.example.com:" + filepath.Join(dir, targetFile1) + }, + assertion: require.NoError, + expected: map[string][]byte{ + targetFile1: expectedFile1, + }, + }, + { + name: "destination no matching host", + source: []string{sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return "asdf.example.com:" + filepath.Join(dir, targetFile1) + }, + assertion: func(tt require.TestingT, err error, i ...any) { + require.Error(tt, err, i...) + require.ErrorContains(tt, err, "no matching hosts", i...) + }, + }, + { + name: "destination multiple matching hosts", + source: []string{sourceFile1}, + destination: func(t *testing.T, dir string) string { + createFile(t, dir, targetFile1) + return "dev.example.com:" + filepath.Join(dir, targetFile1) + }, + assertion: func(tt require.TestingT, err error, i ...any) { + require.Error(tt, err, i...) + require.ErrorContains(tt, err, "multiple matching hosts", i...) + }, + }, + { + name: "upload multiple files", + source: []string{sourceFile1, sourceFile2}, + destination: func(t *testing.T, dir string) string { return "llama.example.com:" + dir }, + assertion: require.NoError, + expected: map[string][]byte{ + filepath.Base(sourceFile1): expectedFile1, + filepath.Base(sourceFile2): expectedFile2, + }, + }, + } + + for _, test := range tests { + test := test + ctx := context.Background() + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + outputDir := t.TempDir() + destination := test.destination(t, outputDir) + + args := []string{"scp", "-d", "--no-resume", "--insecure", "-q"} + args = append(args, test.source...) + args = append(args, destination) + err := Run(ctx, + args, + setHomePath(tmpHomePath), + setTSHConfig(client.TSHConfig{ + ProxyTemplates: client.ProxyTemplates{ + { + Template: `^([0-9\.]+):\d+$`, + Query: `labels["animal"] == "llama"`, + }, + { + Template: `^(.*).example.com:\d+$`, + Search: "$1", + }, + }, + }), + func(conf *CLIConf) error { + // Relogin is disabled since some of the error cases return a + // BadParameter error which triggers the re-authentication flow + // and may result in a different authentication related error + // being returned instead of the expected errors. + conf.Relogin = false + return nil + }, + ) + test.assertion(t, err) + if err != nil { + return + } + + for file, expected := range test.expected { + got, err := os.ReadFile(filepath.Join(outputDir, file)) + require.NoError(t, err) + require.Equal(t, expected, got) + } + }) + } +}