ssh: limit the size of the internal packet queue while waiting for KEX
In the SSH protocol, clients and servers execute the key exchange to
generate one-time session keys used for encryption and authentication.
The key exchange is performed initially after the connection is
established and then periodically after a configurable amount of data.
While a key exchange is in progress, we add the received packets to an
internal queue until we receive SSH_MSG_KEXINIT from the other side.
This can result in high memory usage if the other party is slow to
respond to the SSH_MSG_KEXINIT packet, or memory exhaustion if a
malicious client never responds to an SSH_MSG_KEXINIT packet during a
large file transfer.
We now limit the internal queue to 64 packets: this means 2MB with the
typical 32KB packet size.
When the internal queue is full we block further writes until the
pending key exchange is completed or there is a read or write error.
Thanks to Yuichi Watanabe for reporting this issue.
Change-Id: I1ce2214cc16e08b838d4bc346c74c72addafaeec
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/652135
Reviewed-by: Neal Patel <nealpatel@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index fef687d..c9202b0 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -25,6 +25,11 @@
// quickly.
const chanSize = 16
+// maxPendingPackets sets the maximum number of packets to queue while waiting
+// for KEX to complete. This limits the total pending data to maxPendingPackets
+// * maxPacket bytes, which is ~16.8MB.
+const maxPendingPackets = 64
+
// keyingTransport is a packet based transport that supports key
// changes. It need not be thread-safe. It should pass through
// msgNewKeys in both directions.
@@ -73,11 +78,19 @@
incoming chan []byte
readError error
- mu sync.Mutex
- writeError error
- sentInitPacket []byte
- sentInitMsg *kexInitMsg
- pendingPackets [][]byte // Used when a key exchange is in progress.
+ mu sync.Mutex
+ // Condition for the above mutex. It is used to notify a completed key
+ // exchange or a write failure. Writes can wait for this condition while a
+ // key exchange is in progress.
+ writeCond *sync.Cond
+ writeError error
+ sentInitPacket []byte
+ sentInitMsg *kexInitMsg
+ // Used to queue writes when a key exchange is in progress. The length is
+ // limited by pendingPacketsSize. Once full, writes will block until the key
+ // exchange is completed or an error occurs. If not empty, it is emptied
+ // all at once when the key exchange is completed in kexLoop.
+ pendingPackets [][]byte
writePacketsLeft uint32
writeBytesLeft int64
userAuthComplete bool // whether the user authentication phase is complete
@@ -134,6 +147,7 @@
config: config,
}
+ t.writeCond = sync.NewCond(&t.mu)
t.resetReadThresholds()
t.resetWriteThresholds()
@@ -260,6 +274,7 @@
defer t.mu.Unlock()
if t.writeError == nil && err != nil {
t.writeError = err
+ t.writeCond.Broadcast()
}
}
@@ -363,6 +378,8 @@
}
}
t.pendingPackets = t.pendingPackets[:0]
+ // Unblock writePacket if waiting for KEX.
+ t.writeCond.Broadcast()
t.mu.Unlock()
}
@@ -577,11 +594,20 @@
}
if t.sentInitMsg != nil {
- // Copy the packet so the writer can reuse the buffer.
- cp := make([]byte, len(p))
- copy(cp, p)
- t.pendingPackets = append(t.pendingPackets, cp)
- return nil
+ if len(t.pendingPackets) < maxPendingPackets {
+ // Copy the packet so the writer can reuse the buffer.
+ cp := make([]byte, len(p))
+ copy(cp, p)
+ t.pendingPackets = append(t.pendingPackets, cp)
+ return nil
+ }
+ for t.sentInitMsg != nil {
+ // Block and wait for KEX to complete or an error.
+ t.writeCond.Wait()
+ if t.writeError != nil {
+ return t.writeError
+ }
+ }
}
if t.writeBytesLeft > 0 {
@@ -598,6 +624,7 @@
if err := t.pushPacket(p); err != nil {
t.writeError = err
+ t.writeCond.Broadcast()
}
return nil
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 2bc607b..019e47f 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -539,6 +539,226 @@
}
}
+type mockKeyingTransport struct {
+ packetConn
+ kexInitAllowed chan struct{}
+ kexInitSent chan struct{}
+}
+
+func (n *mockKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+ return nil
+}
+
+func (n *mockKeyingTransport) writePacket(packet []byte) error {
+ if packet[0] == msgKexInit {
+ <-n.kexInitAllowed
+ n.kexInitSent <- struct{}{}
+ }
+ return n.packetConn.writePacket(packet)
+}
+
+func (n *mockKeyingTransport) readPacket() ([]byte, error) {
+ return n.packetConn.readPacket()
+}
+
+func (n *mockKeyingTransport) setStrictMode() error { return nil }
+
+func (n *mockKeyingTransport) setInitialKEXDone() {}
+
+func TestHandshakePendingPacketsWait(t *testing.T) {
+ a, b := memPipe()
+
+ trS := &mockKeyingTransport{
+ packetConn: a,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trS.kexInitAllowed <- struct{}{}
+
+ trC := &mockKeyingTransport{
+ packetConn: b,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trC.kexInitAllowed <- struct{}{}
+
+ clientConf := &ClientConfig{
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
+
+ if err := server.waitSession(); err != nil {
+ t.Fatalf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ t.Fatalf("client.waitSession: %v", err)
+ }
+
+ <-trC.kexInitSent
+ <-trS.kexInitSent
+
+ // Allow and request new KEX server side.
+ trS.kexInitAllowed <- struct{}{}
+ server.requestKeyExchange()
+ // Wait until the KEX init is sent.
+ <-trS.kexInitSent
+ // The client is not allowed to respond to the KEX, so writes will be
+ // blocked on the server side once the packets queue is full.
+ for i := 0; i < maxPendingPackets; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := server.writePacket(p); err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }
+ // The packets queue is now full, the next write will block.
+ server.mu.Lock()
+ if len(server.pendingPackets) != maxPendingPackets {
+ t.Errorf("unexpected pending packets size; got: %d, want: %d", len(server.pendingPackets), maxPendingPackets)
+ }
+ server.mu.Unlock()
+
+ writeDone := make(chan struct{})
+ go func() {
+ defer close(writeDone)
+
+ p := []byte{msgRequestSuccess, byte(65)}
+ // This write will block until KEX completes.
+ err := server.writePacket(p)
+ if err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }()
+
+ // Consume packets on the client side
+ readDone := make(chan bool)
+ go func() {
+ defer close(readDone)
+
+ for {
+ if _, err := client.readPacket(); err != nil {
+ if err != io.EOF {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ break
+ }
+ }
+ }()
+
+ // Allow the client to reply to the KEX and so unblock the write goroutine.
+ trC.kexInitAllowed <- struct{}{}
+ <-trC.kexInitSent
+ <-writeDone
+ // Close the client to unblock the read goroutine.
+ client.Close()
+ <-readDone
+ server.Close()
+}
+
+func TestHandshakePendingPacketsError(t *testing.T) {
+ a, b := memPipe()
+
+ trS := &mockKeyingTransport{
+ packetConn: a,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trS.kexInitAllowed <- struct{}{}
+
+ trC := &mockKeyingTransport{
+ packetConn: b,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trC.kexInitAllowed <- struct{}{}
+
+ clientConf := &ClientConfig{
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
+
+ if err := server.waitSession(); err != nil {
+ t.Fatalf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ t.Fatalf("client.waitSession: %v", err)
+ }
+
+ <-trC.kexInitSent
+ <-trS.kexInitSent
+
+ // Allow and request new KEX server side.
+ trS.kexInitAllowed <- struct{}{}
+ server.requestKeyExchange()
+ // Wait until the KEX init is sent.
+ <-trS.kexInitSent
+ // The client is not allowed to respond to the KEX, so writes will be
+ // blocked on the server side once the packets queue is full.
+ for i := 0; i < maxPendingPackets; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := server.writePacket(p); err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }
+ // The packets queue is now full, the next write will block.
+ writeDone := make(chan struct{})
+ go func() {
+ defer close(writeDone)
+
+ p := []byte{msgRequestSuccess, byte(65)}
+ // This write will block until KEX completes.
+ err := server.writePacket(p)
+ if err != io.EOF {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }()
+
+ // Consume packets on the client side
+ readDone := make(chan bool)
+ go func() {
+ defer close(readDone)
+
+ for {
+ if _, err := client.readPacket(); err != nil {
+ if err != io.EOF {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ break
+ }
+ }
+ }()
+
+ // Close the server to unblock the write after an error
+ server.Close()
+ <-writeDone
+ // Unblock the pending write and close the client to unblock the read
+ // goroutine.
+ trC.kexInitAllowed <- struct{}{}
+ client.Close()
+ <-readDone
+}
+
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{