ssh: reset buffered packets after sending

Since encryption messes up the packets, the wrongly retained packets
look like noise and cause application protocol errors or panics in the
SSH library.

This normally triggers very rarely: the mandatory key exchange doesn't
have parallel writes, so this failure condition would be setup on the
first key exchange, take effect only after the second key exchange.

Fortunately, the tests against openssh exercise this. This change adds
also adds a unittest.

Fixes #18850.

Change-Id: I656c8b94bfb265831daa118f4d614a2f0c65d2af
Reviewed-on: https://go-review.googlesource.com/36056
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 57f2d3d..e3f82c4 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -314,7 +314,7 @@
 				break
 			}
 		}
-		t.pendingPackets = t.pendingPackets[0:]
+		t.pendingPackets = t.pendingPackets[:0]
 		t.mu.Unlock()
 	}
 
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index e61348f..4d64376 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -125,7 +125,12 @@
 		t.Skip("see golang.org/issue/7237")
 	}
 
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	if err != nil {
 		t.Fatalf("handshakePair: %v", err)
@@ -134,22 +139,25 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	// Let first kex complete normally.
 	<-checker.called
 
 	clientDone := make(chan int, 0)
 	gotHalf := make(chan int, 0)
+	const N = 20
 
 	go func() {
 		defer close(clientDone)
 		// Client writes a bunch of stuff, and does a key
 		// change in the middle. This should not confuse the
-		// handshake in progress
-		for i := 0; i < 10; i++ {
+		// handshake in progress. We do this twice, so we test
+		// that the packet buffer is reset correctly.
+		for i := 0; i < N; i++ {
 			p := []byte{msgRequestSuccess, byte(i)}
 			if err := trC.writePacket(p); err != nil {
 				t.Fatalf("sendPacket: %v", err)
 			}
-			if i == 5 {
+			if (i % 10) == 5 {
 				<-gotHalf
 				// halfway through, we request a key change.
 				trC.requestKeyExchange()
@@ -159,32 +167,38 @@
 				// write more.
 				<-checker.called
 			}
+			if (i % 10) == 7 {
+				// write some packets until the kex
+				// completes, to test buffering of
+				// packets.
+				checker.waitCall <- 1
+			}
 		}
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
 	err = nil
-	for ; i < 10; i++ {
+	for ; i < N; i++ {
 		var p []byte
 		p, err = trS.readPacket()
 		if err != nil {
 			break
 		}
-		if i == 5 {
+		if (i % 10) == 5 {
 			gotHalf <- 1
 		}
 
 		want := []byte{msgRequestSuccess, byte(i)}
 		if bytes.Compare(p, want) != 0 {
-			t.Errorf("message %d: got %q, want %q", i, p, want)
+			t.Errorf("message %d: got %v, want %v", i, p, want)
 		}
 	}
 	<-clientDone
 	if err != nil && err != io.EOF {
 		t.Fatalf("server error: %v", err)
 	}
-	if i != 10 {
+	if i != N {
 		t.Errorf("received %d messages, want 10.", i)
 	}
 
@@ -239,7 +253,10 @@
 }
 
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		called:   make(chan int, 10),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf.RekeyThreshold = 500
 	trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -249,14 +266,19 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	input := make([]byte, 251)
+	input[0] = msgRequestSuccess
+
 	done := make(chan int, 1)
 	const numPacket = 5
 	go func() {
 		defer close(done)
 		j := 0
 		for ; j < numPacket; j++ {
-			if _, err := trS.readPacket(); err != nil {
+			if p, err := trS.readPacket(); err != nil {
 				break
+			} else if !bytes.Equal(input, p) {
+				t.Errorf("got packet type %d, want %d", p[0], input[0])
 			}
 		}
 
@@ -268,9 +290,9 @@
 	<-checker.called
 
 	for i := 0; i < numPacket; i++ {
-		packet := make([]byte, 251)
-		packet[0] = msgRequestSuccess
-		if err := trC.writePacket(packet); err != nil {
+		p := make([]byte, len(input))
+		copy(p, input)
+		if err := trC.writePacket(p); err != nil {
 			t.Errorf("writePacket: %v", err)
 		}
 		if i == 2 {
@@ -283,16 +305,23 @@
 }
 
 type syncChecker struct {
-	called chan int
+	waitCall chan int
+	called   chan int
 }
 
 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 	c.called <- 1
+	if c.waitCall != nil {
+		<-c.waitCall
+	}
 	return nil
 }
 
 func TestHandshakeAutoRekeyRead(t *testing.T) {
-	sync := &syncChecker{make(chan int, 2)}
+	sync := &syncChecker{
+		called:   make(chan int, 2),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{
 		HostKeyCallback: sync.Check,
 	}