ssh: add MultiAlgorithmSigner

MultiAlgorithmSigner allows to restrict client-side, server-side and
certificate signing algorithms.

Fixes golang/go#52132
Fixes golang/go#36261

Change-Id: I295092f1bba647327aaaf294f110e9157d294159
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/508398
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
diff --git a/ssh/certs.go b/ssh/certs.go
index fc04d03..27d0e14 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -16,8 +16,9 @@
 
 // Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
 // in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
-// Unlike key algorithm names, these are not passed to AlgorithmSigner and don't
-// appear in the Signature.Format field.
+// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
+// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
+// field.
 const (
 	CertAlgoRSAv01        = "ssh-rsa-cert-v01@openssh.com"
 	CertAlgoDSAv01        = "ssh-dss-cert-v01@openssh.com"
@@ -255,10 +256,17 @@
 		return nil, errors.New("ssh: signer and cert have different public key")
 	}
 
-	if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
+	switch s := signer.(type) {
+	case MultiAlgorithmSigner:
+		return &multiAlgorithmSigner{
+			AlgorithmSigner: &algorithmOpenSSHCertSigner{
+				&openSSHCertSigner{cert, signer}, s},
+			supportedAlgorithms: s.Algorithms(),
+		}, nil
+	case AlgorithmSigner:
 		return &algorithmOpenSSHCertSigner{
-			&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
-	} else {
+			&openSSHCertSigner{cert, signer}, s}, nil
+	default:
 		return &openSSHCertSigner{cert, signer}, nil
 	}
 }
@@ -432,7 +440,9 @@
 }
 
 // SignCert signs the certificate with an authority, setting the Nonce,
-// SignatureKey, and Signature fields.
+// SignatureKey, and Signature fields. If the authority implements the
+// MultiAlgorithmSigner interface the first algorithm in the list is used. This
+// is useful if you want to sign with a specific algorithm.
 func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
 	c.Nonce = make([]byte, 32)
 	if _, err := io.ReadFull(rand, c.Nonce); err != nil {
@@ -440,8 +450,20 @@
 	}
 	c.SignatureKey = authority.PublicKey()
 
-	// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
-	if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
+	if v, ok := authority.(MultiAlgorithmSigner); ok {
+		if len(v.Algorithms()) == 0 {
+			return errors.New("the provided authority has no signature algorithm")
+		}
+		// Use the first algorithm in the list.
+		sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
+		if err != nil {
+			return err
+		}
+		c.Signature = sig
+		return nil
+	} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
+		// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
+		// TODO: consider using KeyAlgoRSASHA256 as default.
 		sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
 		if err != nil {
 			return err
diff --git a/ssh/certs_test.go b/ssh/certs_test.go
index e600483..97b7486 100644
--- a/ssh/certs_test.go
+++ b/ssh/certs_test.go
@@ -187,10 +187,30 @@
 	}
 
 	for _, test := range []struct {
-		addr    string
-		succeed bool
+		addr                    string
+		succeed                 bool
+		certSignerAlgorithms    []string // Empty means no algorithm restrictions.
+		clientHostKeyAlgorithms []string
 	}{
 		{addr: "hostname:22", succeed: true},
+		{
+			addr:                    "hostname:22",
+			succeed:                 true,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{CertAlgoRSASHA512v01},
+		},
+		{
+			addr:                    "hostname:22",
+			succeed:                 false,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{CertAlgoRSAv01},
+		},
+		{
+			addr:                    "hostname:22",
+			succeed:                 false,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{KeyAlgoRSASHA512}, // Not a certificate algorithm.
+		},
 		{addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
 		{addr: "lasthost:22", succeed: false},
 	} {
@@ -207,14 +227,24 @@
 			conf := ServerConfig{
 				NoClientAuth: true,
 			}
-			conf.AddHostKey(certSigner)
+			if len(test.certSignerAlgorithms) > 0 {
+				mas, err := NewSignerWithAlgorithms(certSigner.(AlgorithmSigner), test.certSignerAlgorithms)
+				if err != nil {
+					errc <- err
+					return
+				}
+				conf.AddHostKey(mas)
+			} else {
+				conf.AddHostKey(certSigner)
+			}
 			_, _, _, err := NewServerConn(c1, &conf)
 			errc <- err
 		}()
 
 		config := &ClientConfig{
-			User:            "user",
-			HostKeyCallback: checker.CheckHostKey,
+			User:              "user",
+			HostKeyCallback:   checker.CheckHostKey,
+			HostKeyAlgorithms: test.clientHostKeyAlgorithms,
 		}
 		_, _, _, err = NewClientConn(c2, test.addr, config)
 
@@ -242,6 +272,20 @@
 }
 
 func TestCertTypes(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSignerSHA256, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer SHA256: %v", err)
+	}
+	// Algorithms are in order of preference, we expect rsa-sha2-512 to be used.
+	multiAlgoSignerSHA512, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer SHA512: %v", err)
+	}
+
 	var testVars = []struct {
 		name   string
 		signer Signer
@@ -251,8 +295,10 @@
 		{CertAlgoECDSA384v01, testSigners["ecdsap384"], ""},
 		{CertAlgoECDSA521v01, testSigners["ecdsap521"], ""},
 		{CertAlgoED25519v01, testSigners["ed25519"], ""},
-		{CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA512},
+		{CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA256},
 		{"legacyRSASigner", &legacyRSASigner{testSigners["rsa"]}, KeyAlgoRSA},
+		{"multiAlgoRSASignerSHA256", multiAlgoSignerSHA256, KeyAlgoRSASHA256},
+		{"multiAlgoRSASignerSHA512", multiAlgoSignerSHA512, KeyAlgoRSASHA512},
 		{CertAlgoDSAv01, testSigners["dsa"], ""},
 	}
 
