ssh/knownhosts: support hashed hostnames

Change-Id: I855a6542a2eb2ae1d223f03892c0f19da81a4f8d
Reviewed-on: https://go-review.googlesource.com/40532
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go
index 66bb425..47c9239 100644
--- a/ssh/knownhosts/knownhosts.go
+++ b/ssh/knownhosts/knownhosts.go
@@ -9,6 +9,9 @@
 import (
 	"bufio"
 	"bytes"
+	"crypto/hmac"
+	"crypto/rand"
+	"crypto/sha1"
 	"encoding/base64"
 	"errors"
 	"fmt"
@@ -27,11 +30,15 @@
 type addr struct{ host, port string }
 
 func (a *addr) String() string {
-	return a.host + ":" + a.port
+	h := a.host
+	if strings.Contains(h, ":") {
+		h = "[" + h + "]"
+	}
+	return h + ":" + a.port
 }
 
-func (a *addr) eq(b addr) bool {
-	return a.host == b.host && a.port == b.port
+type matcher interface {
+	match([]addr) bool
 }
 
 type hostPattern struct {
@@ -48,6 +55,25 @@
 	return n + p.addr.String()
 }
 
+type hostPatterns []hostPattern
+
+func (ps hostPatterns) match(addrs []addr) bool {
+	matched := false
+	for _, p := range ps {
+		for _, a := range addrs {
+			m := p.match(a)
+			if !m {
+				continue
+			}
+			if p.negate {
+				return false
+			}
+			matched = true
+		}
+	}
+	return matched
+}
+
 // See
 // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
 // The matching of * has no regard for separators, unlike filesystem globs
@@ -88,48 +114,16 @@
 
 type keyDBLine struct {
 	cert     bool
-	patterns []*hostPattern
+	matcher  matcher
 	knownKey KnownKey
 }
 
-func (l *keyDBLine) String() string {
-	c := ""
-	if l.cert {
-		c = markerCert + " "
-	}
-
-	var ss []string
-	for _, p := range l.patterns {
-		ss = append(ss, p.String())
-	}
-
-	return c + strings.Join(ss, ",") + " " + serialize(l.knownKey.Key)
-}
-
 func serialize(k ssh.PublicKey) string {
 	return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
 }
 
 func (l *keyDBLine) match(addrs []addr) bool {
-	matched := false
-	for _, p := range l.patterns {
-		for _, a := range addrs {
-			m := p.match(a)
-			if p.negate {
-				if m {
-					return false
-				} else {
-					continue
-				}
-			}
-
-			if m {
-				matched = true
-			}
-		}
-	}
-
-	return matched
+	return l.matcher.match(addrs)
 }
 
 type hostKeyDB struct {
@@ -138,17 +132,6 @@
 	lines   []keyDBLine
 }
 
-func (db *hostKeyDB) String() string {
-	var ls []string
-	for _, k := range db.revoked {
-		ls = append(ls, markerRevoked+" * "+serialize(k.Key))
-	}
-	for _, l := range db.lines {
-		ls = append(ls, l.String())
-	}
-	return strings.Join(ls, "\n")
-}
-
 func newHostKeyDB() *hostKeyDB {
 	db := &hostKeyDB{
 		revoked: make(map[string]*KnownKey),
@@ -190,45 +173,39 @@
 	return string(line[:i]), bytes.TrimSpace(line[i:])
 }
 
-func parseLine(line []byte) (marker string, pattern []string, key ssh.PublicKey, err error) {
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
 	if w, next := nextWord(line); w == markerCert || w == markerRevoked {
 		marker = w
 		line = next
 	}
 
-	hostPart, line := nextWord(line)
+	host, line = nextWord(line)
 	if len(line) == 0 {
-		return "", nil, nil, errors.New("knownhosts: missing host pattern")
+		return "", "", nil, errors.New("knownhosts: missing host pattern")
 	}
 
-	if len(hostPart) > 0 && hostPart[0] == '|' {
-		return "", nil, nil, errors.New("knownhosts: hashed hostnames not implemented")
-	}
-
-	pattern = strings.Split(hostPart, ",")
-
 	// ignore the keytype as it's in the key blob anyway.
 	_, line = nextWord(line)
 	if len(line) == 0 {
-		return "", nil, nil, errors.New("knownhosts: missing key type pattern")
+		return "", "", nil, errors.New("knownhosts: missing key type pattern")
 	}
 
 	keyBlob, _ := nextWord(line)
 
 	keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
 	if err != nil {
-		return "", nil, nil, err
+		return "", "", nil, err
 	}
 	key, err = ssh.ParsePublicKey(keyBytes)
 	if err != nil {
-		return "", nil, nil, err
+		return "", "", nil, err
 	}
 
-	return marker, pattern, key, nil
+	return marker, host, key, nil
 }
 
 func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
-	marker, patterns, key, err := parseLine(line)
+	marker, pattern, key, err := parseLine(line)
 	if err != nil {
 		return err
 	}
@@ -252,7 +229,23 @@
 		},
 	}
 
-	for _, p := range patterns {
+	if pattern[0] == '|' {
+		entry.matcher, err = newHashedHost(pattern)
+	} else {
+		entry.matcher, err = newHostnameMatcher(pattern)
+	}
+
+	if err != nil {
+		return err
+	}
+
+	db.lines = append(db.lines, entry)
+	return nil
+}
+
+func newHostnameMatcher(pattern string) (matcher, error) {
+	var hps hostPatterns
+	for _, p := range strings.Split(pattern, ",") {
 		if len(p) == 0 {
 			continue
 		}
@@ -265,13 +258,14 @@
 		}
 
 		if len(p) == 0 {
-			return errors.New("knownhosts: negation without following hostname")
+			return nil, errors.New("knownhosts: negation without following hostname")
 		}
 
+		var err error
 		if p[0] == '[' {
 			a.host, a.port, err = net.SplitHostPort(p)
 			if err != nil {
-				return err
+				return nil, err
 			}
 		} else {
 			a.host, a.port, err = net.SplitHostPort(p)
@@ -280,15 +274,12 @@
 				a.port = "22"
 			}
 		}
-
-		entry.patterns = append(entry.patterns, &hostPattern{
+		hps = append(hps, hostPattern{
 			negate: negate,
 			addr:   a,
 		})
 	}
-
-	db.lines = append(db.lines, entry)
-	return nil
+	return hps, nil
 }
 
 // KnownKey represents a key declared in a known_hosts file.
@@ -446,24 +437,111 @@
 	return certChecker.CheckHostKey, nil
 }
 
