ssh: implement strict KEX protocol changes

Implement the "strict KEX" protocol changes, as described in section
1.9 of the OpenSSH PROTOCOL file (as of OpenSSH version 9.6/9.6p1).

Namely this makes the following changes:
  * Both the server and the client add an additional algorithm to the
    initial KEXINIT message, indicating support for the strict KEX mode.
  * When one side of the connection sees the strict KEX extension
    algorithm, the strict KEX mode is enabled for messages originating
    from the other side of the connection. If the sequence number for
    the side which requested the extension is not 1 (indicating that it
    has already received non-KEXINIT packets), the connection is
    terminated.
  * When strict kex mode is enabled, unexpected messages during the
    handshake are considered fatal. Additionally when a key change
    occurs (on the receipt of the NEWKEYS message) the message sequence
    numbers are reset.

Thanks to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk from Ruhr
University Bochum for reporting this issue.

Fixes CVE-2023-48795
Fixes golang/go#64784

Change-Id: I96b53afd2bd2fb94d2b6f2a46a5dacf325357604
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/550715
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 49bbba7..56cdc7c 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -35,6 +35,16 @@
 	// direction will be effected if a msgNewKeys message is sent
 	// or received.
 	prepareKeyChange(*algorithms, *kexResult) error
+
+	// setStrictMode sets the strict KEX mode, notably triggering
+	// sequence number resets on sending or receiving msgNewKeys.
+	// If the sequence number is already > 1 when setStrictMode
+	// is called, an error is returned.
+	setStrictMode() error
+
+	// setInitialKEXDone indicates to the transport that the initial key exchange
+	// was completed
+	setInitialKEXDone()
 }
 
 // handshakeTransport implements rekeying on top of a keyingTransport
@@ -100,6 +110,10 @@
 
 	// The session ID or nil if first kex did not complete yet.
 	sessionID []byte
+
+	// strictMode indicates if the other side of the handshake indicated
+	// that we should be following the strict KEX protocol restrictions.
+	strictMode bool
 }
 
 type pendingKex struct {
@@ -209,7 +223,10 @@
 			close(t.incoming)
 			break
 		}
-		if p[0] == msgIgnore || p[0] == msgDebug {
+		// If this is the first kex, and strict KEX mode is enabled,
+		// we don't ignore any messages, as they may be used to manipulate
+		// the packet sequence numbers.
+		if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
 			continue
 		}
 		t.incoming <- p
@@ -441,6 +458,11 @@
 	return successPacket, nil
 }
 
+const (
+	kexStrictClient = "kex-strict-c-v00@openssh.com"
+	kexStrictServer = "kex-strict-s-v00@openssh.com"
+)
+
 // sendKexInit sends a key change message.
 func (t *handshakeTransport) sendKexInit() error {
 	t.mu.Lock()
@@ -454,7 +476,6 @@
 	}
 
 	msg := &kexInitMsg{
-		KexAlgos:                t.config.KeyExchanges,
 		CiphersClientServer:     t.config.Ciphers,
 		CiphersServerClient:     t.config.Ciphers,
 		MACsClientServer:        t.config.MACs,
@@ -464,6 +485,13 @@
 	}
 	io.ReadFull(rand.Reader, msg.Cookie[:])
 
+	// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
+	// and possibly to add the ext-info extension algorithm. Since the slice may be the
+	// user owned KeyExchanges, we create our own slice in order to avoid using user
+	// owned memory by mistake.
+	msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
+	msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
+
 	isServer := len(t.hostKeys) > 0
 	if isServer {
 		for _, k := range t.hostKeys {
@@ -488,17 +516,24 @@
 				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
 			}
 		}
+
+		if t.sessionID == nil {
+			msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
+		}
 	} else {
 		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
 
 		// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
 		// algorithms the server supports for public key authentication. See RFC
 		// 8308, Section 2.1.
+		//
+		// We also send the strict KEX mode extension algorithm, in order to opt
+		// into the strict KEX mode.
 		if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
-			msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
-			msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
 			msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
+			msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
 		}
