ssh: rewrite (re)keying logic.

Use channels and a dedicated write loop for managing the rekeying
process.  This lets us collect packets to be written while a key
exchange is in progress.

Previously, the read loop ran the key exchange, and writers would
block if a key exchange was going on. If a reader wrote back a packet
while processing a read packet, it could block, stopping the read
loop, thus causing a deadlock.  Such coupled read/writes are inherent
with handling requests that want a response (eg. keepalive,
opening/closing channels etc.). The buffered channels (most channels
have capacity 16) papered over these problems, but under load SSH
connections would occasionally deadlock.

Fixes #18439.

Change-Id: I7c14ff4991fa3100a5d36025125d0cf1119c471d
Reviewed-on: https://go-review.googlesource.com/35012
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/client.go b/ssh/client.go
index 0212a20..c841e8d 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -97,13 +97,11 @@
 	c.transport = newClientTransport(
 		newTransport(c.sshConn.conn, config.Rand, true /* is client */),
 		c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
-	if err := c.transport.requestInitialKeyChange(); err != nil {
+	if err := c.transport.waitSession(); err != nil {
 		return err
 	}
 
-	// We just did the key change, so the session ID is established.
 	c.sessionID = c.transport.getSessionID()
-
 	return c.clientAuthenticate(config)
 }
 
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 294af0d..fd1ec5d 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -30,8 +30,10 @@
 	// then any untried methods suggested by the server.
 	tried := make(map[string]bool)
 	var lastMethods []string
+
+	sessionID := c.transport.getSessionID()
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
-		ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
+		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
 		if err != nil {
 			return err
 		}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 37d42e4..03c950d 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -53,6 +53,20 @@
 	incoming  chan []byte
 	readError error
 
+	mu             sync.Mutex
+	writeError     error
+	sentInitPacket []byte
+	sentInitMsg    *kexInitMsg
+	pendingPackets [][]byte // Used when a key exchange is in progress.
+
+	// If the read loop wants to schedule a kex, it pings this
+	// channel, and the write loop will send out a kex message.
+	requestKex chan struct{}
+
+	// If the other side requests or confirms a kex, its kexInit
+	// packet is sent here for the write loop to find it.
+	startKex chan *pendingKex
+
 	// data for host key checking
 	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
 	dialAddress     string
@@ -60,27 +74,28 @@
 
 	readSinceKex uint64
 
-	// Protects the writing side of the connection
-	mu              sync.Mutex
-	cond            *sync.Cond
-	sentInitPacket  []byte
-	sentInitMsg     *kexInitMsg
 	writtenSinceKex uint64
-	writeError      error
 
 	// The session ID or nil if first kex did not complete yet.
 	sessionID []byte
 }
 
+type pendingKex struct {
+	otherInit []byte
+	done      chan error
+}
+
 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
 	t := &handshakeTransport{
 		conn:          conn,
 		serverVersion: serverVersion,
 		clientVersion: clientVersion,
 		incoming:      make(chan []byte, 16),
-		config:        config,
+		requestKex:    make(chan struct{}, 1),
+		startKex:      make(chan *pendingKex, 1),
+
+		config: config,
 	}
-	t.cond = sync.NewCond(&t.mu)
 	return t
 }
 
@@ -95,6 +110,7 @@
 		t.hostKeyAlgorithms = supportedHostKeyAlgos
 	}
 	go t.readLoop()
+	go t.kexLoop()
 	return t
 }
 
@@ -102,6 +118,7 @@
 	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 	t.hostKeys = config.hostKeys
 	go t.readLoop()
+	go t.kexLoop()
 	return t
 }
 
@@ -109,6 +126,20 @@
 	return t.sessionID
 }
 
