Skip to content

Commit

Permalink
Add host resolution support to tsh scp (#48465)
Browse files Browse the repository at this point in the history
Extends tsh scp functionality to honor any defined proxy templates
when resolving the remote host.

Closes #45465
  • Loading branch information
rosstimothy authored Nov 19, 2024
1 parent bd21289 commit e92dd31
Show file tree
Hide file tree
Showing 7 changed files with 395 additions and 113 deletions.
138 changes: 41 additions & 97 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2411,112 +2411,69 @@ func PlayFile(ctx context.Context, filename, sid string, speed float64, skipIdle
}

// SFTP securely copies files between Nodes or SSH servers using SFTP
func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) {
func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination string, opts sftp.Options) (err error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/SFTP",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

if len(args) < 2 {
return trace.Errorf("local and remote destinations are required")
}
first := args[0]
last := args[len(args)-1]
isDownload := strings.ContainsRune(source[0], ':')
isUpload := strings.ContainsRune(destination, ':')

// local copy?
if !isRemoteDest(first) && !isRemoteDest(last) {
if !isUpload && !isDownload {
return trace.BadParameter("no remote destination specified")
}

var config *sftpConfig
if isRemoteDest(last) {
config, err = tc.uploadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
} else {
config, err = tc.downloadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
}
if config.hostLogin == "" {
config.hostLogin = tc.Config.HostLogin
}

if !quiet {
config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), tc.Stdout)
}
}

return trace.Wrap(tc.TransferFiles(ctx, config.hostLogin, config.addr, config.cfg))
}

type sftpConfig struct {
cfg *sftp.Config
addr string
hostLogin string
}

func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
// args are guaranteed to have len(args) > 1
srcPaths := args[:len(args)-1]
// copy everything except the last arg (the destination)
dstPath := args[len(args)-1]

dst, addr, err := getSFTPDestination(dstPath, port)
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}
cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts)
defer clt.Close()

// Respect any proxy templates and attempt host resolution.
resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{})
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: dst.Login,
}, nil
}

func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
if len(args) > 2 {
return nil, trace.BadParameter("only one source file is supported when downloading files")
switch len(resolvedNodes) {
case 1:
case 0:
return trace.NotFound("no matching hosts found")
default:
return trace.BadParameter("multiple matching hosts found")
}

// args are guaranteed to have len(args) > 1
src, addr, err := getSFTPDestination(args[0], port)
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts)
if err != nil {
return nil, trace.Wrap(err)
var cfg *sftp.Config
switch {
case isDownload:
dest, err := sftp.ParseDestination(source[0])
if err != nil {
return trace.Wrap(err)
}
cfg, err = sftp.CreateDownloadConfig(dest.Path, destination, opts)
if err != nil {
return trace.Wrap(err)
}
case isUpload:
dest, err := sftp.ParseDestination(destination)
if err != nil {
return trace.Wrap(err)
}
cfg, err = sftp.CreateUploadConfig(source, dest.Path, opts)
if err != nil {
return trace.Wrap(err)
}
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: src.Login,
}, nil
}

func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) {
dest, err = sftp.ParseDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
}
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
return dest, addr, nil
return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg))
}

// TransferFiles copies files between the current machine and the
// specified Node using the supplied config
func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr string, cfg *sftp.Config) error {
func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, hostLogin, nodeAddr string, cfg *sftp.Config) error {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/TransferFiles",
Expand All @@ -2531,16 +2488,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
return trace.BadParameter("node address is not specified")
}

if !tc.Config.ProxySpecified() {
return trace.BadParameter("proxy server is not specified")
}
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return trace.Wrap(err)
}
defer clt.Close()

client, err := tc.ConnectToNode(
nodeClient, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{
Expand All @@ -2554,11 +2502,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
return trace.Wrap(err)
}

return trace.Wrap(client.TransferFiles(ctx, cfg))
}

func isRemoteDest(name string) bool {
return strings.ContainsRune(name, ':')
return trace.Wrap(nodeClient.TransferFiles(ctx, cfg))
}

// ListNodesWithFilters returns all nodes that match the filters in the current cluster
Expand Down
11 changes: 11 additions & 0 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package sftp

