Skip to content

Commit

Permalink
Merge pull request #27 from loft-sh/fix/POD-605-ssh-provider-windows
Browse files Browse the repository at this point in the history
fix: openssh on windows doesn't play nicely with piping data through Stdin
  • Loading branch information
89luca89 authored May 29, 2024
2 parents 907a94e + d7585cb commit a9c6447
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ package ssh

import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
"time"

"github.com/kballard/go-shellquote"
"github.com/loft-sh/devpod-provider-ssh/pkg/options"
"github.com/loft-sh/devpod/pkg/log"
"github.com/loft-sh/devpod/pkg/ssh"
)

type SSHProvider struct {
Expand Down Expand Up @@ -68,6 +73,49 @@ func getSSHCommand(provider *SSHProvider) ([]string, error) {
}

func execSSHCommand(provider *SSHProvider, command string, output io.Writer) error {
if runtime.GOOS == "windows" {
// get ssh config for host
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
sshConfig, err := exec.CommandContext(ctx, "ssh", "-G", provider.Config.Host).Output()
if err != nil {
return fmt.Errorf("read ssh config for host %s: %w", provider.Config.Host, err)
}
hostname, user, port, identityfile := parseConfig(string(sshConfig))
if hostname == "" || user == "" || port == "" {
return fmt.Errorf("resolve ssh config. Hostname='%s', User='%s', Port='%s'", hostname, user, port)
}

// expand identityfile path
if strings.HasPrefix(identityfile, "~") {
identityfile = strings.Replace(identityfile, "~", "$userprofile", 1)
identityfile = os.ExpandEnv(identityfile)
}
abs, err := filepath.Abs(identityfile)
if err != nil {
return fmt.Errorf("absolute filepath: %w", err)
}
key, err := os.ReadFile(abs)
if err != nil {
return fmt.Errorf("read identifiyfile: %w", err)
}

// create ssh session
addr := net.JoinHostPort(hostname, port)
client, err := ssh.NewSSHClient(user, addr, key)
if err != nil {
return fmt.Errorf("create ssh client: %w", err)
}
sess, err := client.NewSession()
if err != nil {
return fmt.Errorf("create ssh session: %w", err)
}
sess.Stdin = os.Stdin
sess.Stdout = output

return sess.Run(command)
}

commandToRun, err := getSSHCommand(provider)
if err != nil {
return err
Expand Down Expand Up @@ -230,3 +278,31 @@ func Init(provider *SSHProvider) error {
func Command(provider *SSHProvider, command string) error {
return execSSHCommand(provider, command, os.Stdout)
}

func parseConfig(config string) (hostname string, user string, port string, identityfile string) {
for _, line := range strings.Split(config, "\n") {
fields := strings.Fields(line)
if len(fields) != 2 {
continue
}
if fields[0] == "hostname" {
hostname = fields[1]
continue
}
if fields[0] == "user" {
user = fields[1]
continue
}
if fields[0] == "port" {
port = fields[1]
continue
}
// just take the first one
if fields[0] == "identityfile" && identityfile == "" {
identityfile = fields[1]
continue
}
}

return hostname, user, port, identityfile
}

0 comments on commit a9c6447

Please sign in to comment.