+// waitSession waits for the session to be established. This should be
+// the first thing to call after instantiating handshakeTransport.
+func (t *handshakeTransport) waitSession() error {
+	p, err := t.readPacket()
+	if err != nil {
+		return err
+	}
+	if p[0] != msgNewKeys {
+		return fmt.Errorf("ssh: first packet should be msgNewKeys")
+	}
+
+	return nil
+}
+
 func (t *handshakeTransport) id() string {
 	if len(t.hostKeys) > 0 {
 		return "server"
@@ -116,6 +147,19 @@
 	return "client"
 }
 
+func (t *handshakeTransport) printPacket(p []byte, write bool) {
+	action := "got"
+	if write {
+		action = "sent"
+	}
+	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
+		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
+	} else {
+		msg, err := decode(p)
+		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
+	}
+}
+
 func (t *handshakeTransport) readPacket() ([]byte, error) {
 	p, ok := <-t.incoming
 	if !ok {
@@ -125,8 +169,16 @@
 }
 
 func (t *handshakeTransport) readLoop() {
+	// We always start with the mandatory key exchange.  We use
+	// the channel for simplicity, and this works if we can rely
+	// on the SSH package itself not doing anything else before
+	// waitSession has completed.
+	t.requestKeyExchange()
+
+	first := true
 	for {
-		p, err := t.readOnePacket()
+		p, err := t.readOnePacket(first)
+		first = false
 		if err != nil {
 			t.readError = err
 			close(t.incoming)
@@ -138,20 +190,123 @@
 		t.incoming <- p
 	}
 
-	// If we can't read, declare the writing part dead too.
-	t.mu.Lock()
-	defer t.mu.Unlock()
-	if t.writeError == nil {
-		t.writeError = t.readError
-	}
-	t.cond.Broadcast()
+	// Stop writers too.
+	t.recordWriteError(t.readError)
+
+	// Unblock the writer should it wait for this.
+	close(t.startKex)
+
+	// Don't close t.requestKex; it's also written to from writePacket.
 }
 
-func (t *handshakeTransport) readOnePacket() ([]byte, error) {
-	if t.readSinceKex > t.config.RekeyThreshold {
-		if err := t.requestKeyChange(); err != nil {
-			return nil, err
+func (t *handshakeTransport) pushPacket(p []byte) error {
+	if debugHandshake {
+		t.printPacket(p, true)
+	}
+	return t.conn.writePacket(p)
+}
+
+func (t *handshakeTransport) getWriteError() error {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	return t.writeError
+}
+
+func (t *handshakeTransport) recordWriteError(err error) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.writeError == nil && err != nil {
+		t.writeError = err
+	}
+}
+
+func (t *handshakeTransport) requestKeyExchange() {
+	select {
+	case t.requestKex <- struct{}{}:
+	default:
+		// something already requested a kex, so do nothing.
+	}
+
+}
+
+func (t *handshakeTransport) kexLoop() {
+write:
+	for t.getWriteError() == nil {
+		var request *pendingKex
+		var sent bool
+
+		for request == nil || !sent {
+			var ok bool
+			select {
+			case request, ok = <-t.startKex:
+				if !ok {
+					break write
+				}
+			case <-t.requestKex:
+			}
+
+			if !sent {
+				if err := t.sendKexInit(); err != nil {
+					t.recordWriteError(err)
+					break
+				}
+				sent = true
+			}
 		}
+
+		if err := t.getWriteError(); err != nil {
+			if request != nil {
+				request.done <- err
+			}
+			break
+		}
+
+		// We're not servicing t.requestKex, but that is OK:
+		// we never block on sending to t.requestKex.
+
+		// We're not servicing t.startKex, but the remote end
+		// has just sent us a kexInitMsg, so it can't send
+		// another key change request.
+
+		err := t.enterKeyExchange(request.otherInit)
+
+		t.mu.Lock()
+		t.writeError = err
+		t.sentInitPacket = nil
+		t.sentInitMsg = nil
+		t.writtenSinceKex = 0
+		request.done <- t.writeError
+
+		// kex finished. Push packets that we received while
+		// the kex was in progress. Don't look at t.startKex
+		// and don't increment writtenSinceKex: if we trigger
+		// another kex while we are still busy with the last
+		// one, things will become very confusing.
+		for _, p := range t.pendingPackets {
+			t.writeError = t.pushPacket(p)
+			if t.writeError != nil {
+				break
+			}
+		}
+		t.pendingPackets = t.pendingPackets[0:]
+		t.mu.Unlock()
+	}
+
+	// drain startKex channel. We don't service t.requestKex
+	// because nobody does blocking sends there.
+	go func() {
+		for init := range t.startKex {
+			init.done <- t.writeError
+		}
+	}()
+
+	// Unblock reader.
+	t.conn.Close()
+}
+
+func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
+	if t.readSinceKex > t.config.RekeyThreshold {
+		t.requestKeyExchange()
 	}
 
 	p, err := t.conn.readPacket()
@@ -161,39 +316,30 @@
 
 	t.readSinceKex += uint64(len(p))
 	if debugHandshake {
-		if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
-			log.Printf("%s got data (packet %d bytes)", t.id(), len(p))
-		} else {
-			msg, err := decode(p)
-			log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
-		}
+		t.printPacket(p, false)
 	}
+
+	if first && p[0] != msgKexInit {
+		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
+	}
+
 	if p[0] != msgKexInit {
 		return p, nil
 	}
 
-	t.mu.Lock()
-
 	firstKex := t.sessionID == nil
 
-	err = t.enterKeyExchangeLocked(p)
-	if err != nil {
-		// drop connection
-		t.conn.Close()
-		t.writeError = err
+	kex := pendingKex{
+		done:      make(chan error, 1),
+		otherInit: p,
 	}
+	t.startKex <- &kex
+	err = <-kex.done
 
 	if debugHandshake {
 		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
 	}
 
-	// Unblock writers.
-	t.sentInitMsg = nil
-	t.sentInitPacket = nil
-	t.cond.Broadcast()
-	t.writtenSinceKex = 0
-	t.mu.Unlock()
-
 	if err != nil {
 		return nil, err
 	}
@@ -213,61 +359,16 @@
 	return successPacket, nil
 }
 
-// keyChangeCategory describes whether a key exchange is the first on a
-// connection, or a subsequent one.
-type keyChangeCategory bool
-
-const (
-	firstKeyExchange      keyChangeCategory = true
-	subsequentKeyExchange keyChangeCategory = false
-)
-
-// sendKexInit sends a key change message, and returns the message
-// that was sent. After initiating the key change, all writes will be
-// blocked until the change is done, and a failed key change will
-// close the underlying transport. This function is safe for
-// concurrent use by multiple goroutines.
-func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
-	var err error
-
+// sendKexInit sends a key change message.
+func (t *handshakeTransport) sendKexInit() error {
 	t.mu.Lock()
-	// If this is the initial key change, but we already have a sessionID,
-	// then do nothing because the key exchange has already completed
-	// asynchronously.
-	if !isFirst || t.sessionID == nil {
-		_, _, err = t.sendKexInitLocked(isFirst)
-	}
-	t.mu.Unlock()
-	if err != nil {
-		return err
-	}
-	if isFirst {
-		if packet, err := t.readPacket(); err != nil {
-			return err
-		} else if packet[0] != msgNewKeys {
-			return unexpectedMessageError(msgNewKeys, packet[0])
-		}
-	}
-	return nil
-}
-
-func (t *handshakeTransport) requestInitialKeyChange() error {
-	return t.sendKexInit(firstKeyExchange)
-}
-
-func (t *handshakeTransport) requestKeyChange() error {
-	return t.sendKexInit(subsequentKeyExchange)
-}
-
-// sendKexInitLocked sends a key change message. t.mu must be locked
-// while this happens.
-func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
-	// kexInits may be sent either in response to the other side,
-	// or because our side wants to initiate a key change, so we
-	// may have already sent a kexInit. In that case, don't send a
-	// second kexInit.
+	defer t.mu.Unlock()
 	if t.sentInitMsg != nil {
-		return t.sentInitMsg, t.sentInitPacket, nil
+		// kexInits may be sent either in response to the other side,
+		// or because our side wants to initiate a key change, so we
+		// may have already sent a kexInit. In that case, don't send a
+		// second kexInit.
+		return nil
 	}
 
 	msg := &kexInitMsg{
@@ -295,53 +396,57 @@
 	packetCopy := make([]byte, len(packet))
 	copy(packetCopy, packet)
 
-	if err := t.conn.writePacket(packetCopy); err != nil {
-		return nil, nil, err
+	if err := t.pushPacket(packetCopy); err != nil {
+		return err
 	}
 
 	t.sentInitMsg = msg
 	t.sentInitPacket = packet
-	return msg, packet, nil
+
+	return nil
 }
 
 func (t *handshakeTransport) writePacket(p []byte) error {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
-	if t.writtenSinceKex > t.config.RekeyThreshold {
-		t.sendKexInitLocked(subsequentKeyExchange)
-	}
-	for t.sentInitMsg != nil && t.writeError == nil {
-		t.cond.Wait()
-	}
-	if t.writeError != nil {
-		return t.writeError
-	}
-	t.writtenSinceKex += uint64(len(p))
-
 	switch p[0] {
 	case msgKexInit:
 		return errors.New("ssh: only handshakeTransport can send kexInit")
 	case msgNewKeys:
 		return errors.New("ssh: only handshakeTransport can send newKeys")
-	default:
-		return t.conn.writePacket(p)
 	}
+
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.writeError != nil {
+		return t.writeError
+	}
+
+	if t.sentInitMsg != nil {
+		// Copy the packet so the writer can reuse the buffer.
+		cp := make([]byte, len(p))
+		copy(cp, p)
+		t.pendingPackets = append(t.pendingPackets, cp)
+		return nil
+	}
+	t.writtenSinceKex += uint64(len(p))
+	if t.writtenSinceKex > t.config.RekeyThreshold {
+		t.requestKeyExchange()
+	}
+
+	if err := t.pushPacket(p); err != nil {
+		t.writeError = err
+	}
+
+	return nil
 }
 
 func (t *handshakeTransport) Close() error {
 	return t.conn.Close()
 }
 
-// enterKeyExchange runs the key exchange. t.mu must be held while running this.
-func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
+func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	if debugHandshake {
 		log.Printf("%s entered key exchange", t.id())
 	}
-	myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
-	if err != nil {
-		return err
-	}
 
 	otherInit := &kexInitMsg{}
 	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
@@ -352,16 +457,15 @@
 		clientVersion: t.clientVersion,
 		serverVersion: t.serverVersion,
 		clientKexInit: otherInitPacket,
-		serverKexInit: myInitPacket,
+		serverKexInit: t.sentInitPacket,
 	}
 
 	clientInit := otherInit
-	serverInit := myInit
+	serverInit := t.sentInitMsg
 	if len(t.hostKeys) == 0 {
-		clientInit = myInit
-		serverInit = otherInit
+		clientInit, serverInit = serverInit, clientInit
 
-		magics.clientKexInit = myInitPacket
+		magics.clientKexInit = t.sentInitPacket
 		magics.serverKexInit = otherInitPacket
 	}
 
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 4cf5368..530d7d2 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -110,6 +110,13 @@
 	serverConf.SetDefaults()
 	server = newServerTransport(trS, v, v, serverConf)
 
+	if err := server.waitSession(); err != nil {
+		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
+	}
+
 	return client, server, nil
 }
 
@@ -117,8 +124,9 @@
 	if runtime.GOOS == "plan9" {
 		t.Skip("see golang.org/issue/7237")
 	}
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", true)
+
+	checker := &syncChecker{make(chan int, 10)}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	if err != nil {
 		t.Fatalf("handshakePair: %v", err)
 	}
@@ -126,7 +134,11 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	<-checker.called
+
 	clientDone := make(chan int, 0)
+	gotHalf := make(chan int, 0)
+
 	go func() {
 		defer close(clientDone)
 		// Client writes a bunch of stuff, and does a key
@@ -138,33 +150,35 @@
 				t.Fatalf("sendPacket: %v", err)
 			}
 			if i == 5 {
+				<-gotHalf
 				// halfway through, we request a key change.
-				err := trC.sendKexInit(subsequentKeyExchange)
-				if err != nil {
-					t.Fatalf("sendKexInit: %v", err)
-				}
+				trC.requestKeyExchange()
+
+				// Wait until we can be sure the key
+				// change has really started before we
+				// write more.
+				<-checker.called
 			}
 		}
-		trC.Close()
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
 	err = nil
-	for {
+	for ; i < 10; i++ {
 		var p []byte
 		p, err = trS.readPacket()
 		if err != nil {
 			break
 		}
-		if p[0] == msgNewKeys {
-			continue
+		if i == 5 {
+			gotHalf <- 1
 		}
+
 		want := []byte{msgRequestSuccess, byte(i)}
 		if bytes.Compare(p, want) != 0 {
 			t.Errorf("message %d: got %q, want %q", i, p, want)
 		}
-		i++
 	}
 	<-clientDone
 	if err != nil && err != io.EOF {
@@ -174,150 +188,58 @@
 		t.Errorf("received %d messages, want 10.", i)
 	}
 
-	// If all went well, we registered exactly 1 key change.
-	if len(checker.calls) != 1 {
-		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
-	}
-
-	pub := testSigners["ecdsa"].PublicKey()
-	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
-	if want != checker.calls[0] {
-		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
-	}
-
-}
-
-func TestHandshakeError(t *testing.T) {
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad", false)
-	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
-	}
-	defer trC.Close()
-	defer trS.Close()
-
-	// send a packet
-	packet := []byte{msgRequestSuccess, 42}
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-
-	// Now request a key change.
-	err = trC.sendKexInit(subsequentKeyExchange)
-	if err != nil {
-		t.Errorf("sendKexInit: %v", err)
-	}
-
-	// the key change will fail, and afterwards we can't write.
-	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
-		t.Errorf("writePacket after botched rekey succeeded.")
-	}
-
-	readback, err := trS.readPacket()
-	if err != nil {
-		t.Fatalf("server closed too soon: %v", err)
-	}
-	if bytes.Compare(readback, packet) != 0 {
-		t.Errorf("got %q want %q", readback, packet)
-	}
-	readback, err = trS.readPacket()
-	if err == nil {
-		t.Errorf("got a message %q after failed key change", readback)
+	close(checker.called)
+	if _, ok := <-checker.called; ok {
+		// If all went well, we registered exactly 2 key changes: one
+		// that establishes the session, and one that we requested
+		// additionally.
+		t.Fatalf("got another host key checks after 2 handshakes")
 	}
 }
 
 func TestForceFirstKex(t *testing.T) {
+	// like handshakePair, but must access the keyingTransport.
 	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+	a, b, err := netPipe()
 	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
+		t.Fatalf("netPipe: %v", err)
 	}
 
-	defer trC.Close()
-	defer trS.Close()
+	var trC, trS keyingTransport
 
+	trC = newTransport(a, rand.Reader, true)
+
+	// This is the disallowed packet:
 	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
 
+	// Rest of the setup.
+	trS = newTransport(b, rand.Reader, false)
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	defer client.Close()
+	defer server.Close()
+
 	// We setup the initial key exchange, but the remote side
 	// tries to send serviceRequestMsg in cleartext, which is
 	// disallowed.
 
-	err = trS.sendKexInit(firstKeyExchange)
-	if err == nil {
+	if err := server.waitSession(); err == nil {
 		t.Errorf("server first kex init should reject unexpected packet")
 	}
 }
 
-func TestHandshakeTwice(t *testing.T) {
-	checker := &testChecker{}
-	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
-	if err != nil {
-		t.Fatalf("handshakePair: %v", err)
-	}
-
-	defer trC.Close()
-	defer trS.Close()
-
-	// Both sides should ask for the first key exchange first.
-	err = trS.sendKexInit(firstKeyExchange)
-	if err != nil {
-		t.Errorf("server sendKexInit: %v", err)
-	}
-
-	err = trC.sendKexInit(firstKeyExchange)
-	if err != nil {
-		t.Errorf("client sendKexInit: %v", err)
-	}
-
-	sent := 0
-	// send a packet
-	packet := make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	// Send another packet. Use a fresh one, since writePacket destroys.
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	// 2nd key change.
-	err = trC.sendKexInit(subsequentKeyExchange)
-	if err != nil {
-		t.Errorf("sendKexInit: %v", err)
-	}
-
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	if err := trC.writePacket(packet); err != nil {
-		t.Errorf("writePacket: %v", err)
-	}
-	sent++
-
-	packet = make([]byte, 5)
-	packet[0] = msgRequestSuccess
-	for i := 0; i < sent; i++ {
-		msg, err := trS.readPacket()
-		if err != nil {
-			t.Fatalf("server closed too soon: %v", err)
-		}
-
-		if bytes.Compare(msg, packet) != 0 {
-			t.Errorf("packet %d: got %q want %q", i, msg, packet)
-		}
-	}
-	if len(checker.calls) != 2 {
-		t.Errorf("got %d key changes, want 2", len(checker.calls))
-	}
-}
-
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
-	checker := &testChecker{}
+	checker := &syncChecker{make(chan int, 10)}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf.RekeyThreshold = 500
 	trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -327,12 +249,19 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	<-checker.called
+
 	for i := 0; i < 5; i++ {
 		packet := make([]byte, 251)
 		packet[0] = msgRequestSuccess
 		if err := trC.writePacket(packet); err != nil {
 			t.Errorf("writePacket: %v", err)
 		}
+		if i == 2 {
+			// Make sure the kex is in progress.
+			<-checker.called
+		}
+
 	}
 
 	j := 0
@@ -346,18 +275,14 @@
 	if j != 5 {
 		t.Errorf("got %d, want 5 messages", j)
 	}
-
-	if len(checker.calls) != 2 {
-		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
-	}
 }
 
 type syncChecker struct {
 	called chan int
 }
 