+// Normalize normalizes an address into the form used in known_hosts
+func Normalize(address string) string {
+	host, port, err := net.SplitHostPort(address)
+	if err != nil {
+		host = address
+		port = "22"
+	}
+	entry := host
+	if port != "22" {
+		entry = "[" + entry + "]:" + port
+	} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
+		entry = "[" + entry + "]"
+	}
+	return entry
+}
+
 // Line returns a line to add append to the known_hosts files.
 func Line(addresses []string, key ssh.PublicKey) string {
 	var trimmed []string
 	for _, a := range addresses {
-		host, port, err := net.SplitHostPort(a)
-		if err != nil {
-			host = a
-			port = "22"
-		}
-		entry := host
-		if port != "22" {
-			entry = "[" + entry + "]:" + port
-		} else if strings.Contains(host, ":") {
-			entry = "[" + entry + "]"
-		}
-
-		trimmed = append(trimmed, entry)
+		trimmed = append(trimmed, Normalize(a))
 	}
 
 	return strings.Join(trimmed, ",") + " " + serialize(key)
 }
+
+// HashHostname hashes the given hostname. The hostname is not
+// normalized before hashing.
+func HashHostname(hostname string) string {
+	// TODO(hanwen): check if we can safely normalize this always.
+	salt := make([]byte, sha1.Size)
+
+	_, err := rand.Read(salt)
+	if err != nil {
+		panic(fmt.Sprintf("crypto/rand failure %v", err))
+	}
+
+	hash := hashHost(hostname, salt)
+	return encodeHash(sha1HashType, salt, hash)
+}
+
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
+	if len(encoded) == 0 || encoded[0] != '|' {
+		err = errors.New("knownhosts: hashed host must start with '|'")
+		return
+	}
+	components := strings.Split(encoded, "|")
+	if len(components) != 4 {
+		err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
+		return
+	}
+
+	hashType = components[1]
+	if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
+		return
+	}
+	if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
+		return
+	}
+	return
+}
+
+func encodeHash(typ string, salt []byte, hash []byte) string {
+	return strings.Join([]string{"",
+		typ,
+		base64.StdEncoding.EncodeToString(salt),
+		base64.StdEncoding.EncodeToString(hash),
+	}, "|")
+}
+
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+func hashHost(hostname string, salt []byte) []byte {
+	mac := hmac.New(sha1.New, salt)
+	mac.Write([]byte(hostname))
+	return mac.Sum(nil)
+}
+
+type hashedHost struct {
+	salt []byte
+	hash []byte
+}
+
+const sha1HashType = "1"
+
+func newHashedHost(encoded string) (*hashedHost, error) {
+	typ, salt, hash, err := decodeHash(encoded)
+	if err != nil {
+		return nil, err
+	}
+
+	// The type field seems for future algorithm agility, but it's
+	// actually hardcoded in openssh currently, see
+	// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+	if typ != sha1HashType {
+		return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
+	}
+
+	return &hashedHost{salt: salt, hash: hash}, nil
+}
+
+func (h *hashedHost) match(addrs []addr) bool {
+	for _, a := range addrs {
+		if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) {
+			return true
+		}
+	}
+	return false
+}
diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go
index 2d0834b..63aff99 100644
--- a/ssh/knownhosts/knownhosts_test.go
+++ b/ssh/knownhosts/knownhosts_test.go
@@ -235,3 +235,73 @@
 }
 
 // TODO(hanwen): test coverage for certificates.
