blob: 3b3980243763ffccf86eb3bc09fa3212b0b2e739 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"strings"
"testing"
)
func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
var serverAuthErrors []error
serverConfig.AddHostKey(testSigners["rsa"])
serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
serverAuthErrors = append(serverAuthErrors, err)
}
go newServer(c1, serverConfig)
c, _, _, err := NewClientConn(c2, "", clientConfig)
if err == nil {
c.Close()
}
return serverAuthErrors, err
}
func TestMultiStepAuth(t *testing.T) {
// This user can login with password, public key or public key + password.
username := "testuser"
// This user can login with public key + password only.
usernameSecondFactor := "testuser_second_factor"
errPwdAuthFailed := errors.New("password auth failed")
errWrongSequence := errors.New("wrong sequence")
serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
if conn.User() == usernameSecondFactor {
return nil, errWrongSequence
}
if conn.User() == username && string(password) == clientPassword {
return nil, nil
}
return nil, errPwdAuthFailed
},
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
if conn.User() == usernameSecondFactor {
return nil, &PartialSuccessError{
Next: ServerAuthCallbacks{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
if string(password) == clientPassword {
return nil, nil
}
return nil, errPwdAuthFailed
},
},
}
}
return nil, nil
}
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
}
clientConfig := &ClientConfig{
User: usernameSecondFactor,
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %s", err)
}
// The error sequence is:
// - no auth passed yet
// - partial success
// - nil
if len(serverAuthErrors) != 3 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
}
// Now test a wrong sequence.
clientConfig.Auth = []AuthMethod{
Password(clientPassword),
PublicKeys(testSigners["rsa"]),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("client login with wrong sequence must fail")
}
// The error sequence is:
// - no auth passed yet
// - wrong sequence
// - partial success
if len(serverAuthErrors) != 3 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if serverAuthErrors[1] != errWrongSequence {
t.Fatal("server not returned wrong sequence")
}
if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
}
// Now test using a correct sequence but a wrong password before the right
// one.
n := 0
passwords := []string{"WRONG", "WRONG", clientPassword}
clientConfig.Auth = []AuthMethod{
PublicKeys(testSigners["rsa"]),
RetryableAuthMethod(PasswordCallback(func() (string, error) {
p := passwords[n]
n++
return p, nil
}), 3),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %s", err)
}
// The error sequence is:
// - no auth passed yet
// - partial success
// - wrong password
// - wrong password
// - nil
if len(serverAuthErrors) != 5 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
if serverAuthErrors[2] != errPwdAuthFailed {
t.Fatal("server not returned password authentication failed")
}
if serverAuthErrors[3] != errPwdAuthFailed {
t.Fatal("server not returned password authentication failed")
}
// Only password authentication should fail.
clientConfig.Auth = []AuthMethod{
Password(clientPassword),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("client login with password only must fail")
}
// The error sequence is:
// - no auth passed yet
// - wrong sequence
if len(serverAuthErrors) != 2 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if serverAuthErrors[1] != errWrongSequence {
t.Fatal("server not returned wrong sequence")
}
// Only public key authentication should fail.
clientConfig.Auth = []AuthMethod{
PublicKeys(testSigners["rsa"]),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("client login with public key only must fail")
}
// The error sequence is:
// - no auth passed yet
// - partial success
if len(serverAuthErrors) != 2 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
// Public key and wrong password.
clientConfig.Auth = []AuthMethod{
PublicKeys(testSigners["rsa"]),
Password("WRONG"),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("client login with wrong password after public key must fail")
}
// The error sequence is:
// - no auth passed yet
// - partial success
// - password auth failed
if len(serverAuthErrors) != 3 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
if serverAuthErrors[2] != errPwdAuthFailed {
t.Fatal("server not returned password authentication failed")
}
// Public key, public key again and then correct password. Public key
// authentication is attempted only once because the partial success error
// returns only "password" as the allowed authentication method.
clientConfig.Auth = []AuthMethod{
PublicKeys(testSigners["rsa"]),
PublicKeys(testSigners["rsa"]),
Password(clientPassword),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %s", err)
}
// The error sequence is:
// - no auth passed yet
// - partial success
// - nil
if len(serverAuthErrors) != 3 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
// The unrestricted username can do anything
clientConfig = &ClientConfig{
User: username,
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
_, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("unrestricted client login error: %s", err)
}
clientConfig = &ClientConfig{
User: username,
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
_, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("unrestricted client login error: %s", err)
}
clientConfig = &ClientConfig{
User: username,
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
_, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("unrestricted client login error: %s", err)
}
}
func TestDynamicAuthCallbacks(t *testing.T) {
user1 := "user1"
user2 := "user2"
errInvalidCredentials := errors.New("invalid credentials")
serverConfig := &ServerConfig{
NoClientAuth: true,
NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
switch conn.User() {
case user1:
return nil, &PartialSuccessError{
Next: ServerAuthCallbacks{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
if conn.User() == user1 && string(password) == clientPassword {
return nil, nil
}
return nil, errInvalidCredentials
},
},
}
case user2:
return nil, &PartialSuccessError{
Next: ServerAuthCallbacks{
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
if conn.User() == user2 {
return nil, nil
}
}
return nil, errInvalidCredentials
},
},
}
default:
return nil, errInvalidCredentials
}
},
}
clientConfig := &ClientConfig{
User: user1,
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %s", err)
}
// The error sequence is:
// - partial success
// - nil
if len(serverAuthErrors) != 2 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
clientConfig = &ClientConfig{
User: user2,
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %s", err)
}
// The error sequence is:
// - partial success
// - nil
if len(serverAuthErrors) != 2 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
// user1 cannot login with public key
clientConfig = &ClientConfig{
User: user1,
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("user1 login with public key must fail")
}
if !strings.Contains(err.Error(), "no supported methods remain") {
t.Errorf("got %v, expected 'no supported methods remain'", err)
}
if len(serverAuthErrors) != 1 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
// user2 cannot login with password
clientConfig = &ClientConfig{
User: user2,
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
if err == nil {
t.Fatal("user2 login with password must fail")
}
if !strings.Contains(err.Error(), "no supported methods remain") {
t.Errorf("got %v, expected 'no supported methods remain'", err)
}
if len(serverAuthErrors) != 1 {
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
t.Fatal("server not returned partial success")
}
}