-func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
-	t.called <- 1
+func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+	c.called <- 1
 	return nil
 }
 
@@ -399,6 +324,7 @@
 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
 	return nil
 }
+
 func (n *errorKeyingTransport) getSessionID() []byte {
 	return nil
 }
@@ -425,20 +351,32 @@
 
 func TestHandshakeErrorHandlingRead(t *testing.T) {
 	for i := 0; i < 20; i++ {
-		testHandshakeErrorHandlingN(t, i, -1)
+		testHandshakeErrorHandlingN(t, i, -1, false)
 	}
 }
 
 func TestHandshakeErrorHandlingWrite(t *testing.T) {
 	for i := 0; i < 20; i++ {
-		testHandshakeErrorHandlingN(t, -1, i)
+		testHandshakeErrorHandlingN(t, -1, i, false)
+	}
+}
+
+func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, i, -1, true)
+	}
+}
+
+func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, -1, i, true)
 	}
 }
 
 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
 // handshakeTransport deadlocks, the go runtime will detect it and
 // panic.
-func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
 	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
 
 	a, b := memPipe()
@@ -451,37 +389,57 @@
 	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
 	serverConn.hostKeys = []Signer{key}
 	go serverConn.readLoop()
+	go serverConn.kexLoop()
 
 	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
 	clientConf.SetDefaults()
 	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
 	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
 	go clientConn.readLoop()
