ssh: return authErr array if all authentication attempts fail

Change-Id: I4d6cab266410a8c7960073665eddf8935693087f
Reviewed-on: https://go-review.googlesource.com/44332
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index bd9f8a1..2e9672a 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -577,3 +577,54 @@
 		t.Fatalf("client: got %s, want %s", err, expectedErr)
 	}
 }
+
+// Test whether authentication errors are being properly logged if all
+// authentication methods have been exhausted
+func TestClientAuthErrorList(t *testing.T) {
+	publicKeyErr := errors.New("This is an error from PublicKeyCallback")
+
+	clientConfig := &ClientConfig{
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	serverConfig := &ServerConfig{
+		PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) {
+			return nil, publicKeyErr
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	go NewClientConn(c2, "", clientConfig)
+	_, err = newServer(c1, serverConfig)
+	if err == nil {
+		t.Fatal("newServer: got nil, expected errors")
+	}
+
+	authErrs, ok := err.(*ServerAuthError)
+	if !ok {
+		t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err)
+	}
+	for i, e := range authErrs.Errors {
+		switch i {
+		case 0:
+			if e.Error() != "no auth passed yet" {
+				t.Fatalf("errors: got %v, want no auth passed yet", e.Error())
+			}
+		case 1:
+			if e != publicKeyErr {
+				t.Fatalf("errors: got %v, want %v", e, publicKeyErr)
+			}
+		default:
+			t.Fatal("errors: got %v, expected 2 errors", authErrs.Errors)
+		}
+	}
+}
diff --git a/ssh/server.go b/ssh/server.go
index 70d6077..b6f4cc8 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -288,12 +288,30 @@
 	return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
 }
 
+// ServerAuthError implements the error interface. It appends any authentication
+// errors that may occur, and is returned if all of the authentication methods
+// provided by the user failed to authenticate.
+type ServerAuthError struct {
+	// Errors contains authentication errors returned by the authentication
+	// callback methods.
+	Errors []error
+}
+
+func (l ServerAuthError) Error() string {
+	var errs []string
+	for _, err := range l.Errors {
+		errs = append(errs, err.Error())
+	}
+	return "[" + strings.Join(errs, ", ") + "]"
+}
+
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
 	sessionID := s.transport.getSessionID()
 	var cache pubKeyCache
 	var perms *Permissions
 
 	authFailures := 0
+	var authErrs []error
 
 userAuthLoop:
 	for {
@@ -312,6 +330,9 @@
 
 		var userAuthReq userAuthRequestMsg
 		if packet, err := s.transport.readPacket(); err != nil {
+			if err == io.EOF {
+				return nil, &ServerAuthError{Errors: authErrs}
+			}
 			return nil, err
 		} else if err = Unmarshal(packet, &userAuthReq); err != nil {
 			return nil, err
@@ -448,6 +469,8 @@
 			authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
 		}
 
+		authErrs = append(authErrs, authErr)
+
 		if config.AuthLogCallback != nil {
 			config.AuthLogCallback(s, userAuthReq.Method, authErr)
 		}