ssh: support SSH agent signature flags and custom extensions

This commit implements two new features. To preserve backwards
compatibility the new methods are added to an `ExtendedAgent` interface
which extends `Agent`. The client code implements `ExtendedAgent`
(which extends Agent) so you can call these additional methods against
SSH agents such as the OpenSSH agent. The ServeAgent method still
accepts Agent but will attempt to upcast the agent to `ExtendedAgent`
as needed, so if you supply an ExtendedAgent implementation you can
implement these additional methods (which keyring does).

The first feature is supporting the standard flags that can be passed to
SSH Sign requests, requesting that RSA signatures use SHA-256 or
SHA-512. See section 4.5.1 of the SSH agent protocol draft:
https://tools.ietf.org/html/draft-miller-ssh-agent-02

The second feature is supporting calling custom extensions from clients
and implementing custom extensions from servers. See section 4.7 of the
SSH agent protocol draft:
https://tools.ietf.org/html/draft-miller-ssh-agent-02

Change-Id: I0f74feb893762c27e921ec37604d3a46434ee6ef
GitHub-Last-Rev: 2e23fd01c0e95b664e8507682f0bd5bd61d4c146
GitHub-Pull-Request: golang/crypto#53
Reviewed-on: https://go-review.googlesource.com/c/123955
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
index b1808dd..1934950 100644
--- a/ssh/agent/client.go
+++ b/ssh/agent/client.go
@@ -25,10 +25,22 @@
 	"math/big"
 	"sync"
 
+	"crypto"
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ssh"
 )
 
+// SignatureFlags represent additional flags that can be passed to the signature
+// requests an defined in [PROTOCOL.agent] section 4.5.1.
+type SignatureFlags uint32
+
+// SignatureFlag values as defined in [PROTOCOL.agent] section 5.3.
+const (
+	SignatureFlagReserved SignatureFlags = 1 << iota
+	SignatureFlagRsaSha256
+	SignatureFlagRsaSha512
+)
+
 // Agent represents the capabilities of an ssh-agent.
 type Agent interface {
 	// List returns the identities known to the agent.
@@ -57,6 +69,26 @@
 	Signers() ([]ssh.Signer, error)
 }
 
+type ExtendedAgent interface {
+	Agent
+
+	// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
+	SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error)
+
+	// Extension processes a custom extension request. Standard-compliant agents are not
+	// required to support any extensions, but this method allows agents to implement
+	// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
+	// If agent extensions are unsupported entirely this method MUST return an
+	// ErrExtensionUnsupported error. Similarly, if just the specific extensionType in
+	// the request is unsupported by the agent then ErrExtensionUnsupported MUST be
+	// returned.
+	//
+	// In the case of success, since [PROTOCOL.agent] section 4.7 specifies that the contents
+	// of the response are unspecified (including the type of the message), the complete
+	// response will be returned as a []byte slice, including the "type" byte of the message.
+	Extension(extensionType string, contents []byte) ([]byte, error)
+}
+
 // ConstraintExtension describes an optional constraint defined by users.
 type ConstraintExtension struct {
 	// ExtensionName consist of a UTF-8 string suffixed by the
@@ -179,6 +211,23 @@
 	Rest []byte `ssh:"rest"`
 }
 