@@ -318,3 +364,45 @@
 		})
 	}
 }
+
+func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
+	type testcase struct {
+		sigAlgo   string
+		algoritms []string
+	}
+	cases := []testcase{
+		{
+			sigAlgo:   KeyAlgoRSA,
+			algoritms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
+		},
+		{
+			sigAlgo:   KeyAlgoRSASHA256,
+			algoritms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
+		},
+		{
+			sigAlgo:   KeyAlgoRSASHA512,
+			algoritms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
+		},
+	}
+
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+
+	for _, c := range cases {
+		t.Run(c.sigAlgo, func(t *testing.T) {
+			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algoritms)
+			if err != nil {
+				t.Fatalf("NewSignerWithAlgorithms error: %v", err)
+			}
+			if err := cert.SignCert(rand.Reader, signer); err != nil {
+				t.Fatalf("SignCert error: %v", err)
+			}
+			if cert.Signature.Format != c.sigAlgo {
+				t.Fatalf("got signature format %q, want %q", cert.Signature.Format, c.sigAlgo)
+			}
+		})
+	}
+}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 409b5ea..5c3bc25 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -71,7 +71,9 @@
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
 		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		if err != nil {
-			return err
+			// We return the error later if there is no other method left to
+			// try.
+			ok = authFailure
 		}
 		if ok == authSuccess {
 			// success
@@ -101,6 +103,12 @@
 				}
 			}
 		}
+
+		if auth == nil && err != nil {
+			// We have an error and there are no other authentication methods to
+			// try, so we return it.
+			return err
+		}
 	}
 	return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
 }
@@ -217,21 +225,45 @@
 	return "publickey"
 }
 
-func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (as AlgorithmSigner, algo string) {
+func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
+	var as MultiAlgorithmSigner
 	keyFormat := signer.PublicKey().Type()
 
-	// Like in sendKexInit, if the public key implements AlgorithmSigner we
-	// assume it supports all algorithms, otherwise only the key format one.
-	as, ok := signer.(AlgorithmSigner)
-	if !ok {
-		return algorithmSignerWrapper{signer}, keyFormat
+	// If the signer implements MultiAlgorithmSigner we use the algorithms it
+	// support, if it implements AlgorithmSigner we assume it supports all
+	// algorithms, otherwise only the key format one.
+	switch s := signer.(type) {
+	case MultiAlgorithmSigner:
+		as = s
+	case AlgorithmSigner:
+		as = &multiAlgorithmSigner{
+			AlgorithmSigner:     s,
+			supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
+		}
+	default:
+		as = &multiAlgorithmSigner{
+			AlgorithmSigner:     algorithmSignerWrapper{signer},
+			supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
+		}
+	}
+
+	getFallbackAlgo := func() (string, error) {
+		// Fallback to use if there is no "server-sig-algs" extension or a
+		// common algorithm cannot be found. We use the public key format if the
+		// MultiAlgorithmSigner supports it, otherwise we return an error.
+		if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
+			return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
+				underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
+		}
+		return keyFormat, nil
 	}
 
 	extPayload, ok := extensions["server-sig-algs"]
 	if !ok {
-		// If there is no "server-sig-algs" extension, fall back to the key
-		// format algorithm.
-		return as, keyFormat
+		// If there is no "server-sig-algs" extension use the fallback
+		// algorithm.
+		algo, err := getFallbackAlgo()
+		return as, algo, err
 	}
 
 	// The server-sig-algs extension only carries underlying signature
@@ -245,15 +277,22 @@
 		}
 	}
 
