ssh: make sure we execute the initial key exchange only once

The initial kex is started from both sides simultaneously, and before,
we could consume the the incoming kex request before we consumed from
our internal channel. This would result in initiating a key exchange
just after completing the initial one, which is not only an extra
delay, but also an error when using OpenSSH (OpenSSH does not support
key exchanges during user authentication).

Change-Id: Ia7e0748ea2bca80ae97d187bcf2931ab6422276b
Reviewed-on: https://go-review.googlesource.com/35851
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index e68e058..93c23d1 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -65,8 +65,9 @@
 	pendingPackets [][]byte // Used when a key exchange is in progress.
 
 	// If the read loop wants to schedule a kex, it pings this
-	// channel, and the write loop will send out a kex message.
-	requestKex chan struct{}
+	// channel, and the write loop will send out a kex
+	// message. The boolean is whether this is the first request or not.
+	requestKex chan bool
 
 	// If the other side requests or confirms a kex, its kexInit
 	// packet is sent here for the write loop to find it.
@@ -96,11 +97,14 @@
 		serverVersion: serverVersion,
 		clientVersion: clientVersion,
 		incoming:      make(chan []byte, chanSize),
-		requestKex:    make(chan struct{}, 1),
+		requestKex:    make(chan bool, 1),
 		startKex:      make(chan *pendingKex, 1),
 
 		config: config,
 	}
+
+	// We always start with a mandatory key exchange.
+	t.requestKex <- true
 	return t
 }
 
@@ -174,12 +178,6 @@
 }
 
 func (t *handshakeTransport) readLoop() {
-	// We always start with the mandatory key exchange.  We use
-	// the channel for simplicity, and this works if we can rely
-	// on the SSH package itself not doing anything else before
-	// waitSession has completed.
-	t.requestKeyExchange()
-
 	first := true
 	for {
 		p, err := t.readOnePacket(first)
@@ -227,14 +225,15 @@
 
 func (t *handshakeTransport) requestKeyExchange() {
 	select {
-	case t.requestKex <- struct{}{}:
+	case t.requestKex <- false:
 	default:
 		// something already requested a kex, so do nothing.
 	}
-
 }
 
 func (t *handshakeTransport) kexLoop() {
+	firstSent := false
+
 write:
 	for t.getWriteError() == nil {
 		var request *pendingKex
@@ -247,7 +246,18 @@
 				if !ok {
 					break write
 				}
-			case <-t.requestKex:
+			case requestFirst := <-t.requestKex:
+				// For the first key exchange, both
+				// sides will initiate a key exchange,
+				// and both channels will fire. To
+				// avoid doing two key exchanges in a
+				// row, ignore our own request for an
+				// initial kex if we have already sent
+				// it out.
+				if firstSent && requestFirst {
+
+					continue
+				}
 			}
 
 			if !sent {
@@ -255,6 +265,7 @@
 					t.recordWriteError(err)
 					break
 				}
+				firstSent = true
 				sent = true
 			}
 		}