+
+const testHostname = "hostname"
+
+// generated with keygen -H -f
+const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
+
+func TestHostHash(t *testing.T) {
+	testHostHash(t, testHostname, encodedTestHostnameHash)
+}
+
+func TestHashList(t *testing.T) {
+	encoded := HashHostname(testHostname)
+	testHostHash(t, testHostname, encoded)
+}
+
+func testHostHash(t *testing.T, hostname, encoded string) {
+	typ, salt, hash, err := decodeHash(encoded)
+	if err != nil {
+		t.Fatalf("decodeHash: %v", err)
+	}
+
+	if got := encodeHash(typ, salt, hash); got != encoded {
+		t.Errorf("got encoding %s want %s", got, encoded)
+	}
+
+	if typ != sha1HashType {
+		t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
+	}
+
+	got := hashHost(hostname, salt)
+	if !bytes.Equal(got, hash) {
+		t.Errorf("got hash %x want %x", got, hash)
+	}
+}
+
+func TestNormalize(t *testing.T) {
+	for in, want := range map[string]string{
+		"127.0.0.1:22":             "127.0.0.1",
+		"[127.0.0.1]:22":           "127.0.0.1",
+		"[127.0.0.1]:23":           "[127.0.0.1]:23",
+		"127.0.0.1:23":             "[127.0.0.1]:23",
+		"[a.b.c]:22":               "a.b.c",
+		"[abcd:abcd:abcd:abcd]":    "[abcd:abcd:abcd:abcd]",
+		"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
+		"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
+	} {
+		got := Normalize(in)
+		if got != want {
+			t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
+		}
+	}
+}
+
+func TestHashedHostkeyCheck(t *testing.T) {
+	str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
+	db := testDB(t, str)
+	if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
+		t.Errorf("check(%s): %v", testHostname, err)
+	}
+	want := &KeyError{
+		Want: []KnownKey{{
+			Filename: "testdb",
+			Line:     1,
+			Key:      edKey,
+		}},
+	}
+	if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
+		t.Errorf("got error %v, want %v", got, want)
+	}
+}