-	keyAlgos := algorithmsForKeyFormat(keyFormat)
+	// Filter algorithms based on those supported by MultiAlgorithmSigner.
+	var keyAlgos []string
+	for _, algo := range algorithmsForKeyFormat(keyFormat) {
+		if contains(as.Algorithms(), underlyingAlgo(algo)) {
+			keyAlgos = append(keyAlgos, algo)
+		}
+	}
+
 	algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
 	if err != nil {
-		// If there is no overlap, try the key anyway with the key format
-		// algorithm, to support servers that fail to list all supported
-		// algorithms.
-		return as, keyFormat
+		// If there is no overlap, return the fallback algorithm to support
+		// servers that fail to list all supported algorithms.
+		algo, err := getFallbackAlgo()
+		return as, algo, err
 	}
-	return as, algo
+	return as, algo, nil
 }
 
 func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
@@ -267,10 +306,17 @@
 		return authFailure, nil, err
 	}
 	var methods []string
+	var errSigAlgo error
 	for _, signer := range signers {
 		pub := signer.PublicKey()
-		as, algo := pickSignatureAlgorithm(signer, extensions)
-
+		as, algo, err := pickSignatureAlgorithm(signer, extensions)
+		if err != nil && errSigAlgo == nil {
+			// If we cannot negotiate a signature algorithm store the first
+			// error so we can return it to provide a more meaningful message if
+			// no other signers work.
+			errSigAlgo = err
+			continue
+		}
 		ok, err := validateKey(pub, algo, user, c)
 		if err != nil {
 			return authFailure, nil, err
@@ -317,22 +363,12 @@
 		// contain the "publickey" method, do not attempt to authenticate with any
 		// other keys.  According to RFC 4252 Section 7, the latter can occur when
 		// additional authentication methods are required.
-		if success == authSuccess || !containsMethod(methods, cb.method()) {
+		if success == authSuccess || !contains(methods, cb.method()) {
 			return success, methods, err
 		}
 	}
 
-	return authFailure, methods, nil
-}
-
-func containsMethod(methods []string, method string) bool {
-	for _, m := range methods {
-		if m == method {
-			return true
-		}
-	}
-
-	return false
+	return authFailure, methods, errSigAlgo
 }
 
 // validateKey validates the key provided is acceptable to the server.
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index 0d2d8d2..16d4113 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -1077,6 +1077,73 @@
 	}
 }
 
+func TestPickSignatureAlgorithm(t *testing.T) {
+	type testcase struct {
+		name       string
+		extensions map[string][]byte
+	}
+	cases := []testcase{
+		{
+			name: "server with empty server-sig-algs",
+			extensions: map[string][]byte{
+				"server-sig-algs": []byte(``),
+			},
+		},
+		{
+			name:       "server with no server-sig-algs",
+			extensions: nil,
+		},
+	}
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			signer, ok := testSigners["rsa"].(MultiAlgorithmSigner)
+			if !ok {
+				t.Fatalf("rsa test signer does not implement the MultiAlgorithmSigner interface")
+			}
+			// The signer supports the public key algorithm which is then returned.
+			_, algo, err := pickSignatureAlgorithm(signer, c.extensions)
+			if err != nil {
+				t.Fatalf("got %v, want no error", err)
+			}
+			if algo != signer.PublicKey().Type() {
+				t.Fatalf("got algo %q, want %q", algo, signer.PublicKey().Type())
+			}
+			// Test a signer that uses a certificate algorithm as the public key
+			// type.
+			cert := &Certificate{
+				CertType: UserCert,
+				Key:      signer.PublicKey(),
+			}
+			cert.SignCert(rand.Reader, signer)
+
+			certSigner, err := NewCertSigner(cert, signer)
+			if err != nil {
+				t.Fatalf("error generating cert signer: %v", err)
+			}
+			// The signer supports the public key algorithm and the
+			// public key format is a certificate type so the cerificate
+			// algorithm matching the key format must be returned
+			_, algo, err = pickSignatureAlgorithm(certSigner, c.extensions)
+			if err != nil {
+				t.Fatalf("got %v, want no error", err)
+			}
+			if algo != certSigner.PublicKey().Type() {
+				t.Fatalf("got algo %q, want %q", algo, certSigner.PublicKey().Type())
+			}
+			signer, err = NewSignerWithAlgorithms(signer.(AlgorithmSigner), []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256})
+			if err != nil {
+				t.Fatalf("unable to create signer with algorithms: %v", err)
+			}
+			// The signer does not support the public key algorithm so an error
+			// is returned.
+			_, _, err = pickSignatureAlgorithm(signer, c.extensions)
+			if err == nil {
+				t.Fatal("got no error, no common public key signature algorithm error expected")
+			}
+		})
+	}
+}
+
 // configurablePublicKeyCallback is a public key callback that allows to
 // configure the signature algorithm and format. This way we can emulate the
 // behavior of buggy clients.
