ssh: add decode support for banners

These banners can be printed when enabling debugHandshake, add decode
support so that they're not printed as unknown messages.

Change-Id: Ic8d56079d8225c35aac843accdbc80a642dd6249
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/650635
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
diff --git a/ssh/messages.go b/ssh/messages.go
index b55f860..118427b 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -818,6 +818,8 @@
 		return new(userAuthSuccessMsg), nil
 	case msgUserAuthFailure:
 		msg = new(userAuthFailureMsg)
+	case msgUserAuthBanner:
+		msg = new(userAuthBannerMsg)
 	case msgUserAuthPubKeyOk:
 		msg = new(userAuthPubKeyOkMsg)
 	case msgGlobalRequest:
diff --git a/ssh/messages_test.go b/ssh/messages_test.go
index e790764..d8691bd 100644
--- a/ssh/messages_test.go
+++ b/ssh/messages_test.go
@@ -206,6 +206,62 @@
 	}
 }
 
+func TestDecode(t *testing.T) {
+	rnd := rand.New(rand.NewSource(0))
+	kexInit := new(kexInitMsg).Generate(rnd, 10).Interface()
+	kexDHInit := new(kexDHInitMsg).Generate(rnd, 10).Interface()
+	kexDHReply := new(kexDHReplyMsg)
+	kexDHReply.Y = randomInt(rnd)
+	// Note: userAuthSuccessMsg can't be tested directly since it
+	// doesn't have a field for sshtype. So it's tested separately
+	// at the end.
+	decodeMessageTypes := []interface{}{
+		new(disconnectMsg),
+		new(serviceRequestMsg),
+		new(serviceAcceptMsg),
+		new(extInfoMsg),
+		kexInit,
+		kexDHInit,
+		kexDHReply,
+		new(userAuthRequestMsg),
+		new(userAuthFailureMsg),
+		new(userAuthBannerMsg),
+		new(userAuthPubKeyOkMsg),
+		new(globalRequestMsg),
+		new(globalRequestSuccessMsg),
+		new(globalRequestFailureMsg),
+		new(channelOpenMsg),
+		new(channelDataMsg),
+		new(channelOpenConfirmMsg),
+		new(channelOpenFailureMsg),
+		new(windowAdjustMsg),
+		new(channelEOFMsg),
+		new(channelCloseMsg),
+		new(channelRequestMsg),
+		new(channelRequestSuccessMsg),
+		new(channelRequestFailureMsg),
+		new(userAuthGSSAPIToken),
+		new(userAuthGSSAPIMIC),
+		new(userAuthGSSAPIErrTok),
+		new(userAuthGSSAPIError),
+	}
+	for _, msg := range decodeMessageTypes {
+		decoded, err := decode(Marshal(msg))
+		if err != nil {
+			t.Errorf("error decoding %T", msg)
+		} else if reflect.TypeOf(msg) != reflect.TypeOf(decoded) {
+			t.Errorf("error decoding %T, unexpected %T", msg, decoded)
+		}
+	}
+
+	userAuthSuccess, err := decode([]byte{msgUserAuthSuccess})
+	if err != nil {
+		t.Errorf("error decoding userAuthSuccessMsg")
+	} else if reflect.TypeOf(userAuthSuccess) != reflect.TypeOf((*userAuthSuccessMsg)(nil)) {
+		t.Errorf("error decoding userAuthSuccessMsg, unexpected %T", userAuthSuccess)
+	}
+}
+
 func randomBytes(out []byte, rand *rand.Rand) {
 	for i := 0; i < len(out); i++ {
 		out[i] = byte(rand.Int31())