x/crypto/ssh: public key authentication example

Fixes golang/go#13902.

Adds public key authentication to the
password authentication example.

Change-Id: I4af0ca627fb15b617cc1ba1c6e0954b013f4d94f
Reviewed-on: https://go-review.googlesource.com/29374
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/example_test.go b/ssh/example_test.go
index 8b8a2fe..5315e45 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -17,9 +17,29 @@
 )
 
 func ExampleNewServerConn() {
+	// Public key authentication is done by comparing
+	// the public key of a received connection
+	// with the entries in the authorized_keys file.
+	authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys")
+	if err != nil {
+		log.Fatalf("Failed to load authorized_keys, err: %v", err)
+	}
+
+	authorizedKeysMap := map[string]bool{}
+	for len(authorizedKeysBytes) > 0 {
+		pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		authorizedKeysMap[string(pubKey.Marshal())] = true
+		authorizedKeysBytes = rest
+	}
+
 	// An SSH server is represented by a ServerConfig, which holds
 	// certificate details and handles authentication of ServerConns.
 	config := &ssh.ServerConfig{
+		// Remove to disable password auth.
 		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
 			// Should use constant-time compare (or better, salt+hash) in
 			// a production setting.
@@ -28,6 +48,14 @@
 			}
 			return nil, fmt.Errorf("password rejected for %q", c.User())
 		},
+
+		// Remove to disable public key auth.
+		PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
+			if authorizedKeysMap[string(pubKey.Marshal())] {
+				return nil, nil
+			}
+			return nil, fmt.Errorf("unknown public key for %q", c.User())
+		},
 	}
 
 	privateBytes, err := ioutil.ReadFile("id_rsa")
@@ -63,6 +91,8 @@
 	go ssh.DiscardRequests(reqs)
 
 	// Service the incoming Channel channel.
+
+	// Service the incoming Channel channel.
 	for newChannel := range chans {
 		// Channels have a type, depending on the application level
 		// protocol intended. In the case of a shell, the type is
@@ -74,7 +104,7 @@
 		}
 		channel, requests, err := newChannel.Accept()
 		if err != nil {
-			log.Fatal("could not accept channel: ", err)
+			log.Fatalf("Could not accept channel: %v", err)
 		}
 
 		// Sessions have out-of-band requests such as "shell",
@@ -82,18 +112,7 @@
 		// "shell" request.
 		go func(in <-chan *ssh.Request) {
 			for req := range in {
-				ok := false
-				switch req.Type {
-				case "shell":
-					ok = true
-					if len(req.Payload) > 0 {
-						// We don't accept any
-						// commands, only the
-						// default shell.
-						ok = false
-					}
-				}
-				req.Reply(ok, nil)
+				req.Reply(req.Type == "shell", nil)
 			}
 		}(requests)