ssh: rationalize rekeying decisions.

1) Always force a key exchange if we exchange 2^31 packets. In the past
this might not happen if RekeyThreshold was set to a very large
interval.

2) Follow recommendations from RFC 4344 for block ciphers. For AES, we
can encrypt 2^(blocksize/4) blocks under the same keys.

On modern hardware, the previous default of 1Gb could force a key
exchange within ~10 seconds. Since the key exchange takes 3 roundtrips
(send kex init, send DH init, send NEW_KEYS), this is relatively
expensive on high-latency links.

Change-Id: I1297124a307c541b7bf22d814d136ec0c6d8ed97
Reviewed-on: https://go-review.googlesource.com/35410
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/common.go b/ssh/common.go
index 2c72ab5..faabb7e 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -104,6 +104,21 @@
 	Compression string
 }
 
+// rekeyBytes returns a rekeying intervals in bytes.
+func (a *directionAlgorithms) rekeyBytes() int64 {
+	// According to RFC4344 block ciphers should rekey after
+	// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
+	// 128.
+	switch a.Cipher {
+	case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
+		return 16 * (1 << 32)
+
+	}
+
+	// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
+	return 1 << 30
+}
+
 type algorithms struct {
 	kex     string
 	hostKey string
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 93c23d1..57f2d3d 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -78,9 +78,14 @@
 	dialAddress     string
 	remoteAddr      net.Addr
 
-	readSinceKex uint64
+	// Algorithms agreed in the last key exchange.
+	algorithms *algorithms
 
-	writtenSinceKex uint64
+	readPacketsLeft uint32
+	readBytesLeft   int64
+
+	writePacketsLeft uint32
+	writeBytesLeft   int64
 
 	// The session ID or nil if first kex did not complete yet.
 	sessionID []byte
@@ -290,7 +295,12 @@
 		t.writeError = err
 		t.sentInitPacket = nil
 		t.sentInitMsg = nil
-		t.writtenSinceKex = 0
+		t.writePacketsLeft = packetRekeyThreshold
+		if t.config.RekeyThreshold > 0 {
+			t.writeBytesLeft = int64(t.config.RekeyThreshold)
+		} else if t.algorithms != nil {
+			t.writeBytesLeft = t.algorithms.w.rekeyBytes()
+		}
 		request.done <- t.writeError
 
 		// kex finished. Push packets that we received while
@@ -320,17 +330,31 @@
 	t.conn.Close()
 }
 
-func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
-	if t.readSinceKex > t.config.RekeyThreshold {
-		t.requestKeyExchange()
-	}
+// The protocol uses uint32 for packet counters, so we can't let them
+// reach 1<<32.  We will actually read and write more packets than
+// this, though: the other side may send more packets, and after we
+// hit this limit on writing we will send a few more packets for the
+// key exchange itself.
+const packetRekeyThreshold = (1 << 31)
 
+func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
 	p, err := t.conn.readPacket()
 	if err != nil {
 		return nil, err
 	}
 
-	t.readSinceKex += uint64(len(p))
+	if t.readPacketsLeft > 0 {
+		t.readPacketsLeft--
+	} else {
+		t.requestKeyExchange()
+	}
+
+	if t.readBytesLeft > 0 {
+		t.readBytesLeft -= int64(len(p))
+	} else {
+		t.requestKeyExchange()
+	}
+
 	if debugHandshake {
 		t.printPacket(p, false)
 	}
@@ -360,7 +384,12 @@
 		return nil, err
 	}
 
-	t.readSinceKex = 0
+	t.readPacketsLeft = packetRekeyThreshold
+	if t.config.RekeyThreshold > 0 {
+		t.readBytesLeft = int64(t.config.RekeyThreshold)
+	} else {
+		t.readBytesLeft = t.algorithms.r.rekeyBytes()
+	}
 
 	// By default, a key exchange is hidden from higher layers by
 	// translating it into msgIgnore.
@@ -443,8 +472,16 @@
 		t.pendingPackets = append(t.pendingPackets, cp)
 		return nil
 	}
-	t.writtenSinceKex += uint64(len(p))
-	if t.writtenSinceKex > t.config.RekeyThreshold {
+
+	if t.writeBytesLeft > 0 {
+		t.writeBytesLeft -= int64(len(p))
+	} else {
+		t.requestKeyExchange()
+	}
+
+	if t.writePacketsLeft > 0 {
+		t.writePacketsLeft--
+	} else {
 		t.requestKeyExchange()
 	}
 
@@ -485,7 +522,8 @@
 		magics.serverKexInit = otherInitPacket
 	}
 
-	algs, err := findAgreedAlgorithms(clientInit, serverInit)
+	var err error
+	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
 	if err != nil {
 		return err
 	}
@@ -508,16 +546,16 @@
 		}
 	}
 
-	kex, ok := kexAlgoMap[algs.kex]
+	kex, ok := kexAlgoMap[t.algorithms.kex]
 	if !ok {
-		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
 	}
 
 	var result *kexResult
 	if len(t.hostKeys) > 0 {
-		result, err = t.server(kex, algs, &magics)
+		result, err = t.server(kex, t.algorithms, &magics)
 	} else {
-		result, err = t.client(kex, algs, &magics)
+		result, err = t.client(kex, t.algorithms, &magics)
 	}
 
 	if err != nil {
@@ -529,7 +567,7 @@
 	}
 	result.SessionID = t.sessionID
 
-	t.conn.prepareKeyChange(algs, result)
+	t.conn.prepareKeyChange(t.algorithms, result)
 	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}