ssh: support rsa-sha2-256/512 for client authentication

CL 220037 had implemented support for host authentication using
rsa-sha2-256/512, but not client public key authentication. OpenSSH
disabled the SHA-1 based ssh-rsa by default in version 8.8 (after
pre-announcing it in versions 8.2, 8.3, 8.4, 8.5, 8.6, and 8.7) although
some distributions re-enable it. GitHub will start rejecting ssh-rsa for
keys uploaded before November 2, 2021 on March 15, 2022.

https://github.blog/2021-09-01-improving-git-protocol-security-github/

The server side already worked, as long as the client selected one of
the SHA-2 algorithms, because the signature flowed freely to Verify.
There was however nothing verifying that the signature algorithm matched
the advertised one. The comment suggested the check was being performed,
but it got lost back in CL 86190043. Not a security issue because the
signature had to pass the callback's Verify method regardless, and both
values were checked to be acceptable.

Tested with OpenSSH 8.8 configured with "PubkeyAcceptedKeyTypes -ssh-rsa"
and no application-side changes.

The Signers returned by ssh/agent (when backed by an agent client)
didn't actually implement AlgorithmSigner but ParameterizedSigner, an
interface defined in an earlier version of CL 123955.

Updates golang/go#49269
Fixes golang/go#39885
For golang/go#49952

Change-Id: I13b41db8041f1112a70f106c55f077b904b12cb8
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/392394
Trust: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
index b909471..3cfe723 100644
--- a/ssh/agent/client.go
+++ b/ssh/agent/client.go
@@ -25,7 +25,6 @@
 	"math/big"
 	"sync"
 
-	"crypto"
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ssh"
 )
@@ -771,19 +770,26 @@
 	return s.agent.Sign(s.pub, data)
 }
 
-func (s *agentKeyringSigner) SignWithOpts(rand io.Reader, data []byte, opts crypto.SignerOpts) (*ssh.Signature, error) {
-	var flags SignatureFlags
-	if opts != nil {
-		switch opts.HashFunc() {
-		case crypto.SHA256:
-			flags = SignatureFlagRsaSha256
-		case crypto.SHA512:
-			flags = SignatureFlagRsaSha512
-		}
+func (s *agentKeyringSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
+	if algorithm == "" || algorithm == s.pub.Type() {
+		return s.Sign(rand, data)
 	}
+
+	var flags SignatureFlags
+	switch algorithm {
+	case ssh.KeyAlgoRSASHA256:
+		flags = SignatureFlagRsaSha256
+	case ssh.KeyAlgoRSASHA512:
+		flags = SignatureFlagRsaSha512
+	default:
+		return nil, fmt.Errorf("agent: unsupported algorithm %q", algorithm)
+	}
+
 	return s.agent.SignWithFlags(s.pub, data, flags)
 }
 
+var _ ssh.AlgorithmSigner = &agentKeyringSigner{}
+
 // Calls an extension method. It is up to the agent implementation as to whether or not
 // any particular extension is supported and may always return an error. Because the
 // type of the response is up to the implementation, this returns the bytes of the
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index affc01d..a962a67 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -9,6 +9,7 @@
 	"errors"
 	"fmt"
 	"io"
+	"strings"
 )
 
 type authResult int
@@ -29,6 +30,33 @@
 	if err != nil {
 		return err
 	}
+	// The server may choose to send a SSH_MSG_EXT_INFO at this point (if we
+	// advertised willingness to receive one, which we always do) or not. See
+	// RFC 8308, Section 2.4.
+	extensions := make(map[string][]byte)
+	if len(packet) > 0 && packet[0] == msgExtInfo {
+		var extInfo extInfoMsg
+		if err := Unmarshal(packet, &extInfo); err != nil {
+			return err
+		}
+		payload := extInfo.Payload
+		for i := uint32(0); i < extInfo.NumExtensions; i++ {
+			name, rest, ok := parseString(payload)
+			if !ok {
+				return parseError(msgExtInfo)
+			}
+			value, rest, ok := parseString(rest)
+			if !ok {
+				return parseError(msgExtInfo)
+			}
+			extensions[string(name)] = value
+			payload = rest
+		}
+		packet, err = c.transport.readPacket()
+		if err != nil {
+			return err
+		}
+	}
 	var serviceAccept serviceAcceptMsg
 	if err := Unmarshal(packet, &serviceAccept); err != nil {
 		return err
@@ -41,7 +69,7 @@
 
 	sessionID := c.transport.getSessionID()
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
-		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
+		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		if err != nil {
 			return err
 		}
@@ -93,7 +121,7 @@
 	// If authentication is not successful, a []string of alternative
 	// method names is returned. If the slice is nil, it will be ignored
 	// and the previous set of possible methods will be reused.
-	auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
+	auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
 
 	// method returns the RFC 4252 method name.
 	method() string
@@ -102,7 +130,7 @@
 // "none" authentication, RFC 4252 section 5.2.
 type noneAuth int
 
-func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	if err := c.writePacket(Marshal(&userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,
@@ -122,7 +150,7 @@
 // a function call, e.g. by prompting the user.
 type passwordCallback func() (password string, err error)
 
-func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	type passwordAuthMsg struct {
 		User     string `sshtype:"50"`
 		Service  string
@@ -189,7 +217,36 @@
 	return "publickey"
 }
 
-func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (as AlgorithmSigner, algo string) {
+	keyFormat := signer.PublicKey().Type()
+
+	// Like in sendKexInit, if the public key implements AlgorithmSigner we
+	// assume it supports all algorithms, otherwise only the key format one.
+	as, ok := signer.(AlgorithmSigner)
+	if !ok {
+		return algorithmSignerWrapper{signer}, keyFormat
+	}
+
+	extPayload, ok := extensions["server-sig-algs"]
+	if !ok {
+		// If there is no "server-sig-algs" extension, fall back to the key
+		// format algorithm.
+		return as, keyFormat
+	}
+
+	serverAlgos := strings.Split(string(extPayload), ",")
+	keyAlgos := algorithmsForKeyFormat(keyFormat)
+	algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
+	if err != nil {
+		// If there is no overlap, try the key anyway with the key format
+		// algorithm, to support servers that fail to list all supported
+		// algorithms.
+		return as, keyFormat
+	}
+	return as, algo
+}
+
+func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
 	// Authentication is performed by sending an enquiry to test if a key is
 	// acceptable to the remote. If the key is acceptable, the client will
 	// attempt to authenticate with the valid key.  If not the client will repeat
@@ -201,7 +258,10 @@
 	}
 	var methods []string
 	for _, signer := range signers {
-		ok, err := validateKey(signer.PublicKey(), user, c)
+		pub := signer.PublicKey()
+		as, algo := pickSignatureAlgorithm(signer, extensions)
+
+		ok, err := validateKey(pub, algo, user, c)
 		if err != nil {
 			return authFailure, nil, err
 		}
@@ -209,13 +269,13 @@
 			continue
 		}
 
-		pub := signer.PublicKey()
 		pubKey := pub.Marshal()
-		sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
+		data := buildDataSignedForAuth(session, userAuthRequestMsg{
 			User:    user,
 			Service: serviceSSH,
 			Method:  cb.method(),
-		}, []byte(pub.Type()), pubKey))
+		}, algo, pubKey)
+		sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
 		if err != nil {
 			return authFailure, nil, err
 		}
@@ -229,7 +289,7 @@
 			Service:  serviceSSH,
 			Method:   cb.method(),
 			HasSig:   true,
-			Algoname: pub.Type(),
+			Algoname: algo,
 			PubKey:   pubKey,
 			Sig:      sig,
 		}
@@ -266,26 +326,25 @@
 }
 
 // validateKey validates the key provided is acceptable to the server.
-func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
+func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
 	msg := publickeyAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
 		Method:   "publickey",
 		HasSig:   false,
-		Algoname: key.Type(),
+		Algoname: algo,
 		PubKey:   pubKey,
 	}
 	if err := c.writePacket(Marshal(&msg)); err != nil {
 		return false, err
 	}
 
-	return confirmKeyAck(key, c)
+	return confirmKeyAck(key, algo, c)
 }
 
-func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
+func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
-	algoname := key.Type()
 
 	for {
 		packet, err := c.readPacket()
@@ -302,14 +361,14 @@
 			if err := Unmarshal(packet, &msg); err != nil {
 				return false, err
 			}
-			if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
+			if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
 				return false, nil
 			}
 			return true, nil
 		case msgUserAuthFailure:
 			return false, nil
 		default:
-			return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
+			return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
 		}
 	}
 }
@@ -330,6 +389,7 @@
 // along with a list of remaining authentication methods to try next and
 // an error if an unexpected response was received.
 func handleAuthResponse(c packetConn) (authResult, []string, error) {
+	gotMsgExtInfo := false
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -341,6 +401,12 @@
 			if err := handleBannerResponse(c, packet); err != nil {
 				return authFailure, nil, err
 			}
+		case msgExtInfo:
+			// Ignore post-authentication RFC 8308 extensions, once.
+			if gotMsgExtInfo {
+				return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
+			}
+			gotMsgExtInfo = true
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
 			if err := Unmarshal(packet, &msg); err != nil {
@@ -395,7 +461,7 @@
 	return "keyboard-interactive"
 }
 
-func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	type initiateMsg struct {
 		User       string `sshtype:"50"`
 		Service    string
@@ -412,6 +478,7 @@
 		return authFailure, nil, err
 	}
 
+	gotMsgExtInfo := false
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -425,6 +492,13 @@
 				return authFailure, nil, err
 			}
 			continue
+		case msgExtInfo:
+			// Ignore post-authentication RFC 8308 extensions, once.
+			if gotMsgExtInfo {
+				return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
+			}
+			gotMsgExtInfo = true
+			continue
 		case msgUserAuthInfoRequest:
 			// OK
 		case msgUserAuthFailure:
@@ -497,9 +571,9 @@
 	maxTries   int
 }
 
