ssh: ensure that handshakeTransport goroutines have finished before Close returns

This fixes a data race in the tests for x/crypto/ssh, which expects to
be able to examine a transport's read and write counters without
locking after closing it.

(Given the number of goroutines, channels, and mutexes used in this
package, I wouldn't be surprised if other concurrency bugs remain.
I would suggest simplifying the concurrency in this package, but I
don't intend to follow up on that myself at the moment.)

Fixes golang/go#56957.

Change-Id: Ib1f1390b66707c66a3608e48f3f52483cff3c1f5
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/456758
Reviewed-by: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 2b84c35..07a1843 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -58,11 +58,13 @@
 	incoming  chan []byte
 	readError error
 
-	mu             sync.Mutex
-	writeError     error
-	sentInitPacket []byte
-	sentInitMsg    *kexInitMsg
-	pendingPackets [][]byte // Used when a key exchange is in progress.
+	mu               sync.Mutex
+	writeError       error
+	sentInitPacket   []byte
+	sentInitMsg      *kexInitMsg
+	pendingPackets   [][]byte // Used when a key exchange is in progress.
+	writePacketsLeft uint32
+	writeBytesLeft   int64
 
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
@@ -71,7 +73,8 @@
 
 	// If the other side requests or confirms a kex, its kexInit
 	// packet is sent here for the write loop to find it.
-	startKex chan *pendingKex
+	startKex    chan *pendingKex
+	kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
 
 	// data for host key checking
 	hostKeyCallback HostKeyCallback
@@ -86,12 +89,10 @@
 	// Algorithms agreed in the last key exchange.
 	algorithms *algorithms
 
+	// Counters exclusively owned by readLoop.
 	readPacketsLeft uint32
 	readBytesLeft   int64
 
-	writePacketsLeft uint32
-	writeBytesLeft   int64
-
 	// The session ID or nil if first kex did not complete yet.
 	sessionID []byte
 }
@@ -108,7 +109,8 @@
 		clientVersion: clientVersion,
 		incoming:      make(chan []byte, chanSize),
 		requestKex:    make(chan struct{}, 1),
-		startKex:      make(chan *pendingKex, 1),
+		startKex:      make(chan *pendingKex),
+		kexLoopDone:   make(chan struct{}),
 
 		config: config,
 	}
@@ -340,16 +342,17 @@
 		t.mu.Unlock()
 	}
 
-	// drain startKex channel. We don't service t.requestKex
-	// because nobody does blocking sends there.
-	go func() {
-		for init := range t.startKex {
-			init.done <- t.writeError
-		}
-	}()
-
 	// Unblock reader.
 	t.conn.Close()
+
+	// drain startKex channel. We don't service t.requestKex
+	// because nobody does blocking sends there.
+	for request := range t.startKex {
+		request.done <- t.getWriteError()
+	}
+
+	// Mark that the loop is done so that Close can return.
+	close(t.kexLoopDone)
 }
 
 // The protocol uses uint32 for packet counters, so we can't let them
@@ -545,7 +548,16 @@
 }
 
 func (t *handshakeTransport) Close() error {
-	return t.conn.Close()
+	// Close the connection. This should cause the readLoop goroutine to wake up
+	// and close t.startKex, which will shut down kexLoop if running.
+	err := t.conn.Close()
+
+	// Wait for the kexLoop goroutine to complete.
+	// At that point we know that the readLoop goroutine is complete too,
+	// because kexLoop itself waits for readLoop to close the startKex channel.
+	<-t.kexLoopDone
+
+	return err
 }
 
 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {