ssh: add server side multi-step authentication

Add support for sending back partial success to the client while
handling authentication in the server. This is implemented by a special
error that can be returned by any of the authentication methods, which
contains the authentication methods to offer next.

This patch is based on CL 399075 with some minor changes and the
addition of test cases.

Fixes golang/go#17889
Fixes golang/go#61447
Fixes golang/go#64974

Co-authored-by: Peter Verraedt <peter.verraedt@kuleuven.be>
Change-Id: I05c8f913bb407d22c2e41c4cbe965e36ab4739b0
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/516355
Reviewed-by: Andrew Lytvynov <awly@tailscale.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
diff --git a/ssh/server.go b/ssh/server.go
index c2dfe32..92d7323 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -426,6 +426,35 @@
 	return "[" + strings.Join(errs, ", ") + "]"
 }
 
+// ServerAuthCallbacks defines server-side authentication callbacks.
+type ServerAuthCallbacks struct {
+	// PasswordCallback behaves like [ServerConfig.PasswordCallback].
+	PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
+
+	// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
+	PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+
+	// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
+	KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
+
+	// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
+	GSSAPIWithMICConfig *GSSAPIWithMICConfig
+}
+
+// PartialSuccessError can be returned by any of the [ServerConfig]
+// authentication callbacks to indicate to the client that authentication has
+// partially succeeded, but further steps are required.
+type PartialSuccessError struct {
+	// Next defines the authentication callbacks to apply to further steps. The
+	// available methods communicated to the client are based on the non-nil
+	// ServerAuthCallbacks fields.
+	Next ServerAuthCallbacks
+}
+
+func (p *PartialSuccessError) Error() string {
+	return "ssh: authenticated with partial success"
+}
+
 // ErrNoAuth is the error value returned if no
 // authentication method has been passed yet. This happens as a normal
 // part of the authentication loop, since the client first tries
@@ -441,6 +470,15 @@
 	authFailures := 0
 	var authErrs []error
 	var displayedBanner bool
+	partialSuccessReturned := false
+	// Set the initial authentication callbacks from the config. They can be
+	// changed if a PartialSuccessError is returned.
+	authConfig := ServerAuthCallbacks{
+		PasswordCallback:            config.PasswordCallback,
+		PublicKeyCallback:           config.PublicKeyCallback,
+		KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
+		GSSAPIWithMICConfig:         config.GSSAPIWithMICConfig,
+	}
 
 userAuthLoop:
 	for {
@@ -471,6 +509,11 @@
 			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
 		}
 