-func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
+func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
 	for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
-		ok, methods, err = r.authMethod.auth(session, user, c, rand)
+		ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
 		if ok != authFailure || err != nil { // either success, partial success or error terminate
 			return ok, methods, err
 		}
@@ -542,7 +616,7 @@
 	target       string
 }
 
-func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	m := &userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index f73079b..a6bedbf 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -105,11 +105,63 @@
 	return err, serverAuthErrors
 }
 
+type loggingAlgorithmSigner struct {
+	used []string
+	AlgorithmSigner
+}
+
+func (l *loggingAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	l.used = append(l.used, "[Sign]")
+	return l.AlgorithmSigner.Sign(rand, data)
+}
+
+func (l *loggingAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	l.used = append(l.used, algorithm)
+	return l.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
+}
+
 func TestClientAuthPublicKey(t *testing.T) {
+	signer := &loggingAlgorithmSigner{AlgorithmSigner: testSigners["rsa"].(AlgorithmSigner)}
 	config := &ClientConfig{
 		User: "testuser",
 		Auth: []AuthMethod{
-			PublicKeys(testSigners["rsa"]),
+			PublicKeys(signer),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	if err := tryAuth(t, config); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+	// Once the server implements the server-sig-algs extension, this will turn
+	// into KeyAlgoRSASHA256.
+	if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSA {
+		t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used)
+	}
+}
+
+// TestClientAuthNoSHA2 tests a ssh-rsa Signer that doesn't implement AlgorithmSigner.
+func TestClientAuthNoSHA2(t *testing.T) {
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			PublicKeys(&legacyRSASigner{testSigners["rsa"]}),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	if err := tryAuth(t, config); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+}
+
+// TestClientAuthThirdKey checks that the third configured can succeed. If we
+// were to do three attempts for each key (rsa-sha2-256, rsa-sha2-512, ssh-rsa),
+// we'd hit the six maximum attempts before reaching it.
+func TestClientAuthThirdKey(t *testing.T) {
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa-openssh-format"],
+				testSigners["rsa-openssh-format"], testSigners["rsa"]),
 		},
 		HostKeyCallback: InsecureIgnoreHostKey(),
 	}
diff --git a/ssh/common.go b/ssh/common.go
index d6d9bf9..2a47a61 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -297,8 +297,9 @@
 }
 
 // buildDataSignedForAuth returns the data that is signed in order to prove
-// possession of a private key. See RFC 4252, section 7.
-func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+// possession of a private key. See RFC 4252, section 7. algo is the advertised
+// algorithm, and may be a certificate type.
+func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte {
 	data := struct {
 		Session []byte
 		Type    byte
@@ -306,7 +307,7 @@
 		Service string
 		Method  string
 		Sign    bool
-		Algo    []byte
+		Algo    string
 		PubKey  []byte
 	}{
 		sessionID,
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 4bceb33..f815cdb 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -476,6 +476,13 @@
 		}
 	} else {
 		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
+
+		// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
+		// algorithms the server supports for public key authentication. See RFC
+		// 8303, Section 2.1.
+		msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
+		msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
+		msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
 	}
 
 	packet := Marshal(msg)
diff --git a/ssh/messages.go b/ssh/messages.go
index 62f9330..19bc67c 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -141,6 +141,14 @@
 	Service string `sshtype:"6"`
 }
 
+// See RFC 8308, section 2.3
+const msgExtInfo = 7
+
+type extInfoMsg struct {
+	NumExtensions uint32 `sshtype:"7"`
+	Payload       []byte `ssh:"rest"`
+}
+
 // See RFC 4252, section 5.
 const msgUserAuthRequest = 50
 
@@ -782,6 +790,8 @@
 		msg = new(serviceRequestMsg)
 	case msgServiceAccept:
 		msg = new(serviceAcceptMsg)
+	case msgExtInfo:
+		msg = new(extInfoMsg)
 	case msgKexInit:
 		msg = new(kexInitMsg)
 	case msgKexDHInit:
@@ -843,6 +853,7 @@
 	msgDisconnect:          "disconnectMsg",
 	msgServiceRequest:      "serviceRequestMsg",
 	msgServiceAccept:       "serviceAcceptMsg",
+	msgExtInfo:             "extInfoMsg",
 	msgKexInit:             "kexInitMsg",
 	msgKexDHInit:           "kexDHInitMsg",
 	msgKexDHReply:          "kexDHReplyMsg",
diff --git a/ssh/server.go b/ssh/server.go
index 9bb714c..70045bd 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -554,6 +554,7 @@
 				if !ok || len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 				}
+
 				// Ensure the public key algo and signature algo
 				// are supported.  Compare the private key
 				// algorithm name that corresponds to algo with
@@ -563,7 +564,12 @@
 					authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
 					break
 				}
-				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
+				if underlyingAlgo(algo) != sig.Format {
+					authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
+					break
+				}
+
+				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
 
 				if err := pubKey.Verify(signedData, sig); err != nil {
 					return nil, err