+// See [PROTOCOL.agent], section 4.7
+const agentExtension = 27
+const agentExtensionFailure = 28
+
+// ErrExtensionUnsupported indicates that an extension defined in
+// [PROTOCOL.agent] section 4.7 is unsupported by the agent. Specifically this
+// error indicates that the agent returned a standard SSH_AGENT_FAILURE message
+// as the result of a SSH_AGENTC_EXTENSION request. Note that the protocol
+// specification (and therefore this error) does not distinguish between a
+// specific extension being unsupported and extensions being unsupported entirely.
+var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
+
+type extensionAgentMsg struct {
+	ExtensionType string `sshtype:"27"`
+	Contents      []byte
+}
+
 // Key represents a protocol 2 public key as defined in
 // [PROTOCOL.agent], section 2.5.2.
 type Key struct {
@@ -260,7 +309,7 @@
 
 // NewClient returns an Agent that talks to an ssh-agent process over
 // the given connection.
-func NewClient(rw io.ReadWriter) Agent {
+func NewClient(rw io.ReadWriter) ExtendedAgent {
 	return &client{conn: rw}
 }
 
@@ -268,6 +317,21 @@
 // unmarshaled into reply and replyType is set to the first byte of
 // the reply, which contains the type of the message.
 func (c *client) call(req []byte) (reply interface{}, err error) {
+	buf, err := c.callRaw(req)
+	if err != nil {
+		return nil, err
+	}
+	reply, err = unmarshal(buf)
+	if err != nil {
+		return nil, clientErr(err)
+	}
+	return reply, nil
+}
+
+// callRaw sends an RPC to the agent. On success, the raw
+// bytes of the response are returned; no unmarshalling is
+// performed on the response.
+func (c *client) callRaw(req []byte) (reply []byte, err error) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
@@ -291,11 +355,7 @@
 	if _, err = io.ReadFull(c.conn, buf); err != nil {
 		return nil, clientErr(err)
 	}
-	reply, err = unmarshal(buf)
-	if err != nil {
-		return nil, clientErr(err)
-	}
-	return reply, err
+	return buf, nil
 }
 
 func (c *client) simpleCall(req []byte) error {
@@ -369,9 +429,14 @@
 // Sign has the agent sign the data using a protocol 2 key as defined
 // in [PROTOCOL.agent] section 2.6.2.
 func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+	return c.SignWithFlags(key, data, 0)
+}
+
+func (c *client) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
 	req := ssh.Marshal(signRequestAgentMsg{
 		KeyBlob: key.Marshal(),
 		Data:    data,
+		Flags:   uint32(flags),
 	})
 
 	msg, err := c.call(req)
@@ -681,3 +746,44 @@
 	// The agent has its own entropy source, so the rand argument is ignored.
 	return s.agent.Sign(s.pub, data)
 }
+
+func (s *agentKeyringSigner) SignWithOpts(rand io.Reader, data []byte, opts crypto.SignerOpts) (*ssh.Signature, error) {
+	var flags SignatureFlags
+	if opts != nil {
+		switch opts.HashFunc() {
+		case crypto.SHA256:
+			flags = SignatureFlagRsaSha256
+		case crypto.SHA512:
+			flags = SignatureFlagRsaSha512
+		}
+	}
+	return s.agent.SignWithFlags(s.pub, data, flags)
+}
+
+// Calls an extension method. It is up to the agent implementation as to whether or not
+// any particular extension is supported and may always return an error. Because the
+// type of the response is up to the implementation, this returns the bytes of the
+// response and does not attempt any type of unmarshalling.
+func (c *client) Extension(extensionType string, contents []byte) ([]byte, error) {
+	req := ssh.Marshal(extensionAgentMsg{
+		ExtensionType: extensionType,
+		Contents:      contents,
+	})
+	buf, err := c.callRaw(req)
+	if err != nil {
+		return nil, err
+	}
+	if len(buf) == 0 {
+		return nil, errors.New("agent: failure; empty response")
+	}
+	// [PROTOCOL.agent] section 4.7 indicates that an SSH_AGENT_FAILURE message
+	// represents an agent that does not support the extension
+	if buf[0] == agentFailure {
+		return nil, ErrExtensionUnsupported
+	}
+	if buf[0] == agentExtensionFailure {
+		return nil, errors.New("agent: generic extension failure")
+	}
+
+	return buf, nil
+}
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
index 266fd6d..e8cd978 100644
--- a/ssh/agent/client_test.go
+++ b/ssh/agent/client_test.go
@@ -20,7 +20,7 @@
 )
 
 // startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
-func startOpenSSHAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
+func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
 	if testing.Short() {
 		// ssh-agent is not always available, and the key
 		// types supported vary by platform.
@@ -79,13 +79,12 @@
 	}
 }
 
-// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
-func startKeyringAgent(t *testing.T) (client Agent, cleanup func()) {
+func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
 	c1, c2, err := netPipe()
 	if err != nil {
 		t.Fatalf("netPipe: %v", err)
 	}
-	go ServeAgent(NewKeyring(), c2)
+	go ServeAgent(agent, c2)
 
 	return NewClient(c1), func() {
 		c1.Close()
@@ -93,6 +92,11 @@
 	}
 }
 
+// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
+func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
+	return startAgent(t, NewKeyring())
+}
+
 func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
 	agent, _, cleanup := startOpenSSHAgent(t)
 	defer cleanup()
