Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v17] Add host resolution support to tsh scp #49226

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 41 additions & 97 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2411,112 +2411,69 @@ func PlayFile(ctx context.Context, filename, sid string, speed float64, skipIdle
}

// SFTP securely copies files between Nodes or SSH servers using SFTP
func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) {
func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination string, opts sftp.Options) (err error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/SFTP",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

if len(args) < 2 {
return trace.Errorf("local and remote destinations are required")
}
first := args[0]
last := args[len(args)-1]
isDownload := strings.ContainsRune(source[0], ':')
isUpload := strings.ContainsRune(destination, ':')

// local copy?
if !isRemoteDest(first) && !isRemoteDest(last) {
if !isUpload && !isDownload {
return trace.BadParameter("no remote destination specified")
}

var config *sftpConfig
if isRemoteDest(last) {
config, err = tc.uploadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
} else {
config, err = tc.downloadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
}
if config.hostLogin == "" {
config.hostLogin = tc.Config.HostLogin
}

if !quiet {
config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), tc.Stdout)
}
}

return trace.Wrap(tc.TransferFiles(ctx, config.hostLogin, config.addr, config.cfg))
}

type sftpConfig struct {
cfg *sftp.Config
addr string
hostLogin string
}

func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
// args are guaranteed to have len(args) > 1
srcPaths := args[:len(args)-1]
// copy everything except the last arg (the destination)
dstPath := args[len(args)-1]

dst, addr, err := getSFTPDestination(dstPath, port)
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}
cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts)
defer clt.Close()

// Respect any proxy templates and attempt host resolution.
resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{})
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: dst.Login,
}, nil
}

func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
if len(args) > 2 {
return nil, trace.BadParameter("only one source file is supported when downloading files")
switch len(resolvedNodes) {
case 1:
case 0:
return trace.NotFound("no matching hosts found")
default:
return trace.BadParameter("multiple matching hosts found")
}

// args are guaranteed to have len(args) > 1
src, addr, err := getSFTPDestination(args[0], port)
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts)
if err != nil {
return nil, trace.Wrap(err)
var cfg *sftp.Config
switch {
case isDownload:
dest, err := sftp.ParseDestination(source[0])
if err != nil {
return trace.Wrap(err)
}
cfg, err = sftp.CreateDownloadConfig(dest.Path, destination, opts)
if err != nil {
return trace.Wrap(err)
}
case isUpload:
dest, err := sftp.ParseDestination(destination)
if err != nil {
return trace.Wrap(err)
}
cfg, err = sftp.CreateUploadConfig(source, dest.Path, opts)
if err != nil {
return trace.Wrap(err)
}
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: src.Login,
}, nil
}

func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) {
dest, err = sftp.ParseDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
}
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
return dest, addr, nil
return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg))
}

// TransferFiles copies files between the current machine and the
// specified Node using the supplied config
func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr string, cfg *sftp.Config) error {
func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, hostLogin, nodeAddr string, cfg *sftp.Config) error {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/TransferFiles",
Expand All @@ -2531,16 +2488,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
return trace.BadParameter("node address is not specified")
}

if !tc.Config.ProxySpecified() {
return trace.BadParameter("proxy server is not specified")
}
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return trace.Wrap(err)
}
defer clt.Close()

client, err := tc.ConnectToNode(
nodeClient, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{
Expand All @@ -2554,11 +2502,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
return trace.Wrap(err)
}

return trace.Wrap(client.TransferFiles(ctx, cfg))
}

func isRemoteDest(name string) bool {
return strings.ContainsRune(name, ':')
return trace.Wrap(nodeClient.TransferFiles(ctx, cfg))
}

// ListNodesWithFilters returns all nodes that match the filters in the current cluster
Expand Down
11 changes: 11 additions & 0 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package sftp

import (
"cmp"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -51,6 +52,10 @@ type Options struct {
// PreserveAttrs preserves access and modification times
// from the original file
PreserveAttrs bool
// Quiet indicates whether progress should be displayed.
Quiet bool
// ProgressWriter is used to write the progress output.
ProgressWriter io.Writer
}

// Config describes the settings of a file transfer
Expand Down Expand Up @@ -228,6 +233,12 @@ func (c *Config) setDefaults() {
"PreserveAttrs": c.opts.PreserveAttrs,
},
})

if !c.opts.Quiet {
c.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return NewProgressBar(fileInfo.Size(), fileInfo.Name(), cmp.Or(c.opts.ProgressWriter, io.Writer(os.Stdout)))
}
}
}

