x/crypto/ssh: add function to parse known_hosts files.

Change-Id: I9258ecf2b38258e31bcb6e73ac042ad8125fd2d1
Reviewed-on: https://go-review.googlesource.com/18106
Reviewed-by: Peter Moody <peter.moody@gmail.com>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/keys.go b/ssh/keys.go
index 8ce4185..cfc970b 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -19,6 +19,7 @@
 	"fmt"
 	"io"
 	"math/big"
+	"strings"
 )
 
 // These constants represent the algorithm names for key types supported by this
@@ -77,6 +78,79 @@
 	return out, comment, nil
 }
 
+// ParseKnownHosts parses an entry in the format of the known_hosts file.
+//
+// The known_hosts format is documented in the sshd(8) manual page. This
+// function will parse a single entry from in. On successful return, marker
+// will contain the optional marker value (i.e. "cert-authority" or "revoked")
+// or else be empty, hosts will contain the hosts that this entry matches,
+// pubKey will contain the public key and comment will contain any trailing
+// comment at the end of the line. See the sshd(8) manual page for the various
+// forms that a host string can take.
+//
+// The unparsed remainder of the input will be returned in rest. This function
+// can be called repeatedly to parse multiple entries.
+//
+// If no entries were found in the input then err will be io.EOF. Otherwise a
+// non-nil err value indicates a parse error.
+func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) {
+	for len(in) > 0 {
+		end := bytes.IndexByte(in, '\n')
+		if end != -1 {
+			rest = in[end+1:]
+			in = in[:end]
+		} else {
+			rest = nil
+		}
+
+		end = bytes.IndexByte(in, '\r')
+		if end != -1 {
+			in = in[:end]
+		}
+
+		in = bytes.TrimSpace(in)
+		if len(in) == 0 || in[0] == '#' {
+			in = rest
+			continue
+		}
+
+		i := bytes.IndexAny(in, " \t")
+		if i == -1 {
+			in = rest
+			continue
+		}
+
+		// Strip out the begining of the known_host key.
+		// This is either an optional marker or a (set of) hostname(s).
+		keyFields := bytes.Fields(in)
+		if len(keyFields) < 3 || len(keyFields) > 5 {
+			return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data")
+		}
+
+		// keyFields[0] is either "@cert-authority", "@revoked" or a comma separated
+		// list of hosts
+		marker := ""
+		if keyFields[0][0] == '@' {
+			marker = string(keyFields[0][1:])
+			keyFields = keyFields[1:]
+		}
+
+		hosts := string(keyFields[0])
+		// keyFields[1] contains the key type (e.g. “ssh-rsa”).
+		// However, that information is duplicated inside the
+		// base64-encoded key and so is ignored here.
+
+		key := bytes.Join(keyFields[2:], []byte(" "))
+		if pubKey, comment, err = parseAuthorizedKey(key); err != nil {
+			return "", nil, nil, "", nil, err
+		}
+
+		return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil
+	}
+
+	return "", nil, nil, "", nil, io.EOF
+}
+
 // ParseAuthorizedKeys parses a public key from an authorized_keys
 // file used in OpenSSH according to the sshd(8) manual page.
 func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index b4cceaf..2756947 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -304,3 +304,134 @@
 		t.Errorf("got valid entry for %q", authInvalid)
 	}
 }
+
+var knownHostsParseTests = []struct {
+	input     string
+	err       string
+
+	marker   string
+	comment  string
+	hosts    []string
+	rest     string
+} {
+	{
+		"",
+		"EOF",
+
+		"", "", nil, "",
+	},
+	{
+		"# Just a comment",
+		"EOF",
+
+		"", "", nil, "",
+	},
+	{
+		"   \t   ",
+		"EOF",
+
+		"", "", nil, "",
+	},
+	{
+		"localhost ssh-rsa {RSAPUB}",
+		"",
+
+		"", "", []string{"localhost"}, "",
+	},
+	{
+		"localhost\tssh-rsa {RSAPUB}",
+		"",
+
+		"", "", []string{"localhost"}, "",
+	},
+	{
+		"localhost\tssh-rsa {RSAPUB}\tcomment comment",
+		"",
+
+		"", "comment comment", []string{"localhost"}, "",
+	},
+	{
+		"localhost\tssh-rsa {RSAPUB}\tcomment comment\n",
+		"",
+
+		"", "comment comment", []string{"localhost"}, "",
+	},
+	{
+		"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n",
+		"",
+
+		"", "comment comment", []string{"localhost"}, "",
+	},
+	{
+		"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line",
+		"",
+
+		"", "comment comment", []string{"localhost"}, "next line",
+	},
+	{
+		"localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment",
+		"",
+
+		"", "comment comment", []string{"localhost","[host2:123]"}, "",
+	},
+	{
+		"@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}",
+		"",
+
+		"marker", "", []string{"localhost","[host2:123]"}, "",
+	},
+	{
+		"@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd",
+		"short read",
+
+		"", "", nil, "",
+	},
+}
+
+func TestKnownHostsParsing(t *testing.T) {
+	rsaPub, rsaPubSerialized := getTestKey()
+
+	for i, test := range knownHostsParseTests {
+		var expectedKey PublicKey
+		const rsaKeyToken = "{RSAPUB}"
+
+		input := test.input
+		if strings.Contains(input, rsaKeyToken) {
+			expectedKey = rsaPub
+			input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1)
+		}
+
+		marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input))
+		if err != nil {
+			if len(test.err) == 0 {
+				t.Errorf("#%d: unexpectedly failed with %q", i, err)
+			} else if !strings.Contains(err.Error(), test.err) {
+				t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err)
+			}
+			continue
+		} else if len(test.err) != 0 {
+			t.Errorf("#%d: succeeded but expected error including %q", i, test.err)
+			continue
+		}
+
+		if !reflect.DeepEqual(expectedKey, pubKey) {
+			t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey)
+		}
+
+		if marker != test.marker {
+			t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker)
+		}
+
+		if comment != test.comment {
+			t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment)
+		}
+
+		if !reflect.DeepEqual(test.hosts, hosts) {
+			t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts)
+		}
+
+		if rest := string(rest); rest != test.rest {
+			t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest)
+		}
+	}
+}