+
 	}
 
 	packet := Marshal(msg)
@@ -604,6 +639,13 @@
 		return err
 	}
 
+	if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
+		t.strictMode = true
+		if err := t.conn.setStrictMode(); err != nil {
+			return err
+		}
+	}
+
 	// We don't send FirstKexFollows, but we handle receiving it.
 	//
 	// RFC 4253 section 7 defines the kex and the agreement method for
@@ -679,6 +721,12 @@
 		return unexpectedMessageError(msgNewKeys, packet[0])
 	}
 
+	if firstKeyExchange {
+		// Indicates to the transport that the first key exchange is completed
+		// after receiving SSH_MSG_NEWKEYS.
+		t.conn.setInitialKEXDone()
+	}
+
 	return nil
 }
 
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 65afc20..2bc607b 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -395,6 +395,10 @@
 	return n.packetConn.readPacket()
 }
 
+func (n *errorKeyingTransport) setStrictMode() error { return nil }
+
+func (n *errorKeyingTransport) setInitialKEXDone() {}
+
 func TestHandshakeErrorHandlingRead(t *testing.T) {
 	for i := 0; i < 20; i++ {
 		testHandshakeErrorHandlingN(t, i, -1, false)
@@ -710,3 +714,308 @@
 		t.Fatal("incompatible signer returned")
 	}
 }
+
+func TestStrictKEXResetSeqFirstKEX(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	// check that the sequence number counters. We reset after msgNewKeys, but
+	// then the server immediately writes msgExtInfo, and we close the
+	// transports so we expect read 2, write 0 on the client and read 1, write 1
+	// on the server.
+	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
+		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// write and read five packets on either side to bump the sequence numbers
+	for i := 0; i < 5; i++ {
+		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trS.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trC.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+	}
+
+	// Request a key exchange, which should cause the sequence numbers to reset
+	checker.waitCall <- 1
+	trC.requestKeyExchange()
+	<-checker.called
+
+	// write a packet on the client, and then read it, to verify the key change has actually happened, since
+	// the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake
+	// finishing.
+	dummyPacket := []byte{99}
+	if err := trS.writePacket(dummyPacket); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	if p, err := trC.readPacket(); err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	} else if !bytes.Equal(p, dummyPacket) {
+		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
+		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestSeqNumIncrease(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// write and read five packets on either side to bump the sequence numbers
+	for i := 0; i < 5; i++ {
+		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trS.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trC.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 ||
+		trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestStrictKEXUnexpectedMsg(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	// Check that unexpected messages during the handshake cause failure
+	_, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true)
+	if err == nil {
+		t.Fatal("handshake should fail when there are unexpected messages during the handshake")
+	}
+
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshake failed: %s", err)
+	}
+
+	// Check that ignore/debug pacekts are still ignored outside of the handshake
+	if err := trC.writePacket([]byte{msgIgnore}); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	if err := trC.writePacket([]byte{msgDebug}); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	dummyPacket := []byte{99}
+	if err := trC.writePacket(dummyPacket); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+
+	if p, err := trS.readPacket(); err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	} else if !bytes.Equal(p, dummyPacket) {
+		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
+	}
+}
+
+func TestStrictKEXMixed(t *testing.T) {
+	// Test that we still support a mixed connection, where one side sends kex-strict but the other
+	// side doesn't.
+
+	a, b, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe failed: %s", err)
+	}
+
+	var trC, trS keyingTransport
+
+	trC = newTransport(a, rand.Reader, true)
+	trS = newTransport(b, rand.Reader, false)
+	trS = addNoiseTransport(trS)
+
+	clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+
+	transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version"))
+	transport.hostKeys = serverConf.hostKeys
+	transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms
+
+	readOneFailure := make(chan error, 1)
+	go func() {
+		if _, err := transport.readOnePacket(true); err != nil {
+			readOneFailure <- err
+		}
+	}()
+
+	// Basically sendKexInit, but without the kex-strict extension algorithm
+	msg := &kexInitMsg{
+		KexAlgos:                transport.config.KeyExchanges,
+		CiphersClientServer:     transport.config.Ciphers,
+		CiphersServerClient:     transport.config.Ciphers,
+		MACsClientServer:        transport.config.MACs,
+		MACsServerClient:        transport.config.MACs,
+		CompressionClientServer: supportedCompressions,
+		CompressionServerClient: supportedCompressions,
+		ServerHostKeyAlgos:      []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA},
+	}
+	packet := Marshal(msg)
+	// writePacket destroys the contents, so save a copy.
+	packetCopy := make([]byte, len(packet))
+	copy(packetCopy, packet)
+	if err := transport.pushPacket(packetCopy); err != nil {
+		t.Fatalf("pushPacket: %s", err)
+	}
+	transport.sentInitMsg = msg
+	transport.sentInitPacket = packet
+
+	if err := transport.getWriteError(); err != nil {
+		t.Fatalf("getWriteError failed: %s", err)
+	}
+	var request *pendingKex
+	select {
+	case err = <-readOneFailure:
+		t.Fatalf("server readOnePacket failed: %s", err)
+	case request = <-transport.startKex:
+		break
+	}
+
+	// We expect the following calls to fail if the side which does not support
+	// kex-strict sends unexpected/ignored packets during the handshake, even if
+	// the other side does support kex-strict.
+
+	if err := transport.enterKeyExchange(request.otherInit); err != nil {
+		t.Fatalf("enterKeyExchange failed: %s", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+}
diff --git a/ssh/transport.go b/ssh/transport.go
index da01580..0424d2d 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -49,6 +49,9 @@
 	rand      io.Reader
 	isClient  bool
 	io.Closer
+
+	strictMode     bool
+	initialKEXDone bool
 }
 
 // packetCipher represents a combination of SSH encryption/MAC