@@ -107,7 +111,7 @@
 	testAgentInterface(t, agent, key, cert, lifetimeSecs)
 }
 
-func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
+func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
 	signer, err := ssh.NewSignerFromKey(key)
 	if err != nil {
 		t.Fatalf("NewSignerFromKey(%T): %v", key, err)
@@ -159,6 +163,25 @@
 		t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
 	}
 
+	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
+	if pubKey.Type() == "ssh-rsa" {
+		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
+			sig, err = agent.SignWithFlags(pubKey, data, flag)
+			if err != nil {
+				t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
+			}
+			if sig.Format != expectedSigFormat {
+				t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
+			}
+			if err := pubKey.Verify(data, sig); err != nil {
+				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
+			}
+		}
+		sshFlagTest(0, ssh.SigAlgoRSA)
+		sshFlagTest(SignatureFlagRsaSha256, ssh.SigAlgoRSASHA2256)
+		sshFlagTest(SignatureFlagRsaSha512, ssh.SigAlgoRSASHA2512)
+	}
+
 	// If the key has a lifetime, is it removed when it should be?
 	if lifetimeSecs > 0 {
 		time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
@@ -377,3 +400,38 @@
 		t.Errorf("Want 0 keys, got %v", len(keys))
 	}
 }
+
+type keyringExtended struct {
+	*keyring
+}
+
+func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
+	if extensionType != "my-extension@example.com" {
+		return []byte{agentExtensionFailure}, nil
+	}
+	return append([]byte{agentSuccess}, contents...), nil
+}
+
+func TestAgentExtensions(t *testing.T) {
+	agent, _, cleanup := startOpenSSHAgent(t)
+	defer cleanup()
+	result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
+	if err == nil {
+		t.Fatal("should have gotten agent extension failure")
+	}
+
+	agent, cleanup = startAgent(t, &keyringExtended{})
+	defer cleanup()
+	result, err = agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
+	if err != nil {
+		t.Fatalf("agent extension failure: %v", err)
+	}
+	if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
+		t.Fatalf("agent extension result invalid: %v", result)
+	}
+
+	result, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
+	if err == nil {
+		t.Fatal("should have gotten agent extension failure")
+	}
+}
diff --git a/ssh/agent/keyring.go b/ssh/agent/keyring.go
index 1a51632..c9d9794 100644
--- a/ssh/agent/keyring.go
+++ b/ssh/agent/keyring.go
@@ -182,6 +182,10 @@
 
 // Sign returns a signature for the data.
 func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+	return r.SignWithFlags(key, data, 0)
+}
+
+func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
 	r.mu.Lock()
 	defer r.mu.Unlock()
 	if r.locked {
@@ -192,7 +196,24 @@
 	wanted := key.Marshal()
 	for _, k := range r.keys {
 		if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
-			return k.signer.Sign(rand.Reader, data)
+			if flags == 0 {
+				return k.signer.Sign(rand.Reader, data)
+			} else {
+				if algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner); !ok {
+					return nil, fmt.Errorf("agent: signature does not support non-default signature algorithm: %T", k.signer)
+				} else {
+					var algorithm string
+					switch flags {
+					case SignatureFlagRsaSha256:
+						algorithm = ssh.SigAlgoRSASHA2256
+					case SignatureFlagRsaSha512:
+						algorithm = ssh.SigAlgoRSASHA2512
+					default:
+						return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
+					}
+					return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
+				}
+			}
 		}
 	}
 	return nil, errors.New("not found")
@@ -213,3 +234,8 @@
 	}
 	return s, nil
 }
+
+// The keyring does not support any extensions
+func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) {
+	return nil, ErrExtensionUnsupported
+}
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
index 2e4692c..a194976 100644
--- a/ssh/agent/server.go
+++ b/ssh/agent/server.go
@@ -128,7 +128,14 @@
 			Blob:   req.KeyBlob,
 		}
 
-		sig, err := s.agent.Sign(k, req.Data) //  TODO(hanwen): flags.
+		var sig *ssh.Signature
+		var err error
+		if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
+			sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
+		} else {
+			sig, err = s.agent.Sign(k, req.Data)
+		}
+
 		if err != nil {
 			return nil, err
 		}
@@ -150,6 +157,43 @@
 
 	case agentAddIDConstrained, agentAddIdentity:
 		return nil, s.insertIdentity(data)