@@ -1133,7 +1200,7 @@
 	if err != nil {
 		return authFailure, nil, err
 	}
-	if success == authSuccess || !containsMethod(methods, cb.method()) {
+	if success == authSuccess || !contains(methods, cb.method()) {
 		return success, methods, err
 	}
 
diff --git a/ssh/example_test.go b/ssh/example_test.go
index bee6796..0a6b076 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -7,6 +7,8 @@
 import (
 	"bufio"
 	"bytes"
+	"crypto/rand"
+	"crypto/rsa"
 	"fmt"
 	"log"
 	"net"
@@ -75,7 +77,6 @@
 	if err != nil {
 		log.Fatal("Failed to parse private key: ", err)
 	}
-
 	config.AddHostKey(private)
 
 	// Once a ServerConfig has been configured, connections can be
@@ -139,6 +140,36 @@
 	}
 }
 
+func ExampleServerConfig_AddHostKey() {
+	// Minimal ServerConfig supporting only password authentication.
+	config := &ssh.ServerConfig{
+		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
+			// Should use constant-time compare (or better, salt+hash) in
+			// a production setting.
+			if c.User() == "testuser" && string(pass) == "tiger" {
+				return nil, nil
+			}
+			return nil, fmt.Errorf("password rejected for %q", c.User())
+		},
+	}
+
+	privateBytes, err := os.ReadFile("id_rsa")
+	if err != nil {
+		log.Fatal("Failed to load private key: ", err)
+	}
+
+	private, err := ssh.ParsePrivateKey(privateBytes)
+	if err != nil {
+		log.Fatal("Failed to parse private key: ", err)
+	}
+	// Restrict host key algorithms to disable ssh-rsa.
+	signer, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512})
+	if err != nil {
+		log.Fatal("Failed to create private key with restricted algorithms: ", err)
+	}
+	config.AddHostKey(signer)
+}
+
 func ExampleClientConfig_HostKeyCallback() {
 	// Every client must provide a host key check.  Here is a
 	// simple-minded parse of OpenSSH's known_hosts file
@@ -318,3 +349,38 @@
 		log.Fatal("failed to start shell: ", err)
 	}
 }
+
+func ExampleCertificate_SignCert() {
+	// Sign a certificate with a specific algorithm.
+	privateKey, err := rsa.GenerateKey(rand.Reader, 3072)
+	if err != nil {
+		log.Fatal("unable to generate RSA key: ", err)
+	}
+	publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+	if err != nil {
+		log.Fatal("unable to get RSA public key: ", err)
+	}
+	caKey, err := rsa.GenerateKey(rand.Reader, 3072)
+	if err != nil {
+		log.Fatal("unable to generate CA key: ", err)
+	}
+	signer, err := ssh.NewSignerFromKey(caKey)
+	if err != nil {
+		log.Fatal("unable to generate signer from key: ", err)
+	}
+	mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256})
+	if err != nil {
+		log.Fatal("unable to create signer with algoritms: ", err)
+	}
+	certificate := ssh.Certificate{
+		Key:      publicKey,
+		CertType: ssh.UserCert,
+	}
+	if err := certificate.SignCert(rand.Reader, mas); err != nil {
+		log.Fatal("unable to sign certificate: ", err)
+	}
+	// Save the public key to a file and check that rsa-sha-256 is used for
+	// signing:
+	// ssh-keygen -L -f <path to the file>
+	fmt.Println(string(ssh.MarshalAuthorizedKey(&certificate)))
+}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 07a1843..b95b112 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -461,19 +461,24 @@
 	isServer := len(t.hostKeys) > 0
 	if isServer {
 		for _, k := range t.hostKeys {
-			// If k is an AlgorithmSigner, presume it supports all signature algorithms
-			// associated with the key format. (Ideally AlgorithmSigner would have a
-			// method to advertise supported algorithms, but it doesn't. This means that
-			// adding support for a new algorithm is a breaking change, as we will
-			// immediately negotiate it even if existing implementations don't support
-			// it. If that ever happens, we'll have to figure something out.)
-			// If k is not an AlgorithmSigner, we can only assume it only supports the
-			// algorithms that matches the key format. (This means that Sign can't pick
-			// a different default.)
+			// If k is a MultiAlgorithmSigner, we restrict the signature
+			// algorithms. If k is a AlgorithmSigner, presume it supports all
+			// signature algorithms associated with the key format. If k is not
+			// an AlgorithmSigner, we can only assume it only supports the
+			// algorithms that matches the key format. (This means that Sign
+			// can't pick a different default).
 			keyFormat := k.PublicKey().Type()
-			if _, ok := k.(AlgorithmSigner); ok {
+
+			switch s := k.(type) {
+			case MultiAlgorithmSigner:
+				for _, algo := range algorithmsForKeyFormat(keyFormat) {
+					if contains(s.Algorithms(), underlyingAlgo(algo)) {
+						msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
+					}
+				}
+			case AlgorithmSigner:
 				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
-			} else {
+			default:
 				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
 			}
 		}
@@ -685,9 +690,16 @@
 
 func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
 	for _, k := range hostKeys {
+		if s, ok := k.(MultiAlgorithmSigner); ok {
+			if !contains(s.Algorithms(), underlyingAlgo(algo)) {
+				continue
+			}
+		}
+
 		if algo == k.PublicKey().Type() {
 			return algorithmSignerWrapper{k}
 		}
+
 		k, ok := k.(AlgorithmSigner)
 		if !ok {
 			continue
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 879143a..65afc20 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -620,3 +620,93 @@
 		t.Fatal(err)
 	}
 }
+
+func TestMultiAlgoSignerHandshake(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	serverConf := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+	}
+	serverConf.AddHostKey(multiAlgoSigner)
+	go NewServerConn(c1, serverConf)
+
+	clientConf := &ClientConfig{
+		User:              "test",
+		Auth:              []AuthMethod{Password("testpw")},
+		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
+		HostKeyAlgorithms: []string{KeyAlgoRSASHA512},
+	}
+
+	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	// ssh-rsa is disabled server side
+	serverConf := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+	}
+	serverConf.AddHostKey(multiAlgoSigner)
+	go NewServerConn(c1, serverConf)
+
+	// the client only supports ssh-rsa
+	clientConf := &ClientConfig{
+		User:              "test",
+		Auth:              []AuthMethod{Password("testpw")},
+		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
+		HostKeyAlgorithms: []string{KeyAlgoRSA},
+	}
+
+	_, _, _, err = NewClientConn(c2, "", clientConf)
+	if err == nil {
+		t.Fatal("succeeded connecting with no common hostkey algorithm")
+	}
+}
+
+func TestPickIncompatibleHostKeyAlgo(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA)
+	if signer != nil {
+		t.Fatal("incompatible signer returned")
+	}
+}
diff --git a/ssh/keys.go b/ssh/keys.go
index 1bf28d8..bcaae50 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -335,7 +335,7 @@
 
 // A Signer can create signatures that verify against a public key.
 //
