ssh: rewrite (re)keying logic.

Use channels and a dedicated write loop for managing the rekeying
process.  This lets us collect packets to be written while a key
exchange is in progress.

Previously, the read loop ran the key exchange, and writers would
block if a key exchange was going on. If a reader wrote back a packet
while processing a read packet, it could block, stopping the read
loop, thus causing a deadlock.  Such coupled read/writes are inherent
with handling requests that want a response (eg. keepalive,
opening/closing channels etc.). The buffered channels (most channels
have capacity 16) papered over these problems, but under load SSH
connections would occasionally deadlock.

Fixes #18439.

Change-Id: I7c14ff4991fa3100a5d36025125d0cf1119c471d
Reviewed-on: https://go-review.googlesource.com/35012
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 4cf5368..530d7d2 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -110,6 +110,13 @@
 	serverConf.SetDefaults()
 	server = newServerTransport(trS, v, v, serverConf)
 
+	if err := server.waitSession(); err != nil {
+		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
+	}
+
 	return client, server, nil
 }
 
@@ -117,8 +124,9 @@
 	if runtime.GOOS == "plan9" {
 		t.Skip("see golang.org/issue/7237")
 	}
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", true)
+
+	checker := &syncChecker{make(chan int, 10)}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	if err != nil {
 		t.Fatalf("handshakePair: %v", err)
 	}
@@ -126,7 +134,11 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	<-checker.called
+
 	clientDone := make(chan int, 0)
+	gotHalf := make(chan int, 0)
+
 	go func() {
 		defer close(clientDone)
 		// Client writes a bunch of stuff, and does a key
@@ -138,33 +150,35 @@
 				t.Fatalf("sendPacket: %v", err)
 			}
 			if i == 5 {
+				<-gotHalf
 				// halfway through, we request a key change.
-				err := trC.sendKexInit(subsequentKeyExchange)
-				if err != nil {
-					t.Fatalf("sendKexInit: %v", err)
-				}
+				trC.requestKeyExchange()
+
+				// Wait until we can be sure the key
+				// change has really started before we
+				// write more.
+				<-checker.called
 			}
 		}
-		trC.Close()
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
 	err = nil
-	for {
+	for ; i < 10; i++ {
 		var p []byte
 		p, err = trS.readPacket()
 		if err != nil {
 			break
 		}
-		if p[0] == msgNewKeys {
-			continue
+		if i == 5 {
+			gotHalf <- 1
 		}
+
 		want := []byte{msgRequestSuccess, byte(i)}
 		if bytes.Compare(p, want) != 0 {
 			t.Errorf("message %d: got %q, want %q", i, p, want)
 		}
-		i++
 	}
 	<-clientDone
 	if err != nil && err != io.EOF {
@@ -174,150 +188,58 @@
 		t.Errorf("received %d messages, want 10.", i)
 	}
 
-	// If all went well, we registered exactly 1 key change.
-	if len(checker.calls) != 1 {
-		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
-	}
-
-	pub := testSigners["ecdsa"].PublicKey()
-	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
-	if want != checker.calls[0] {
-		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
-	}
-
-}
-
-func TestHandshakeError(t *testing.T) {
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad", false)
-	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
-	}
-	defer trC.Close()
-	defer trS.Close()
-
-	// send a packet
-	packet := []byte{msgRequestSuccess, 42}
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-
-	// Now request a key change.
-	err = trC.sendKexInit(subsequentKeyExchange)
-	if err != nil {
-		t.Errorf("sendKexInit: %v", err)
-	}
-
-	// the key change will fail, and afterwards we can't write.
-	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
-		t.Errorf("writePacket after botched rekey succeeded.")
-	}
-
-	readback, err := trS.readPacket()
-	if err != nil {
-		t.Fatalf("server closed too soon: %v", err)
-	}
-	if bytes.Compare(readback, packet) != 0 {
-		t.Errorf("got %q want %q", readback, packet)
-	}
-	readback, err = trS.readPacket()
-	if err == nil {
-		t.Errorf("got a message %q after failed key change", readback)
+	close(checker.called)
+	if _, ok := <-checker.called; ok {
+		// If all went well, we registered exactly 2 key changes: one
+		// that establishes the session, and one that we requested
+		// additionally.
+		t.Fatalf("got another host key checks after 2 handshakes")
 	}
 }
 
 func TestForceFirstKex(t *testing.T) {
+	// like handshakePair, but must access the keyingTransport.
 	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+	a, b, err := netPipe()
 	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
+		t.Fatalf("netPipe: %v", err)
 	}
 
-	defer trC.Close()
-	defer trS.Close()
+	var trC, trS keyingTransport
 
+	trC = newTransport(a, rand.Reader, true)
+
+	// This is the disallowed packet:
 	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
 
+	// Rest of the setup.
+	trS = newTransport(b, rand.Reader, false)
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	defer client.Close()
+	defer server.Close()
+
 	// We setup the initial key exchange, but the remote side
 	// tries to send serviceRequestMsg in cleartext, which is
 	// disallowed.
 