+
+	case agentExtension:
+		// Return a stub object where the whole contents of the response gets marshaled.
+		var responseStub struct {
+			Rest []byte `ssh:"rest"`
+		}
+
+		if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
+			// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
+			// requires that we return a standard SSH_AGENT_FAILURE message.
+			responseStub.Rest = []byte{agentFailure}
+		} else {
+			var req extensionAgentMsg
+			if err := ssh.Unmarshal(data, &req); err != nil {
+				return nil, err
+			}
+			res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
+			if err != nil {
+				// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
+				// message as required by [PROTOCOL.agent] section 4.7.
+				if err == ErrExtensionUnsupported {
+					responseStub.Rest = []byte{agentFailure}
+				} else {
+					// As the result of any other error processing an extension request,
+					// [PROTOCOL.agent] section 4.7 requires that we return a
+					// SSH_AGENT_EXTENSION_FAILURE code.
+					responseStub.Rest = []byte{agentExtensionFailure}
+				}
+			} else {
+				if len(res) == 0 {
+					return nil, nil
+				}
+				responseStub.Rest = res
+			}
+		}
+
+		return responseStub, nil
 	}
 
 	return nil, fmt.Errorf("unknown opcode %d", data[0])
diff --git a/ssh/certs.go b/ssh/certs.go
index 42106f3..00ed992 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -222,6 +222,11 @@
 	signer Signer
 }
 
+type algorithmOpenSSHCertSigner struct {
+	*openSSHCertSigner
+	algorithmSigner AlgorithmSigner
+}
+
 // NewCertSigner returns a Signer that signs with the given Certificate, whose
 // private key is held by signer. It returns an error if the public key in cert
 // doesn't match the key used by signer.
@@ -230,7 +235,12 @@
 		return nil, errors.New("ssh: signer and cert have different public key")
 	}
 
-	return &openSSHCertSigner{cert, signer}, nil
+	if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
+		return &algorithmOpenSSHCertSigner{
+			&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
+	} else {
+		return &openSSHCertSigner{cert, signer}, nil
+	}
 }
 
 func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
@@ -241,6 +251,10 @@
 	return s.pub
 }
 
+func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
+}
+
 const sourceAddressCriticalOption = "source-address"
 
 // CertChecker does the work of verifying a certificate. Its methods
diff --git a/ssh/keys.go b/ssh/keys.go
index 2261dc3..9698047 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -38,6 +38,16 @@
 	KeyAlgoED25519  = "ssh-ed25519"
 )
 
+// These constants represent non-default signature algorithms that are supported
+// as algorithm parameters to AlgorithmSigner.SignWithAlgorithm methods. See
+// [PROTOCOL.agent] section 4.5.1 and
+// https://tools.ietf.org/html/draft-ietf-curdle-rsa-sha2-10
+const (
+	SigAlgoRSA        = "ssh-rsa"
+	SigAlgoRSASHA2256 = "rsa-sha2-256"
+	SigAlgoRSASHA2512 = "rsa-sha2-512"
+)
+
 // parsePubKey parses a public key of the given algorithm.
 // Use ParsePublicKey for keys with prepended algorithm.
 func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
@@ -301,6 +311,19 @@
 	Sign(rand io.Reader, data []byte) (*Signature, error)
 }
 
+// A AlgorithmSigner is a Signer that also supports specifying a specific
+// algorithm to use for signing.
+type AlgorithmSigner interface {
+	Signer
+
+	// SignWithAlgorithm is like Signer.Sign, but allows specification of a
+	// non-default signing algorithm. See the SigAlgo* constants in this
+	// package for signature algorithms supported by this package. Callers may
+	// pass an empty string for the algorithm in which case the AlgorithmSigner
+	// will use its default algorithm.
+	SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
+}
+
 type rsaPublicKey rsa.PublicKey
 
 func (r *rsaPublicKey) Type() string {
@@ -349,13 +372,21 @@
 }
 
 func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
-	if sig.Format != r.Type() {
+	var hash crypto.Hash
+	switch sig.Format {
+	case SigAlgoRSA:
+		hash = crypto.SHA1
+	case SigAlgoRSASHA2256:
+		hash = crypto.SHA256
+	case SigAlgoRSASHA2512:
+		hash = crypto.SHA512
+	default:
 		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
 	}
-	h := crypto.SHA1.New()
+	h := hash.New()
 	h.Write(data)
 	digest := h.Sum(nil)
-	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
+	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
 }
 
 func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
