ssh: limit the size of the internal packet queue while waiting for KEX

In the SSH protocol, clients and servers execute the key exchange to
generate one-time session keys used for encryption and authentication.
The key exchange is performed initially after the connection is
established and then periodically after a configurable amount of data.
While a key exchange is in progress, we add the received packets to an
internal queue until we receive SSH_MSG_KEXINIT from the other side.
This can result in high memory usage if the other party is slow to
respond to the SSH_MSG_KEXINIT packet, or memory exhaustion if a
malicious client never responds to an SSH_MSG_KEXINIT packet during a
large file transfer.
We now limit the internal queue to 64 packets: this means 2MB with the
typical 32KB packet size.
When the internal queue is full we block further writes until the
pending key exchange is completed or there is a read or write error.

Thanks to Yuichi Watanabe for reporting this issue.

Change-Id: I1ce2214cc16e08b838d4bc346c74c72addafaeec
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/652135
Reviewed-by: Neal Patel <nealpatel@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index fef687d..c9202b0 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -25,6 +25,11 @@
 // quickly.
 const chanSize = 16
 
+// maxPendingPackets sets the maximum number of packets to queue while waiting
+// for KEX to complete. This limits the total pending data to maxPendingPackets
+// * maxPacket bytes, which is ~16.8MB.
+const maxPendingPackets = 64
+
 // keyingTransport is a packet based transport that supports key
 // changes. It need not be thread-safe. It should pass through
 // msgNewKeys in both directions.
@@ -73,11 +78,19 @@
 	incoming  chan []byte
 	readError error
 
-	mu               sync.Mutex
-	writeError       error
-	sentInitPacket   []byte
-	sentInitMsg      *kexInitMsg
-	pendingPackets   [][]byte // Used when a key exchange is in progress.
+	mu sync.Mutex
+	// Condition for the above mutex. It is used to notify a completed key
+	// exchange or a write failure. Writes can wait for this condition while a
+	// key exchange is in progress.
+	writeCond      *sync.Cond
+	writeError     error
+	sentInitPacket []byte
+	sentInitMsg    *kexInitMsg
+	// Used to queue writes when a key exchange is in progress. The length is
+	// limited by pendingPacketsSize. Once full, writes will block until the key
+	// exchange is completed or an error occurs. If not empty, it is emptied
+	// all at once when the key exchange is completed in kexLoop.
+	pendingPackets   [][]byte
 	writePacketsLeft uint32
 	writeBytesLeft   int64
 	userAuthComplete bool // whether the user authentication phase is complete
@@ -134,6 +147,7 @@
 
 		config: config,
 	}
+	t.writeCond = sync.NewCond(&t.mu)
 	t.resetReadThresholds()
 	t.resetWriteThresholds()
 
@@ -260,6 +274,7 @@
 	defer t.mu.Unlock()
 	if t.writeError == nil && err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 }
 
@@ -363,6 +378,8 @@
 			}
 		}
 		t.pendingPackets = t.pendingPackets[:0]
+		// Unblock writePacket if waiting for KEX.
+		t.writeCond.Broadcast()
 		t.mu.Unlock()
 	}
 
@@ -577,11 +594,20 @@
 	}
 
 	if t.sentInitMsg != nil {
-		// Copy the packet so the writer can reuse the buffer.
-		cp := make([]byte, len(p))
-		copy(cp, p)
-		t.pendingPackets = append(t.pendingPackets, cp)
-		return nil
+		if len(t.pendingPackets) < maxPendingPackets {
+			// Copy the packet so the writer can reuse the buffer.
+			cp := make([]byte, len(p))
+			copy(cp, p)
+			t.pendingPackets = append(t.pendingPackets, cp)
+			return nil
+		}
+		for t.sentInitMsg != nil {
+			// Block and wait for KEX to complete or an error.
+			t.writeCond.Wait()
+			if t.writeError != nil {
+				return t.writeError
+			}
+		}
 	}
 
 	if t.writeBytesLeft > 0 {
@@ -598,6 +624,7 @@
 
 	if err := t.pushPacket(p); err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 
 	return nil
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 2bc607b..019e47f 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -539,6 +539,226 @@
 	}
 }
 
+type mockKeyingTransport struct {
+	packetConn
+	kexInitAllowed chan struct{}
+	kexInitSent    chan struct{}
+}
+
+func (n *mockKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+	return nil
+}
+
+func (n *mockKeyingTransport) writePacket(packet []byte) error {
+	if packet[0] == msgKexInit {
+		<-n.kexInitAllowed
+		n.kexInitSent <- struct{}{}
+	}
+	return n.packetConn.writePacket(packet)
+}
+
+func (n *mockKeyingTransport) readPacket() ([]byte, error) {
+	return n.packetConn.readPacket()
+}
+
+func (n *mockKeyingTransport) setStrictMode() error { return nil }
+
+func (n *mockKeyingTransport) setInitialKEXDone() {}
+
+func TestHandshakePendingPacketsWait(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	server.mu.Lock()
+	if len(server.pendingPackets) != maxPendingPackets {
+		t.Errorf("unexpected pending packets size; got: %d, want: %d", len(server.pendingPackets), maxPendingPackets)
+	}
+	server.mu.Unlock()
+
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Allow the client to reply to the KEX and so unblock the write goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	<-trC.kexInitSent
+	<-writeDone
+	// Close the client to unblock the read goroutine.
+	client.Close()
+	<-readDone
+	server.Close()
+}
+
+func TestHandshakePendingPacketsError(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != io.EOF {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Close the server to unblock the write after an error
+	server.Close()
+	<-writeDone
+	// Unblock the pending write and close the client to unblock the read
+	// goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	client.Close()
+	<-readDone
+}
+
 func TestHandshakeRekeyDefault(t *testing.T) {
 	clientConf := &ClientConfig{
 		Config: Config{