import (
"cmp"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -51,6 +52,10 @@ type Options struct {
// PreserveAttrs preserves access and modification times
// from the original file
PreserveAttrs bool
// Quiet indicates whether progress should be displayed.
Quiet bool
// ProgressWriter is used to write the progress output.
ProgressWriter io.Writer
}

// Config describes the settings of a file transfer
Expand Down Expand Up @@ -228,6 +233,12 @@ func (c *Config) setDefaults() {
"PreserveAttrs": c.opts.PreserveAttrs,
},
})

if !c.opts.Quiet {
c.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return NewProgressBar(fileInfo.Size(), fileInfo.Name(), cmp.Or(c.opts.ProgressWriter, io.Writer(os.Stdout)))
}
}
}

// TransferFiles transfers files from the configured source paths to the
Expand Down
5 changes: 3 additions & 2 deletions lib/teleterm/clusters/cluster_file_transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ import (
"github.com/gravitational/trace"

api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
)

type FileTransferProgressSender = func(progress *api.FileTransferProgress) error

func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
func (c *Cluster) TransferFile(ctx context.Context, clt *client.ClusterClient, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
config, err := getSftpConfig(request)
if err != nil {
return trace.Wrap(err)
Expand All @@ -54,7 +55,7 @@ func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferReq
}

err = AddMetadataToRetryableError(ctx, func() error {
err := c.clusterClient.TransferFiles(ctx, request.GetLogin(), serverUUID+":0", config)
err := c.clusterClient.TransferFiles(ctx, clt, request.GetLogin(), serverUUID+":0", config)
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return trace.Errorf("transferring directories through Teleport Connect is not supported at the moment, please use tsh scp -r")
}
Expand Down
7 changes: 6 additions & 1 deletion lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,12 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, request, sendProgress)
clt, err := s.GetCachedClient(ctx, cluster.URI)
if err != nil {
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, clt, request, sendProgress)
}

// CreateConnectMyComputerRole creates a role which allows access to nodes with the label
Expand Down
8 changes: 7 additions & 1 deletion lib/web/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID)
}

err = tc.TransferFiles(ctx, req.login, req.serverID+":0", cfg)
cl, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
defer cl.Close()

err = tc.TransferFiles(ctx, cl, req.login, req.serverID+":0", cfg)
if err != nil {
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return nil, trace.Errorf("transferring directories through the Web UI is not supported at the moment, please use tsh scp -r")
Expand Down
48 changes: 36 additions & 12 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3928,6 +3928,10 @@ func onJoin(cf *CLIConf) error {

// onSCP executes 'tsh scp' command
func onSCP(cf *CLIConf) error {
if len(cf.CopySpec) < 2 {
return trace.Errorf("local and remote destinations are required")
}

tc, err := makeClient(cf)
if err != nil {
return trace.Wrap(err)
Expand All @@ -3940,13 +3944,27 @@ func onSCP(cf *CLIConf) error {
cf.Context = ctx
defer cancel()

opts := sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
executor := client.RetryWithRelogin
if !cf.Relogin {
executor = func(ctx context.Context, teleportClient *client.TeleportClient, f func() error, option ...client.RetryWithReloginOption) error {
return f()
}
}
err = client.RetryWithRelogin(cf.Context, tc, func() error {
return tc.SFTP(cf.Context, cf.CopySpec, int(cf.NodePort), opts, cf.Quiet)

err = executor(cf.Context, tc, func() error {
return trace.Wrap(tc.SFTP(
cf.Context,
cf.CopySpec[:len(cf.CopySpec)-1],
cf.CopySpec[len(cf.CopySpec)-1],
sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
Quiet: cf.Quiet,
ProgressWriter: cf.Stdout(),
},
))
})

// don't print context canceled errors to the user
if err == nil || errors.Is(err, context.Canceled) {
return nil
Expand Down Expand Up @@ -4048,14 +4066,20 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
} else if cf.CopySpec != nil {
for _, location := range cf.CopySpec {
// Extract username and host from "username@host:file/path"
parts := strings.Split(location, ":")
parts = strings.Split(parts[0], "@")
partsLength := len(parts)
if partsLength > 1 {
hostLogin = strings.Join(parts[:partsLength-1], "@")
hostUser = parts[partsLength-1]
break
userHost, _, found := strings.Cut(location, ":")
if !found {
continue
}

login, hostname, found := strings.Cut(userHost, "@")
if found {
hostLogin = login
hostUser = hostname
} else {
hostUser = userHost
}
break

}
}

Expand Down
Loading

0 comments on commit e92dd31

Please sign in to comment.