@@ -459,6 +490,14 @@
 }
 
 func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	return k.SignWithAlgorithm(rand, data, "")
+}
+
+func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	if algorithm != "" && algorithm != k.PublicKey().Type() {
+		return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
+	}
+
 	h := crypto.SHA1.New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -691,16 +730,42 @@
 }
 
 func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	return s.SignWithAlgorithm(rand, data, "")
+}
+
+func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
 	var hashFunc crypto.Hash
 
-	switch key := s.pubKey.(type) {
-	case *rsaPublicKey, *dsaPublicKey:
-		hashFunc = crypto.SHA1
-	case *ecdsaPublicKey:
-		hashFunc = ecHash(key.Curve)
-	case ed25519PublicKey:
-	default:
-		return nil, fmt.Errorf("ssh: unsupported key type %T", key)
+	if _, ok := s.pubKey.(*rsaPublicKey); ok {
+		// RSA keys support a few hash functions determined by the requested signature algorithm
+		switch algorithm {
+		case "", SigAlgoRSA:
+			algorithm = SigAlgoRSA
+			hashFunc = crypto.SHA1
+		case SigAlgoRSASHA2256:
+			hashFunc = crypto.SHA256
+		case SigAlgoRSASHA2512:
+			hashFunc = crypto.SHA512
+		default:
+			return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
+		}
+	} else {
+		// The only supported algorithm for all other key types is the same as the type of the key
+		if algorithm == "" {
+			algorithm = s.pubKey.Type()
+		} else if algorithm != s.pubKey.Type() {
+			return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
+		}
+
+		switch key := s.pubKey.(type) {
+		case *dsaPublicKey:
+			hashFunc = crypto.SHA1
+		case *ecdsaPublicKey:
+			hashFunc = ecHash(key.Curve)
+		case ed25519PublicKey:
+		default:
+			return nil, fmt.Errorf("ssh: unsupported key type %T", key)
+		}
 	}
 
 	var digest []byte
@@ -745,7 +810,7 @@
 	}
 
 	return &Signature{
-		Format: s.pubKey.Type(),
+		Format: algorithm,
 		Blob:   signature,
 	}, nil
 }
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index f28725f..3847b3b 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -109,6 +109,49 @@
 	}
 }
 
+func TestKeySignWithAlgorithmVerify(t *testing.T) {
+	for _, priv := range testSigners {
+		if algorithmSigner, ok := priv.(AlgorithmSigner); !ok {
+			t.Errorf("Signers constructed by ssh package should always implement the AlgorithmSigner interface: %T", priv)
+		} else {
+			pub := priv.PublicKey()
+			data := []byte("sign me")
+
+			signWithAlgTestCase := func(algorithm string, expectedAlg string) {
+				sig, err := algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
+				if err != nil {
+					t.Fatalf("Sign(%T): %v", priv, err)
+				}
+				if sig.Format != expectedAlg {
+					t.Errorf("signature format did not match requested signature algorithm: %s != %s", sig.Format, expectedAlg)
+				}
+
+				if err := pub.Verify(data, sig); err != nil {
+					t.Errorf("publicKey.Verify(%T): %v", priv, err)
+				}
+				sig.Blob[5]++
+				if err := pub.Verify(data, sig); err == nil {
+					t.Errorf("publicKey.Verify on broken sig did not fail")
+				}
+			}
+
+			// Using the empty string as the algorithm name should result in the same signature format as the algorithm-free Sign method.
+			defaultSig, err := priv.Sign(rand.Reader, data)
+			if err != nil {
+				t.Fatalf("Sign(%T): %v", priv, err)
+			}
+			signWithAlgTestCase("", defaultSig.Format)
+
+			// RSA keys are the only ones which currently support more than one signing algorithm
+			if pub.Type() == KeyAlgoRSA {
+				for _, algorithm := range []string{SigAlgoRSA, SigAlgoRSASHA2256, SigAlgoRSASHA2512} {
+					signWithAlgTestCase(algorithm, algorithm)
+				}
+			}
+		}
+	}
+}
+
 func TestParseRSAPrivateKey(t *testing.T) {
 	key := testPrivateKeys["rsa"]