+	go clientConn.kexLoop()
 
 	var wg sync.WaitGroup
-	wg.Add(4)
 
 	for _, hs := range []packetConn{serverConn, clientConn} {
-		go func(c packetConn) {
-			for {
-				err := c.writePacket(msg)
-				if err != nil {
-					break
+		if !coupled {
+			wg.Add(2)
+			go func(c packetConn) {
+				for i := 0; ; i++ {
+					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
+					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
+					if err != nil {
+						break
+					}
 				}
-			}
-			wg.Done()
-		}(hs)
-		go func(c packetConn) {
-			for {
-				_, err := c.readPacket()
-				if err != nil {
-					break
+				wg.Done()
+				c.Close()
+			}(hs)
+			go func(c packetConn) {
+				for {
+					_, err := c.readPacket()
+					if err != nil {
+						break
+					}
 				}
-			}
-			wg.Done()
-		}(hs)
-	}
+				wg.Done()
+			}(hs)
+		} else {
+			wg.Add(1)
+			go func(c packetConn) {
+				for {
+					_, err := c.readPacket()
+					if err != nil {
+						break
+					}
+					if err := c.writePacket(msg); err != nil {
+						break
+					}
 
+				}
+				wg.Done()
+			}(hs)
+		}
+	}
 	wg.Wait()
 }
 
