ssh/knownhosts: support hashed hostnames
Change-Id: I855a6542a2eb2ae1d223f03892c0f19da81a4f8d
Reviewed-on: https://go-review.googlesource.com/40532
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go
index 66bb425..47c9239 100644
--- a/ssh/knownhosts/knownhosts.go
+++ b/ssh/knownhosts/knownhosts.go
@@ -9,6 +9,9 @@
import (
"bufio"
"bytes"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha1"
"encoding/base64"
"errors"
"fmt"
@@ -27,11 +30,15 @@
type addr struct{ host, port string }
func (a *addr) String() string {
- return a.host + ":" + a.port
+ h := a.host
+ if strings.Contains(h, ":") {
+ h = "[" + h + "]"
+ }
+ return h + ":" + a.port
}
-func (a *addr) eq(b addr) bool {
- return a.host == b.host && a.port == b.port
+type matcher interface {
+ match([]addr) bool
}
type hostPattern struct {
@@ -48,6 +55,25 @@
return n + p.addr.String()
}
+type hostPatterns []hostPattern
+
+func (ps hostPatterns) match(addrs []addr) bool {
+ matched := false
+ for _, p := range ps {
+ for _, a := range addrs {
+ m := p.match(a)
+ if !m {
+ continue
+ }
+ if p.negate {
+ return false
+ }
+ matched = true
+ }
+ }
+ return matched
+}
+
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
@@ -88,48 +114,16 @@
type keyDBLine struct {
cert bool
- patterns []*hostPattern
+ matcher matcher
knownKey KnownKey
}
-func (l *keyDBLine) String() string {
- c := ""
- if l.cert {
- c = markerCert + " "
- }
-
- var ss []string
- for _, p := range l.patterns {
- ss = append(ss, p.String())
- }
-
- return c + strings.Join(ss, ",") + " " + serialize(l.knownKey.Key)
-}
-
func serialize(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
func (l *keyDBLine) match(addrs []addr) bool {
- matched := false
- for _, p := range l.patterns {
- for _, a := range addrs {
- m := p.match(a)
- if p.negate {
- if m {
- return false
- } else {
- continue
- }
- }
-
- if m {
- matched = true
- }
- }
- }
-
- return matched
+ return l.matcher.match(addrs)
}
type hostKeyDB struct {
@@ -138,17 +132,6 @@
lines []keyDBLine
}
-func (db *hostKeyDB) String() string {
- var ls []string
- for _, k := range db.revoked {
- ls = append(ls, markerRevoked+" * "+serialize(k.Key))
- }
- for _, l := range db.lines {
- ls = append(ls, l.String())
- }
- return strings.Join(ls, "\n")
-}
-
func newHostKeyDB() *hostKeyDB {
db := &hostKeyDB{
revoked: make(map[string]*KnownKey),
@@ -190,45 +173,39 @@
return string(line[:i]), bytes.TrimSpace(line[i:])
}
-func parseLine(line []byte) (marker string, pattern []string, key ssh.PublicKey, err error) {
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
- hostPart, line := nextWord(line)
+ host, line = nextWord(line)
if len(line) == 0 {
- return "", nil, nil, errors.New("knownhosts: missing host pattern")
+ return "", "", nil, errors.New("knownhosts: missing host pattern")
}
- if len(hostPart) > 0 && hostPart[0] == '|' {
- return "", nil, nil, errors.New("knownhosts: hashed hostnames not implemented")
- }
-
- pattern = strings.Split(hostPart, ",")
-
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
- return "", nil, nil, errors.New("knownhosts: missing key type pattern")
+ return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
- return "", nil, nil, err
+ return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
- return "", nil, nil, err
+ return "", "", nil, err
}
- return marker, pattern, key, nil
+ return marker, host, key, nil
}
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
- marker, patterns, key, err := parseLine(line)
+ marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
@@ -252,7 +229,23 @@
},
}
- for _, p := range patterns {
+ if pattern[0] == '|' {
+ entry.matcher, err = newHashedHost(pattern)
+ } else {
+ entry.matcher, err = newHostnameMatcher(pattern)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ db.lines = append(db.lines, entry)
+ return nil
+}
+
+func newHostnameMatcher(pattern string) (matcher, error) {
+ var hps hostPatterns
+ for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
@@ -265,13 +258,14 @@
}
if len(p) == 0 {
- return errors.New("knownhosts: negation without following hostname")
+ return nil, errors.New("knownhosts: negation without following hostname")
}
+ var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
- return err
+ return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
@@ -280,15 +274,12 @@
a.port = "22"
}
}
-
- entry.patterns = append(entry.patterns, &hostPattern{
+ hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
-
- db.lines = append(db.lines, entry)
- return nil
+ return hps, nil
}
// KnownKey represents a key declared in a known_hosts file.
@@ -446,24 +437,111 @@
return certChecker.CheckHostKey, nil
}
+// Normalize normalizes an address into the form used in known_hosts
+func Normalize(address string) string {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ host = address
+ port = "22"
+ }
+ entry := host
+ if port != "22" {
+ entry = "[" + entry + "]:" + port
+ } else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
+ entry = "[" + entry + "]"
+ }
+ return entry
+}
+
// Line returns a line to add append to the known_hosts files.
func Line(addresses []string, key ssh.PublicKey) string {
var trimmed []string
for _, a := range addresses {
- host, port, err := net.SplitHostPort(a)
- if err != nil {
- host = a
- port = "22"
- }
- entry := host
- if port != "22" {
- entry = "[" + entry + "]:" + port
- } else if strings.Contains(host, ":") {
- entry = "[" + entry + "]"
- }
-
- trimmed = append(trimmed, entry)
+ trimmed = append(trimmed, Normalize(a))
}
return strings.Join(trimmed, ",") + " " + serialize(key)
}
+
+// HashHostname hashes the given hostname. The hostname is not
+// normalized before hashing.
+func HashHostname(hostname string) string {
+ // TODO(hanwen): check if we can safely normalize this always.
+ salt := make([]byte, sha1.Size)
+
+ _, err := rand.Read(salt)
+ if err != nil {
+ panic(fmt.Sprintf("crypto/rand failure %v", err))
+ }
+
+ hash := hashHost(hostname, salt)
+ return encodeHash(sha1HashType, salt, hash)
+}
+
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
+ if len(encoded) == 0 || encoded[0] != '|' {
+ err = errors.New("knownhosts: hashed host must start with '|'")
+ return
+ }
+ components := strings.Split(encoded, "|")
+ if len(components) != 4 {
+ err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
+ return
+ }
+
+ hashType = components[1]
+ if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
+ return
+ }
+ if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
+ return
+ }
+ return
+}
+
+func encodeHash(typ string, salt []byte, hash []byte) string {
+ return strings.Join([]string{"",
+ typ,
+ base64.StdEncoding.EncodeToString(salt),
+ base64.StdEncoding.EncodeToString(hash),
+ }, "|")
+}
+
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+func hashHost(hostname string, salt []byte) []byte {
+ mac := hmac.New(sha1.New, salt)
+ mac.Write([]byte(hostname))
+ return mac.Sum(nil)
+}
+
+type hashedHost struct {
+ salt []byte
+ hash []byte
+}
+
+const sha1HashType = "1"
+
+func newHashedHost(encoded string) (*hashedHost, error) {
+ typ, salt, hash, err := decodeHash(encoded)
+ if err != nil {
+ return nil, err
+ }
+
+ // The type field seems for future algorithm agility, but it's
+ // actually hardcoded in openssh currently, see
+ // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+ if typ != sha1HashType {
+ return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
+ }
+
+ return &hashedHost{salt: salt, hash: hash}, nil
+}
+
+func (h *hashedHost) match(addrs []addr) bool {
+ for _, a := range addrs {
+ if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go
index 2d0834b..63aff99 100644
--- a/ssh/knownhosts/knownhosts_test.go
+++ b/ssh/knownhosts/knownhosts_test.go
@@ -235,3 +235,73 @@
}
// TODO(hanwen): test coverage for certificates.
+
+const testHostname = "hostname"
+
+// generated with keygen -H -f
+const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
+
+func TestHostHash(t *testing.T) {
+ testHostHash(t, testHostname, encodedTestHostnameHash)
+}
+
+func TestHashList(t *testing.T) {
+ encoded := HashHostname(testHostname)
+ testHostHash(t, testHostname, encoded)
+}
+
+func testHostHash(t *testing.T, hostname, encoded string) {
+ typ, salt, hash, err := decodeHash(encoded)
+ if err != nil {
+ t.Fatalf("decodeHash: %v", err)
+ }
+
+ if got := encodeHash(typ, salt, hash); got != encoded {
+ t.Errorf("got encoding %s want %s", got, encoded)
+ }
+
+ if typ != sha1HashType {
+ t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
+ }
+
+ got := hashHost(hostname, salt)
+ if !bytes.Equal(got, hash) {
+ t.Errorf("got hash %x want %x", got, hash)
+ }
+}
+
+func TestNormalize(t *testing.T) {
+ for in, want := range map[string]string{
+ "127.0.0.1:22": "127.0.0.1",
+ "[127.0.0.1]:22": "127.0.0.1",
+ "[127.0.0.1]:23": "[127.0.0.1]:23",
+ "127.0.0.1:23": "[127.0.0.1]:23",
+ "[a.b.c]:22": "a.b.c",
+ "[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]",
+ "[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
+ "[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
+ } {
+ got := Normalize(in)
+ if got != want {
+ t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
+ }
+ }
+}
+
+func TestHashedHostkeyCheck(t *testing.T) {
+ str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
+ db := testDB(t, str)
+ if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
+ t.Errorf("check(%s): %v", testHostname, err)
+ }
+ want := &KeyError{
+ Want: []KnownKey{{
+ Filename: "testdb",
+ Line: 1,
+ Key: edKey,
+ }},
+ }
+ if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
+ t.Errorf("got error %v, want %v", got, want)
+ }
+}