diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 452a44237cf05..ed3472b45694e 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -17,12 +17,14 @@ limitations under the License. package client import ( + "bufio" "fmt" "io" "io/ioutil" "os" "os/user" "path/filepath" + "strings" "time" "github.com/gravitational/teleport/lib/sshutils" @@ -176,15 +178,42 @@ func (fs *FSLocalKeyStore) GetKey(host, username string) (*Key, error) { // AddKnownHost adds a new entry to 'known_CAs' file func (fs *FSLocalKeyStore) AddKnownCA(domainName string, hostKeys []ssh.PublicKey) error { - fp, err := os.OpenFile(filepath.Join(fs.KeyDir, fileNameKnownHosts), os.O_CREATE|os.O_APPEND|os.O_RDWR, 0640) + fp, err := os.OpenFile(filepath.Join(fs.KeyDir, fileNameKnownHosts), os.O_CREATE|os.O_RDWR, 0640) if err != nil { return trace.Wrap(err) } + defer fp.Sync() defer fp.Close() + // read all existing entries into a map (this removes any pre-existing dupes) + entries := make(map[string]int) + output := make([]string, 0) + scanner := bufio.NewScanner(fp) + for scanner.Scan() { + line := scanner.Text() + if _, exists := entries[line]; !exists { + output = append(output, line) + entries[line] = 1 + } + } + // add every host key to the list of entries for i := range hostKeys { - bytes := ssh.MarshalAuthorizedKey(hostKeys[i]) log.Infof("adding known CA %v %v", domainName, sshutils.Fingerprint(hostKeys[i])) - fmt.Fprintf(fp, "%s %s\n", domainName, bytes) + bytes := ssh.MarshalAuthorizedKey(hostKeys[i]) + line := strings.TrimSpace(fmt.Sprintf("%s %s", domainName, bytes)) + if _, exists := entries[line]; !exists { + output = append(output, line) + } + } + // re-create the file: + _, err = fp.Seek(0, 0) + if err != nil { + return trace.Wrap(err) + } + if err = fp.Truncate(0); err != nil { + return trace.Wrap(err) + } + for _, line := range output { + fmt.Fprintf(fp, "%s\n", line) } return nil } @@ -198,7 +227,6 @@ func (fs *FSLocalKeyStore) GetKnownCAs() ([]ssh.PublicKey, error) { } return nil, trace.Wrap(err) } - var ( pubKey ssh.PublicKey retval []ssh.PublicKey = make([]ssh.PublicKey, 0) diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index a99ae959e2f53..c51fc2b7e7e91 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -128,6 +128,13 @@ func (s *KeyStoreTestSuite) TestKnownHosts(c *check.C) { c.Assert(err, check.IsNil) c.Assert(keys, check.HasLen, 3) c.Assert(keys, check.DeepEquals, []ssh.PublicKey{pub, pub2, pub2}) + + // check against dupes: + before, _ := s.store.GetKnownCAs() + s.store.AddKnownCA("example.org", []ssh.PublicKey{pub2}) + s.store.AddKnownCA("example.org", []ssh.PublicKey{pub2}) + after, _ := s.store.GetKnownCAs() + c.Assert(len(before), check.Equals, len(after)) } // makeSIgnedKey helper returns all 3 components of a user key (signed by CAPriv key) diff --git a/lib/srv/sshserver_test.go b/lib/srv/sshserver_test.go index fca0ffbbee6c7..2869331c20833 100644 --- a/lib/srv/sshserver_test.go +++ b/lib/srv/sshserver_test.go @@ -216,7 +216,7 @@ func (s *SrvSuite) TestShell(c *C) { // send a few "keyboard inputs" into the session: _, err = io.WriteString(writer, "echo $((50+100))\n\r") c.Assert(err, IsNil) - time.Sleep(time.Millisecond * 3) + time.Sleep(time.Millisecond * 5) // read the output and make sure that "150" (output of $((50+100)) is there _, err = reader.Read(buf)