Skip to content

Commit

Permalink
Merge pull request coreos#1529 from endocode/tixxdz/ssh_known_hosts_f…
Browse files Browse the repository at this point in the history
…ixes

ssh: define the list of Key Algorithms of remote hosts before handshake
  • Loading branch information
tixxdz committed Apr 4, 2016
2 parents ec514b3 + d81fe84 commit 2957742
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
42 changes: 42 additions & 0 deletions ssh/known_hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,48 @@ func NewHostKeyChecker(m HostKeyManager) *HostKeyChecker {
return &HostKeyChecker{m, askToTrustHost}
}

// Returns public key algorithms of the remote host that are listed
// inside known_hosts
func (kc *HostKeyChecker) GetHostKeyAlgorithms(addr string) []string {
var results []string
remoteAddr, err := kc.addrToHostPort(addr)
if err != nil {
log.Errorf("Failed to parse address %v: %v", addr, err)
return nil
}

hostKeys, err := kc.m.GetHostKeys()

_, ok := err.(*os.PathError)
if err != nil && !ok {
log.Errorf("Failed to read known_hosts file %v: %v", kc.m.String(), err)
return nil
}

for pattern, keys := range hostKeys {
if !matchHost(remoteAddr, pattern) {
remoteIP, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
log.Errorf("Failed to resolve TCP address %v: %v", addr, err)
continue
}
ipAddr, err := kc.addrToHostPort(remoteIP.String())
if err != nil {
log.Errorf("Failed to parse address %v: %v", remoteIP.String(), err)
continue
}
if !matchHost(ipAddr, pattern) {
continue
}
}
for _, hostKey := range keys {
results = append(results, hostKey.Type())
}
}

return results
}

// Check is called during the handshake to check the server's public key for
// unexpected changes. The key argument is in SSH wire format. It can be parsed
// using ssh.ParsePublicKey. The address before DNS resolution is passed in the
Expand Down
17 changes: 9 additions & 8 deletions ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func SSHAgentClient() (gosshagent.Agent, error) {
return gosshagent.NewClient(agent), nil
}

func sshClientConfig(user string, checker *HostKeyChecker) (*gossh.ClientConfig, error) {
func sshClientConfig(user string, checker *HostKeyChecker, addr string) (*gossh.ClientConfig, error) {
agentClient, err := SSHAgentClient()
if err != nil {
return nil, err
Expand All @@ -191,6 +191,7 @@ func sshClientConfig(user string, checker *HostKeyChecker) (*gossh.ClientConfig,

if checker != nil {
cfg.HostKeyCallback = checker.Check
cfg.HostKeyAlgorithms = checker.GetHostKeyAlgorithms(addr)
}

return &cfg, nil
Expand All @@ -204,13 +205,13 @@ func maybeAddDefaultPort(addr string) string {
}

func NewSSHClient(user, addr string, checker *HostKeyChecker, agentForwarding bool, timeout time.Duration) (*SSHForwardingClient, error) {
clientConfig, err := sshClientConfig(user, checker)
addr = maybeAddDefaultPort(addr)

clientConfig, err := sshClientConfig(user, checker, addr)
if err != nil {
return nil, err
}

addr = maybeAddDefaultPort(addr)

var client *gossh.Client
dialFunc := func(echan chan error) {
var err error
Expand All @@ -226,14 +227,14 @@ func NewSSHClient(user, addr string, checker *HostKeyChecker, agentForwarding bo
}

func NewTunnelledSSHClient(user, tunaddr, tgtaddr string, checker *HostKeyChecker, agentForwarding bool, timeout time.Duration) (*SSHForwardingClient, error) {
clientConfig, err := sshClientConfig(user, checker)
tunaddr = maybeAddDefaultPort(tunaddr)
tgtaddr = maybeAddDefaultPort(tgtaddr)

clientConfig, err := sshClientConfig(user, checker, tunaddr)
if err != nil {
return nil, err
}

tunaddr = maybeAddDefaultPort(tunaddr)
tgtaddr = maybeAddDefaultPort(tgtaddr)

var tunnelClient *gossh.Client
dialFunc := func(echan chan error) {
var err error
Expand Down

0 comments on commit 2957742

Please sign in to comment.