ssh: support rsa-sha2-256/512 on the server side

This lets clients know we support rsa-sha2-256/512 signatures from
ssh-rsa public keys. OpenSSH prefers to break the connection rather than
attempting trial and error, apparently.

We don't enable support for the "ext-info-s" because we're not
interested in any client->server extensions.

This also replaces isAcceptableAlgo which was rejecting the
rsa-sha2-256/512-cert-v01@openssh.com public key algorithms.

Tested with OpenSSH 9.1 on macOS Ventura.

Fixes golang/go#49269
Updates golang/go#49952

Co-authored-by: Nicola Murino <nicola.murino@gmail.com>
Co-authored-by: Kristin Davidson <kdavidson@atlassian.com>
Change-Id: I4955c3b12bb45575e9977ac657bb5805b49d00c3
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/447757
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index a6bedbf..35b62e3 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -132,9 +132,7 @@
 	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
-	// Once the server implements the server-sig-algs extension, this will turn
-	// into KeyAlgoRSASHA256.
-	if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSA {
+	if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 {
 		t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used)
 	}
 }
diff --git a/ssh/common.go b/ssh/common.go
index 7a5ff2d..c796427 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -10,6 +10,7 @@
 	"fmt"
 	"io"
 	"math"
+	"strings"
 	"sync"
 
 	_ "crypto/sha1"
@@ -118,6 +119,20 @@
 	}
 }
 
+// supportedPubKeyAuthAlgos specifies the supported client public key
+// authentication algorithms. Note that this doesn't include certificate types
+// since those use the underlying algorithm. This list is sent to the client if
+// it supports the server-sig-algs extension. Order is irrelevant.
+var supportedPubKeyAuthAlgos = []string{
+	KeyAlgoED25519,
+	KeyAlgoSKED25519, KeyAlgoSKECDSA256,
+	KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
+	KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
+	KeyAlgoDSA,
+}
+
+var supportedPubKeyAuthAlgosList = strings.Join(supportedPubKeyAuthAlgos, ",")
+
 // unexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
 func unexpectedMessageError(expected, got uint8) error {
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 653dc4d..2b84c35 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -615,7 +615,8 @@
 		return err
 	}
 
-	if t.sessionID == nil {
+	firstKeyExchange := t.sessionID == nil
+	if firstKeyExchange {
 		t.sessionID = result.H
 	}
 	result.SessionID = t.sessionID
@@ -626,6 +627,24 @@
 	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}
+
+	// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
+	// message with the server-sig-algs extension if the client supports it. See
+	// RFC 8308, Sections 2.4 and 3.1.
+	if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
+		extInfo := &extInfoMsg{
+			NumExtensions: 1,
+			Payload:       make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)),
+		}
+		extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
+		extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
+		extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
+		extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
+		if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
+			return err
+		}
+	}
+
 	if packet, err := t.conn.readPacket(); err != nil {
 		return err
 	} else if packet[0] != msgNewKeys {
diff --git a/ssh/server.go b/ssh/server.go
index 2260b20..9e38702 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -291,15 +291,6 @@
 	return perms, err
 }
 
-func isAcceptableAlgo(algo string) bool {
-	switch algo {
-	case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519,
-		CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
-		return true
-	}
-	return false
-}
-
 func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
 	if addr == nil {
 		return errors.New("ssh: no address known for client, but source-address match required")
@@ -514,7 +505,7 @@
 				return nil, parseError(msgUserAuthRequest)
 			}
 			algo := string(algoBytes)
-			if !isAcceptableAlgo(algo) {
+			if !contains(supportedPubKeyAuthAlgos, underlyingAlgo(algo)) {
 				authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
 				break
 			}
@@ -572,7 +563,7 @@
 				// algorithm name that corresponds to algo with
 				// sig.Format.  This is usually the same, but
 				// for certs, the names differ.
-				if !isAcceptableAlgo(sig.Format) {
+				if !contains(supportedPubKeyAuthAlgos, sig.Format) {
 					authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
 					break
 				}