-// Some Signers provided by this package also implement AlgorithmSigner.
+// Some Signers provided by this package also implement MultiAlgorithmSigner.
 type Signer interface {
 	// PublicKey returns the associated PublicKey.
 	PublicKey() PublicKey
@@ -350,9 +350,9 @@
 // An AlgorithmSigner is a Signer that also supports specifying an algorithm to
 // use for signing.
 //
-// An AlgorithmSigner can't advertise the algorithms it supports, so it should
-// be prepared to be invoked with every algorithm supported by the public key
-// format.
+// An AlgorithmSigner can't advertise the algorithms it supports, unless it also
+// implements MultiAlgorithmSigner, so it should be prepared to be invoked with
+// every algorithm supported by the public key format.
 type AlgorithmSigner interface {
 	Signer
 
@@ -363,6 +363,75 @@
 	SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
 }
 
+// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms
+// supported by that signer.
+type MultiAlgorithmSigner interface {
+	AlgorithmSigner
+
+	// Algorithms returns the available algorithms in preference order. The list
+	// must not be empty, and it must not include certificate types.
+	Algorithms() []string
+}
+
+// NewSignerWithAlgorithms returns a signer restricted to the specified
+// algorithms. The algorithms must be set in preference order. The list must not
+// be empty, and it must not include certificate types. An error is returned if
+// the specified algorithms are incompatible with the public key type.
+func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) {
+	if len(algorithms) == 0 {
+		return nil, errors.New("ssh: please specify at least one valid signing algorithm")
+	}
+	var signerAlgos []string
+	supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type()))
+	if s, ok := signer.(*multiAlgorithmSigner); ok {
+		signerAlgos = s.Algorithms()
+	} else {
+		signerAlgos = supportedAlgos
+	}
+
+	for _, algo := range algorithms {
+		if !contains(supportedAlgos, algo) {
+			return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q",
+				algo, signer.PublicKey().Type())
+		}
+		if !contains(signerAlgos, algo) {
+			return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo)
+		}
+	}
+	return &multiAlgorithmSigner{
+		AlgorithmSigner:     signer,
+		supportedAlgorithms: algorithms,
+	}, nil
+}
+
+type multiAlgorithmSigner struct {
+	AlgorithmSigner
+	supportedAlgorithms []string
+}
+
+func (s *multiAlgorithmSigner) Algorithms() []string {
+	return s.supportedAlgorithms
+}
+
+func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool {
+	if algorithm == "" {
+		algorithm = underlyingAlgo(s.PublicKey().Type())
+	}
+	for _, algo := range s.supportedAlgorithms {
+		if algorithm == algo {
+			return true
+		}
+	}
+	return false
+}
+
+func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	if !s.isAlgorithmSupported(algorithm) {
+		return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms)
+	}
+	return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
+}
+
 type rsaPublicKey rsa.PublicKey
 
 func (r *rsaPublicKey) Type() string {
@@ -526,6 +595,10 @@
 	return k.SignWithAlgorithm(rand, data, k.PublicKey().Type())
 }
 
+func (k *dsaPrivateKey) Algorithms() []string {
+	return []string{k.PublicKey().Type()}
+}
+
 func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
 	if algorithm != "" && algorithm != k.PublicKey().Type() {
 		return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
@@ -975,13 +1048,16 @@
 	return s.SignWithAlgorithm(rand, data, s.pubKey.Type())
 }
 
+func (s *wrappedSigner) Algorithms() []string {
+	return algorithmsForKeyFormat(s.pubKey.Type())
+}
+
 func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
 	if algorithm == "" {
 		algorithm = s.pubKey.Type()
 	}
 
-	supportedAlgos := algorithmsForKeyFormat(s.pubKey.Type())
-	if !contains(supportedAlgos, algorithm) {
+	if !contains(s.Algorithms(), algorithm) {
 		return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type())
 	}
 
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index a8e216e..365d93d 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -111,9 +111,9 @@
 }
 
 func TestKeySignWithAlgorithmVerify(t *testing.T) {
-	for _, priv := range testSigners {
-		if algorithmSigner, ok := priv.(AlgorithmSigner); !ok {
-			t.Errorf("Signers constructed by ssh package should always implement the AlgorithmSigner interface: %T", priv)
+	for k, priv := range testSigners {
+		if algorithmSigner, ok := priv.(MultiAlgorithmSigner); !ok {
+			t.Errorf("Signers %q constructed by ssh package should always implement the MultiAlgorithmSigner interface: %T", k, priv)
 		} else {
 			pub := priv.PublicKey()
 			data := []byte("sign me")
@@ -684,3 +684,34 @@
 		}
 	}
 }
+
+func TestNewSignerWithAlgos(t *testing.T) {
+	algorithSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	_, err := NewSignerWithAlgorithms(algorithSigner, nil)
+	if err == nil {
+		t.Error("signer with algos created with no algorithms")
+	}
+
+	_, err = NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoED25519})
+	if err == nil {
+		t.Error("signer with algos created with invalid algorithms")
+	}
+
+	_, err = NewSignerWithAlgorithms(algorithSigner, []string{CertAlgoRSASHA256v01})
+	if err == nil {
+		t.Error("signer with algos created with certificate algorithms")
+	}
+
+	mas, err := NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Errorf("unable to create signer with valid algorithms: %v", err)
+	}
+
+	_, err = NewSignerWithAlgorithms(mas, []string{KeyAlgoRSA})
+	if err == nil {
+		t.Error("signer with algos created with restricted algorithms")
+	}
+}