// TransferFiles transfers files from the configured source paths to the
Expand Down
5 changes: 3 additions & 2 deletions lib/teleterm/clusters/cluster_file_transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ import (
"github.com/gravitational/trace"

api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
)

type FileTransferProgressSender = func(progress *api.FileTransferProgress) error

func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
func (c *Cluster) TransferFile(ctx context.Context, clt *client.ClusterClient, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
config, err := getSftpConfig(request)
if err != nil {
return trace.Wrap(err)
Expand All @@ -54,7 +55,7 @@ func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferReq
}

err = AddMetadataToRetryableError(ctx, func() error {
err := c.clusterClient.TransferFiles(ctx, request.GetLogin(), serverUUID+":0", config)
err := c.clusterClient.TransferFiles(ctx, clt, request.GetLogin(), serverUUID+":0", config)
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return trace.Errorf("transferring directories through Teleport Connect is not supported at the moment, please use tsh scp -r")
}
Expand Down
7 changes: 6 additions & 1 deletion lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,12 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, request, sendProgress)
clt, err := s.GetCachedClient(ctx, cluster.URI)
if err != nil {
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, clt, request, sendProgress)
}

// CreateConnectMyComputerRole creates a role which allows access to nodes with the label
Expand Down
8 changes: 7 additions & 1 deletion lib/web/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID)
}

err = tc.TransferFiles(ctx, req.login, req.serverID+":0", cfg)
cl, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
defer cl.Close()

err = tc.TransferFiles(ctx, cl, req.login, req.serverID+":0", cfg)
if err != nil {
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return nil, trace.Errorf("transferring directories through the Web UI is not supported at the moment, please use tsh scp -r")
Expand Down
49 changes: 37 additions & 12 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
scp.Flag("preserve", "Preserves access and modification times from the original file").Short('p').BoolVar(&cf.PreserveAttrs)
scp.Flag("quiet", "Quiet mode").Short('q').BoolVar(&cf.Quiet)
scp.Flag("no-resume", "Disable SSH connection resumption").Envar(noResumeEnvVar).BoolVar(&cf.DisableSSHResumption)
scp.Flag("relogin", "Permit performing an authentication attempt on a failed command").Default("true").BoolVar(&cf.Relogin)
// ls
ls := app.Command("ls", "List remote SSH nodes.")
ls.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)
Expand Down Expand Up @@ -3928,6 +3929,10 @@ func onJoin(cf *CLIConf) error {

// onSCP executes 'tsh scp' command
func onSCP(cf *CLIConf) error {
if len(cf.CopySpec) < 2 {
return trace.Errorf("local and remote destinations are required")
}

tc, err := makeClient(cf)
if err != nil {
return trace.Wrap(err)
Expand All @@ -3940,13 +3945,27 @@ func onSCP(cf *CLIConf) error {
cf.Context = ctx
defer cancel()

opts := sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
executor := client.RetryWithRelogin
if !cf.Relogin {
executor = func(ctx context.Context, teleportClient *client.TeleportClient, f func() error, option ...client.RetryWithReloginOption) error {
return f()
}
}
err = client.RetryWithRelogin(cf.Context, tc, func() error {
return tc.SFTP(cf.Context, cf.CopySpec, int(cf.NodePort), opts, cf.Quiet)

err = executor(cf.Context, tc, func() error {
return trace.Wrap(tc.SFTP(
cf.Context,
cf.CopySpec[:len(cf.CopySpec)-1],
cf.CopySpec[len(cf.CopySpec)-1],
sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
Quiet: cf.Quiet,
ProgressWriter: cf.Stdout(),
},
))
})

// don't print context canceled errors to the user
if err == nil || errors.Is(err, context.Canceled) {
return nil
Expand Down Expand Up @@ -4048,14 +4067,20 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
} else if cf.CopySpec != nil {
for _, location := range cf.CopySpec {
// Extract username and host from "username@host:file/path"
parts := strings.Split(location, ":")
parts = strings.Split(parts[0], "@")
partsLength := len(parts)
if partsLength > 1 {
hostLogin = strings.Join(parts[:partsLength-1], "@")
hostUser = parts[partsLength-1]
break
userHost, _, found := strings.Cut(location, ":")
if !found {
continue
}

login, hostname, found := strings.Cut(userHost, "@")
if found {
hostLogin = login
hostUser = hostname
} else {
hostUser = userHost
}
break

}
}

Expand Down
Loading
Loading