ssh: add support for banners

According to RFC 4252 section 5.4, the banner is sent between the
ssh-connection request and responding to user authentication.

Original support for server sending banner by joshua stein <jcs@jcs.org>

Fixes golang/go#19567

Change-Id: I68944a7f4711c0623759f6a59023e8e45a8781aa
Reviewed-on: https://go-review.googlesource.com/65271
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/client.go b/ssh/client.go
index a7e3263..6fd1994 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -9,6 +9,7 @@
 	"errors"
 	"fmt"
 	"net"
+	"os"
 	"sync"
 	"time"
 )
@@ -187,6 +188,10 @@
 // net.Conn underlying the the SSH connection.
 type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
 
+// BannerCallback is the function type used for treat the banner sent by
+// the server. A BannerCallback receives the message sent by the remote server.
+type BannerCallback func(message string) error
+
 // A ClientConfig structure is used to configure a Client. It must not be
 // modified after having been passed to an SSH function.
 type ClientConfig struct {
@@ -209,6 +214,12 @@
 	// FixedHostKey can be used for simplistic host key checks.
 	HostKeyCallback HostKeyCallback
 
+	// BannerCallback is called during the SSH dance to display a custom
+	// server's message. The client configuration can supply this callback to
+	// handle it as wished. The function BannerDisplayStderr can be used for
+	// simplistic display on Stderr.
+	BannerCallback BannerCallback
+
 	// ClientVersion contains the version identification string that will
 	// be used for the connection. If empty, a reasonable default is used.
 	ClientVersion string
@@ -255,3 +266,13 @@
 	hk := &fixedHostKey{key}
 	return hk.check
 }
+
+// BannerDisplayStderr returns a function that can be used for
+// ClientConfig.BannerCallback to display banners on os.Stderr.
+func BannerDisplayStderr() BannerCallback {
+	return func(banner string) error {
+		_, err := os.Stderr.WriteString(banner)
+
+		return err
+	}
+}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 3acd8d4..a1252cb 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -283,7 +283,9 @@
 		}
 		switch packet[0] {
 		case msgUserAuthBanner:
-			// TODO(gpaul): add callback to present the banner to the user
+			if err := handleBannerResponse(c, packet); err != nil {
+				return false, err
+			}
 		case msgUserAuthPubKeyOk:
 			var msg userAuthPubKeyOkMsg
 			if err := Unmarshal(packet, &msg); err != nil {
@@ -325,7 +327,9 @@
 
 		switch packet[0] {
 		case msgUserAuthBanner:
-			// TODO: add callback to present the banner to the user
+			if err := handleBannerResponse(c, packet); err != nil {
+				return false, nil, err
+			}
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
 			if err := Unmarshal(packet, &msg); err != nil {
@@ -340,6 +344,24 @@
 	}
 }
 
+func handleBannerResponse(c packetConn, packet []byte) error {
+	var msg userAuthBannerMsg
+	if err := Unmarshal(packet, &msg); err != nil {
+		return err
+	}
+
+	transport, ok := c.(*handshakeTransport)
+	if !ok {
+		return nil
+	}
+
+	if transport.bannerCallback != nil {
+		return transport.bannerCallback(msg.Message)
+	}
+
+	return nil
+}
+
 // KeyboardInteractiveChallenge should print questions, optionally
 // disabling echoing (e.g. for passwords), and return all the answers.
 // Challenge may be called multiple times in a single session. After
@@ -385,7 +407,9 @@
 		// like handleAuthResponse, but with less options.
 		switch packet[0] {
 		case msgUserAuthBanner:
-			// TODO: Print banners during userauth.
+			if err := handleBannerResponse(c, packet); err != nil {
+				return false, nil, err
+			}
 			continue
 		case msgUserAuthInfoRequest:
 			// OK
diff --git a/ssh/client_test.go b/ssh/client_test.go
index ccf5607..f751eb6 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -79,3 +79,40 @@
 		}
 	}
 }
+func TestBannerCallback(t *testing.T) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	serverConf := &ServerConfig{
+		NoClientAuth: true,
+		BannerCallback: func(conn ConnMetadata) string {
+			return "Hello World"
+		},
+	}
+	serverConf.AddHostKey(testSigners["rsa"])
+	go NewServerConn(c1, serverConf)
+
+	var receivedBanner string
+	clientConf := ClientConfig{
+		User:            "user",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(message string) error {
+			receivedBanner = message
+			return nil
+		},
+	}
+
+	_, _, _, err = NewClientConn(c2, "", &clientConf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	expected := "Hello World"
+	if receivedBanner != expected {
+		t.Fatalf("got %s; want %s", receivedBanner, expected)
+	}
+}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 932ce83..4f7912e 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -78,6 +78,11 @@
 	dialAddress     string
 	remoteAddr      net.Addr
 
+	// bannerCallback is non-empty if we are the client and it has been set in
+	// ClientConfig. In that case it is called during the user authentication
+	// dance to handle a custom server's message.
+	bannerCallback BannerCallback
+
 	// Algorithms agreed in the last key exchange.
 	algorithms *algorithms
 
@@ -120,6 +125,7 @@
 	t.dialAddress = dialAddr
 	t.remoteAddr = addr
 	t.hostKeyCallback = config.HostKeyCallback
+	t.bannerCallback = config.BannerCallback
 	if config.HostKeyAlgorithms != nil {
 		t.hostKeyAlgorithms = config.HostKeyAlgorithms
 	} else {
diff --git a/ssh/messages.go b/ssh/messages.go
index e6ecd3a..92f3810 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -23,10 +23,6 @@
 	msgUnimplemented = 3
 	msgDebug         = 4
 	msgNewKeys       = 21
-
-	// Standard authentication messages
-	msgUserAuthSuccess = 52
-	msgUserAuthBanner  = 53
 )
 
 // SSH messages:
@@ -137,6 +133,16 @@
 	PartialSuccess bool
 }
 
+// See RFC 4252, section 5.1
+const msgUserAuthSuccess = 52
+
+// See RFC 4252, section 5.4
+const msgUserAuthBanner = 53
+
+type userAuthBannerMsg struct {
+	Message string `sshtype:"53"`
+}
+
 // See RFC 4256, section 3.2
 const msgUserAuthInfoRequest = 60
 const msgUserAuthInfoResponse = 61
diff --git a/ssh/server.go b/ssh/server.go
index 8a78b7c..148d2cb 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -95,6 +95,10 @@
 	// Note that RFC 4253 section 4.2 requires that this string start with
 	// "SSH-2.0-".
 	ServerVersion string
+
+	// BannerCallback, if present, is called and the return string is sent to
+	// the client after key exchange completed but before authentication.
+	BannerCallback func(conn ConnMetadata) string
 }
 
 // AddHostKey adds a private key as a host key. If an existing host
@@ -343,6 +347,19 @@
 		}
 
 		s.user = userAuthReq.User
+
+		if authFailures == 0 && config.BannerCallback != nil {
+			msg := config.BannerCallback(s)
+			if msg != "" {
+				bannerMsg := &userAuthBannerMsg{
+					Message: msg,
+				}
+				if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+					return nil, err
+				}
+			}
+		}
+
 		perms = nil
 		authErr := errors.New("no auth passed yet")