@@ -74,6 +77,18 @@
 	pendingKeyChange chan packetCipher
 }
 
+func (t *transport) setStrictMode() error {
+	if t.reader.seqNum != 1 {
+		return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
+	}
+	t.strictMode = true
+	return nil
+}
+
+func (t *transport) setInitialKEXDone() {
+	t.initialKEXDone = true
+}
+
 // prepareKeyChange sets up key material for a keychange. The key changes in
 // both directions are triggered by reading and writing a msgNewKey packet
 // respectively.
@@ -112,11 +127,12 @@
 // Read and decrypt next packet.
 func (t *transport) readPacket() (p []byte, err error) {
 	for {
-		p, err = t.reader.readPacket(t.bufReader)
+		p, err = t.reader.readPacket(t.bufReader, t.strictMode)
 		if err != nil {
 			break
 		}
-		if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
+		// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
+		if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
 			break
 		}
 	}
@@ -127,7 +143,7 @@
 	return p, err
 }
 
-func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
+func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
 	packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
 	s.seqNum++
 	if err == nil && len(packet) == 0 {
@@ -140,6 +156,9 @@
 			select {
 			case cipher := <-s.pendingKeyChange:
 				s.packetCipher = cipher
+				if strictMode {
+					s.seqNum = 0
+				}
 			default:
 				return nil, errors.New("ssh: got bogus newkeys message")
 			}
@@ -170,10 +189,10 @@
 	if debugTransport {
 		t.printPacket(packet, true)
 	}
-	return t.writer.writePacket(t.bufWriter, t.rand, packet)
+	return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
 }
 
-func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
+func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
 	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
 
 	err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
@@ -188,6 +207,9 @@
 		select {
 		case cipher := <-s.pendingKeyChange:
 			s.packetCipher = cipher
+			if strictMode {
+				s.seqNum = 0
+			}
 		default:
 			panic("ssh: no key material for msgNewKeys")
 		}