ssh: check the declared public key algo against decoded one

This check will ensure we don't accept e.g. ssh-rsa-cert-v01@openssh.com
algorithm with ssh-rsa public key type.
The algorithm and public key type must be consistent: both must be
certificate algorithms, or neither.

Change-Id: I1d75074fb4d6db3a8796408e98ddffe577a96ab1
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/506836
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index 70558a9..0d2d8d2 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -1076,3 +1076,94 @@
 		}
 	}
 }
+
+// configurablePublicKeyCallback is a public key callback that allows to
+// configure the signature algorithm and format. This way we can emulate the
+// behavior of buggy clients.
+type configurablePublicKeyCallback struct {
+	signer          AlgorithmSigner
+	signatureAlgo   string
+	signatureFormat string
+}
+
+func (cb configurablePublicKeyCallback) method() string {
+	return "publickey"
+}
+
+func (cb configurablePublicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
+	pub := cb.signer.PublicKey()
+
+	ok, err := validateKey(pub, cb.signatureAlgo, user, c)
+	if err != nil {
+		return authFailure, nil, err
+	}
+	if !ok {
+		return authFailure, nil, fmt.Errorf("invalid public key")
+	}
+
+	pubKey := pub.Marshal()
+	data := buildDataSignedForAuth(session, userAuthRequestMsg{
+		User:    user,
+		Service: serviceSSH,
+		Method:  cb.method(),
+	}, cb.signatureAlgo, pubKey)
+	sign, err := cb.signer.SignWithAlgorithm(rand, data, underlyingAlgo(cb.signatureFormat))
+	if err != nil {
+		return authFailure, nil, err
+	}
+
+	s := Marshal(sign)
+	sig := make([]byte, stringLength(len(s)))
+	marshalString(sig, s)
+	msg := publickeyAuthMsg{
+		User:     user,
+		Service:  serviceSSH,
+		Method:   cb.method(),
+		HasSig:   true,
+		Algoname: cb.signatureAlgo,
+		PubKey:   pubKey,
+		Sig:      sig,
+	}
+	p := Marshal(&msg)
+	if err := c.writePacket(p); err != nil {
+		return authFailure, nil, err
+	}
+	var success authResult
+	success, methods, err := handleAuthResponse(c)
+	if err != nil {
+		return authFailure, nil, err
+	}
+	if success == authSuccess || !containsMethod(methods, cb.method()) {
+		return success, methods, err
+	}
+
+	return authFailure, methods, nil
+}
+
+func TestPublicKeyAndAlgoCompatibility(t *testing.T) {
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+	if err != nil {
+		t.Fatalf("NewCertSigner: %v", err)
+	}
+
+	clientConfig := &ClientConfig{
+		User:            "user",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		Auth: []AuthMethod{
+			configurablePublicKeyCallback{
+				signer:          certSigner.(AlgorithmSigner),
+				signatureAlgo:   KeyAlgoRSASHA256,
+				signatureFormat: KeyAlgoRSASHA256,
+			},
+		},
+	}
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Error("cert login passed with incompatible public key type and algorithm")
+	}
+}
diff --git a/ssh/server.go b/ssh/server.go
index b21322a..727c71b 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -576,7 +576,16 @@
 				if !ok || len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 				}
-
+				// Ensure the declared public key algo is compatible with the
+				// decoded one. This check will ensure we don't accept e.g.
+				// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
+				// key type. The algorithm and public key type must be
+				// consistent: both must be certificate algorithms, or neither.
+				if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
+					authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
+						pubKey.Type(), algo)
+					break
+				}
 				// Ensure the public key algo and signature algo
 				// are supported.  Compare the private key
 				// algorithm name that corresponds to algo with