x/crypto/ssh: hide msgNewKeys in the transport layer.

This ensures that extraneous key exchanges cannot confuse application
level code.

Change-Id: I1a333e2b7b46f1e484406a79db7a949294e79c6d
Reviewed-on: https://go-review.googlesource.com/22417
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index ef97d57..da53d3a 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -104,7 +104,7 @@
 			}
 			if i == 5 {
 				// halfway through, we request a key change.
-				_, _, err := trC.sendKexInit(subsequentKeyExchange)
+				err := trC.sendKexInit(subsequentKeyExchange)
 				if err != nil {
 					t.Fatalf("sendKexInit: %v", err)
 				}
@@ -161,7 +161,7 @@
 	}
 
 	// Now request a key change.
-	_, _, err = trC.sendKexInit(subsequentKeyExchange)
+	err = trC.sendKexInit(subsequentKeyExchange)
 	if err != nil {
 		t.Errorf("sendKexInit: %v", err)
 	}
@@ -184,6 +184,28 @@
 	}
 }
 
+func TestForceFirstKex(t *testing.T) {
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+
+	defer trC.Close()
+	defer trS.Close()
+
+	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
+
+	// We setup the initial key exchange, but the remote side
+	// tries to send serviceRequestMsg in cleartext, which is
+	// disallowed.
+
+	err = trS.sendKexInit(firstKeyExchange)
+	if err == nil {
+		t.Errorf("server first kex init should reject unexpected packet")
+	}
+}
+
 func TestHandshakeTwice(t *testing.T) {
 	checker := &testChecker{}
 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
@@ -194,18 +216,25 @@
 	defer trC.Close()
 	defer trS.Close()
 
+	// Both sides should ask for the first key exchange first.
+	err = trS.sendKexInit(firstKeyExchange)
+	if err != nil {
+		t.Errorf("server sendKexInit: %v", err)
+	}
+
+	err = trC.sendKexInit(firstKeyExchange)
+	if err != nil {
+		t.Errorf("client sendKexInit: %v", err)
+	}
+
+	sent := 0
 	// send a packet
 	packet := make([]byte, 5)
 	packet[0] = msgRequestSuccess
 	if err := trC.writePacket(packet); err != nil {
 		t.Errorf("writePacket: %v", err)
 	}
-
-	// Now request a key change.
-	_, _, err = trC.sendKexInit(subsequentKeyExchange)
-	if err != nil {
-		t.Errorf("sendKexInit: %v", err)
-	}
+	sent++
 
 	// Send another packet. Use a fresh one, since writePacket destroys.
 	packet = make([]byte, 5)
@@ -213,9 +242,10 @@
 	if err := trC.writePacket(packet); err != nil {
 		t.Errorf("writePacket: %v", err)
 	}
+	sent++
 
 	// 2nd key change.
-	_, _, err = trC.sendKexInit(subsequentKeyExchange)
+	err = trC.sendKexInit(subsequentKeyExchange)
 	if err != nil {
 		t.Errorf("sendKexInit: %v", err)
 	}
@@ -225,17 +255,15 @@
 	if err := trC.writePacket(packet); err != nil {
 		t.Errorf("writePacket: %v", err)
 	}
+	sent++
 
 	packet = make([]byte, 5)
 	packet[0] = msgRequestSuccess
-	for i := 0; i < 5; i++ {
+	for i := 0; i < sent; i++ {
 		msg, err := trS.readPacket()
 		if err != nil {
 			t.Fatalf("server closed too soon: %v", err)
 		}
-		if msg[0] == msgNewKeys {
-			continue
-		}
 
 		if bytes.Compare(msg, packet) != 0 {
 			t.Errorf("packet %d: got %q want %q", i, msg, packet)