+		if s.user != userAuthReq.User && partialSuccessReturned {
+			return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
+				s.user, userAuthReq.User)
+		}
+
 		s.user = userAuthReq.User
 
 		if !displayedBanner && config.BannerCallback != nil {
@@ -491,20 +534,17 @@
 
 		switch userAuthReq.Method {
 		case "none":
-			if config.NoClientAuth {
+			// We don't allow none authentication after a partial success
+			// response.
+			if config.NoClientAuth && !partialSuccessReturned {
 				if config.NoClientAuthCallback != nil {
 					perms, authErr = config.NoClientAuthCallback(s)
 				} else {
 					authErr = nil
 				}
 			}
-
-			// allow initial attempt of 'none' without penalty
-			if authFailures == 0 {
-				authFailures--
-			}
 		case "password":
-			if config.PasswordCallback == nil {
+			if authConfig.PasswordCallback == nil {
 				authErr = errors.New("ssh: password auth not configured")
 				break
 			}
@@ -518,17 +558,17 @@
 				return nil, parseError(msgUserAuthRequest)
 			}
 
-			perms, authErr = config.PasswordCallback(s, password)
+			perms, authErr = authConfig.PasswordCallback(s, password)
 		case "keyboard-interactive":
-			if config.KeyboardInteractiveCallback == nil {
+			if authConfig.KeyboardInteractiveCallback == nil {
 				authErr = errors.New("ssh: keyboard-interactive auth not configured")
 				break
 			}
 
 			prompter := &sshClientKeyboardInteractive{s}
-			perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
+			perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
 		case "publickey":
-			if config.PublicKeyCallback == nil {
+			if authConfig.PublicKeyCallback == nil {
 				authErr = errors.New("ssh: publickey auth not configured")
 				break
 			}
@@ -562,11 +602,18 @@
 			if !ok {
 				candidate.user = s.user
 				candidate.pubKeyData = pubKeyData
-				candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
-				if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
-					candidate.result = checkSourceAddress(
+				candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+
+				if (candidate.result == nil || isPartialSuccessError) &&
+					candidate.perms != nil &&
+					candidate.perms.CriticalOptions != nil &&
+					candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
+					if err := checkSourceAddress(
 						s.RemoteAddr(),
-						candidate.perms.CriticalOptions[sourceAddressCriticalOption])
+						candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
+						candidate.result = err
+					}
 				}
 				cache.add(candidate)
 			}
@@ -578,8 +625,8 @@
 				if len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 				}
-
-				if candidate.result == nil {
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+				if candidate.result == nil || isPartialSuccessError {
 					okMsg := userAuthPubKeyOkMsg{
 						Algo:   algo,
 						PubKey: pubKeyData,
@@ -629,11 +676,11 @@
 				perms = candidate.perms
 			}
 		case "gssapi-with-mic":
-			if config.GSSAPIWithMICConfig == nil {
+			if authConfig.GSSAPIWithMICConfig == nil {
 				authErr = errors.New("ssh: gssapi-with-mic auth not configured")
 				break
 			}
-			gssapiConfig := config.GSSAPIWithMICConfig
+			gssapiConfig := authConfig.GSSAPIWithMICConfig
 			userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
 			if err != nil {
 				return nil, parseError(msgUserAuthRequest)
@@ -689,49 +736,70 @@
 			break userAuthLoop
 		}
 
-		authFailures++
-		if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
-			// If we have hit the max attempts, don't bother sending the
-			// final SSH_MSG_USERAUTH_FAILURE message, since there are
-			// no more authentication methods which can be attempted,
-			// and this message may cause the client to re-attempt
-			// authentication while we send the disconnect message.
-			// Continue, and trigger the disconnect at the start of
-			// the loop.
-			//
-			// The SSH specification is somewhat confusing about this,
-			// RFC 4252 Section 5.1 requires each authentication failure
-			// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
-			// message, but Section 4 says the server should disconnect
-			// after some number of attempts, but it isn't explicit which
-			// message should take precedence (i.e. should there be a failure
-			// message than a disconnect message, or if we are going to
-			// disconnect, should we only send that message.)
-			//
-			// Either way, OpenSSH disconnects immediately after the last
-			// failed authnetication attempt, and given they are typically
-			// considered the golden implementation it seems reasonable
-			// to match that behavior.
-			continue
+		var failureMsg userAuthFailureMsg
+
+		if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
+			// After a partial success error we don't allow changing the user
+			// name and execute the NoClientAuthCallback.
+			partialSuccessReturned = true
+
+			// In case a partial success is returned, the server may send
+			// a new set of authentication methods.
+			authConfig = partialSuccess.Next
+
+			// Reset pubkey cache, as the new PublicKeyCallback might
+			// accept a different set of public keys.
+			cache = pubKeyCache{}
+
+			// Send back a partial success message to the user.
+			failureMsg.PartialSuccess = true
+		} else {
+			// Allow initial attempt of 'none' without penalty.
+			if authFailures > 0 || userAuthReq.Method != "none" {
+				authFailures++
+			}
+			if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
+				// If we have hit the max attempts, don't bother sending the
+				// final SSH_MSG_USERAUTH_FAILURE message, since there are
+				// no more authentication methods which can be attempted,
+				// and this message may cause the client to re-attempt
+				// authentication while we send the disconnect message.
+				// Continue, and trigger the disconnect at the start of
+				// the loop.
+				//
+				// The SSH specification is somewhat confusing about this,
+				// RFC 4252 Section 5.1 requires each authentication failure
+				// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
+				// message, but Section 4 says the server should disconnect
+				// after some number of attempts, but it isn't explicit which
+				// message should take precedence (i.e. should there be a failure
+				// message than a disconnect message, or if we are going to
+				// disconnect, should we only send that message.)
+				//
+				// Either way, OpenSSH disconnects immediately after the last
+				// failed authnetication attempt, and given they are typically
+				// considered the golden implementation it seems reasonable
+				// to match that behavior.
+				continue
+			}
 		}
 
-		var failureMsg userAuthFailureMsg
-		if config.PasswordCallback != nil {
+		if authConfig.PasswordCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "password")
 		}
-		if config.PublicKeyCallback != nil {
+		if authConfig.PublicKeyCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "publickey")
 		}
-		if config.KeyboardInteractiveCallback != nil {
+		if authConfig.KeyboardInteractiveCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
 		}
-		if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
-			config.GSSAPIWithMICConfig.AllowLogin != nil {
+		if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
+			authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
 		}
 
 		if len(failureMsg.Methods) == 0 {
-			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
+			return nil, errors.New("ssh: no authentication methods available")
 		}
 
 		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
diff --git a/ssh/server_multi_auth_test.go b/ssh/server_multi_auth_test.go
new file mode 100644
index 0000000..3b39802
--- /dev/null
+++ b/ssh/server_multi_auth_test.go
@@ -0,0 +1,412 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"strings"
+	"testing"
+)
+
+func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	var serverAuthErrors []error
+
+	serverConfig.AddHostKey(testSigners["rsa"])
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err == nil {
+		c.Close()
+	}
+	return serverAuthErrors, err
+}
+
+func TestMultiStepAuth(t *testing.T) {
+	// This user can login with password, public key or public key + password.
+	username := "testuser"
+	// This user can login with public key + password only.
+	usernameSecondFactor := "testuser_second_factor"
+	errPwdAuthFailed := errors.New("password auth failed")
+	errWrongSequence := errors.New("wrong sequence")
+
+	serverConfig := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == usernameSecondFactor {
+				return nil, errWrongSequence
+			}
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errPwdAuthFailed
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+				if conn.User() == usernameSecondFactor {
+					return nil, &PartialSuccessError{
+						Next: ServerAuthCallbacks{
+							PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+								if string(password) == clientPassword {
+									return nil, nil
+								}
+								return nil, errPwdAuthFailed
+							},
+						},
+					}
+				}
+				return nil, nil
+			}
+			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: usernameSecondFactor,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
+	}
+	// Now test a wrong sequence.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong sequence must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	// - partial success
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+	if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
+	}
+	// Now test using a correct sequence but a wrong password before the right
+	// one.
+	n := 0
+	passwords := []string{"WRONG", "WRONG", clientPassword}
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		RetryableAuthMethod(PasswordCallback(func() (string, error) {
+			p := passwords[n]
+			n++
+			return p, nil
+		}), 3),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - wrong password
+	// - wrong password
+	// - nil
+	if len(serverAuthErrors) != 5 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	if serverAuthErrors[3] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	// Only password authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with password only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+
+	// Only public key authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with public key only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// Public key and wrong password.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		Password("WRONG"),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong password after public key must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - password auth failed
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+
+	// Public key, public key again and then correct password. Public key
+	// authentication is attempted only once because the partial success error
+	// returns only "password" as the allowed authentication method.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		PublicKeys(testSigners["rsa"]),
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// The unrestricted username can do anything
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+}
+
+func TestDynamicAuthCallbacks(t *testing.T) {
+	user1 := "user1"
+	user2 := "user2"
+	errInvalidCredentials := errors.New("invalid credentials")
+
+	serverConfig := &ServerConfig{
+		NoClientAuth: true,
+		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
+			switch conn.User() {
+			case user1:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+							if conn.User() == user1 && string(password) == clientPassword {
+								return nil, nil
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			case user2:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+							if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+								if conn.User() == user2 {
+									return nil, nil
+								}
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			default:
+				return nil, errInvalidCredentials
+			}
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// user1 cannot login with public key
+	clientConfig = &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user1 login with public key must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	// user2 cannot login with password
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user2 login with password must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+}