diff --git a/ssh/server.go b/ssh/server.go
index 9037470..28b109a 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -188,7 +188,7 @@
 	tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
 	s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
 
-	if err := s.transport.requestInitialKeyChange(); err != nil {
+	if err := s.transport.waitSession(); err != nil {
 		return nil, err
 	}
 
@@ -260,7 +260,7 @@
 }
 
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
-	var err error
+	sessionID := s.transport.getSessionID()
 	var cache pubKeyCache
 	var perms *Permissions
 
@@ -385,7 +385,7 @@
 				if !isAcceptableAlgo(sig.Format) {
 					break
 				}
-				signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
+				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
 
 				if err := pubKey.Verify(signedData, sig); err != nil {
 					return nil, err
@@ -421,12 +421,12 @@
 			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
 		}
 
-		if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
+		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
 			return nil, err
 		}
 	}
 
-	if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
+	if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
 		return nil, err
 	}
 	return perms, nil
diff --git a/ssh/transport.go b/ssh/transport.go
index 7e43a12..fd19932 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -22,7 +22,9 @@
 	// Encrypt and send a packet of data to the remote peer.
 	writePacket(packet []byte) error
 
-	// Read a packet from the connection
+	// Read a packet from the connection. The read is blocking,
+	// i.e. if error is nil, then the returned byte slice is
+	// always non-empty.
 	readPacket() ([]byte, error)
 
 	// Close closes the write-side of the connection.