-	err = trS.sendKexInit(firstKeyExchange)
-	if err == nil {
+	if err := server.waitSession(); err == nil {
 		t.Errorf("server first kex init should reject unexpected packet")
 	}
 }
 
-func TestHandshakeTwice(t *testing.T) {
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
-	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
-	}
-
-	defer trC.Close()
-	defer trS.Close()
-
-	// Both sides should ask for the first key exchange first.
-	err = trS.sendKexInit(firstKeyExchange)
-	if err != nil {
-		t.Errorf("server sendKexInit: %v", err)
-	}
-
-	err = trC.sendKexInit(firstKeyExchange)
-	if err != nil {
-		t.Errorf("client sendKexInit: %v", err)
-	}
-
-	sent := 0
-	// send a packet
-	packet := make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	// Send another packet. Use a fresh one, since writePacket destroys.
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	// 2nd key change.
-	err = trC.sendKexInit(subsequentKeyExchange)
-	if err != nil {
-		t.Errorf("sendKexInit: %v", err)
-	}
-
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	for i := 0; i < sent; i++ {
-		msg, err := trS.readPacket()
-		if err != nil {
-			t.Fatalf("server closed too soon: %v", err)
-		}
-
-		if bytes.Compare(msg, packet) != 0 {
-			t.Errorf("packet %d: got %q want %q", i, msg, packet)
-		}
-	}
-	if len(checker.calls) != 2 {
-		t.Errorf("got %d key changes, want 2", len(checker.calls))
-	}
-}
-
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
-	checker := &testChecker{}
+	checker := &syncChecker{make(chan int, 10)}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf.RekeyThreshold = 500
 	trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -327,12 +249,19 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	<-checker.called
+
 	for i := 0; i < 5; i++ {
 		packet := make([]byte, 251)
 		packet[0] = msgRequestSuccess
 		if err := trC.writePacket(packet); err != nil {
 			t.Errorf("writePacket: %v", err)
 		}
+		if i == 2 {
+			// Make sure the kex is in progress.
+			<-checker.called
+		}
+
 	}
 
 	j := 0
@@ -346,18 +275,14 @@
 	if j != 5 {
 		t.Errorf("got %d, want 5 messages", j)
 	}
-
-	if len(checker.calls) != 2 {
-		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
-	}
 }
 
 type syncChecker struct {
 	called chan int
 }
 
-func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
-	t.called <- 1
+func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+	c.called <- 1
 	return nil
 }
 
@@ -399,6 +324,7 @@
 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
 	return nil
 }
+
 func (n *errorKeyingTransport) getSessionID() []byte {
 	return nil
 }
@@ -425,20 +351,32 @@
 
 func TestHandshakeErrorHandlingRead(t *testing.T) {
 	for i := 0; i < 20; i++ {
-		testHandshakeErrorHandlingN(t, i, -1)
+		testHandshakeErrorHandlingN(t, i, -1, false)
 	}
 }
 
 func TestHandshakeErrorHandlingWrite(t *testing.T) {
 	for i := 0; i < 20; i++ {
-		testHandshakeErrorHandlingN(t, -1, i)
+		testHandshakeErrorHandlingN(t, -1, i, false)
+	}
+}
+
+func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, i, -1, true)
+	}
+}
+
+func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, -1, i, true)
 	}
 }
 
 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
 // handshakeTransport deadlocks, the go runtime will detect it and
 // panic.
-func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
 	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
 
 	a, b := memPipe()
@@ -451,37 +389,57 @@
 	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
 	serverConn.hostKeys = []Signer{key}
 	go serverConn.readLoop()
+	go serverConn.kexLoop()
 
 	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
 	clientConf.SetDefaults()
 	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
 	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
 	go clientConn.readLoop()
+	go clientConn.kexLoop()
 
 	var wg sync.WaitGroup
-	wg.Add(4)
 
 	for _, hs := range []packetConn{serverConn, clientConn} {
-		go func(c packetConn) {
-			for {
-				err := c.writePacket(msg)
-				if err != nil {
-					break
+		if !coupled {
+			wg.Add(2)
+			go func(c packetConn) {
+				for i := 0; ; i++ {
+					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
+					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
+					if err != nil {
+						break
+					}
 				}
-			}
-			wg.Done()
-		}(hs)
-		go func(c packetConn) {
-			for {
-				_, err := c.readPacket()
-				if err != nil {
-					break
+				wg.Done()
+				c.Close()
+			}(hs)
+			go func(c packetConn) {
+				for {
+					_, err := c.readPacket()
+					if err != nil {
+						break
+					}
 				}
-			}
-			wg.Done()
-		}(hs)
-	}
+				wg.Done()
+			}(hs)
+		} else {
+			wg.Add(1)
+			go func(c packetConn) {
+				for {
+					_, err := c.readPacket()
+					if err != nil {
+						break
+					}
+					if err := c.writePacket(msg); err != nil {
+						break
+					}
 
+				}
+				wg.Done()
+			}(hs)
+		}
+	}
 	wg.Wait()
 }