ssh: add decode support for banners These banners can be printed when enabling debugHandshake, add decode support so that they're not printed as unknown messages. Change-Id: Ic8d56079d8225c35aac843accdbc80a642dd6249 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/650635 Reviewed-by: Junyang Shao <shaojunyang@google.com> Reviewed-by: Nicola Murino <nicola.murino@gmail.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Nicola Murino <nicola.murino@gmail.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
diff --git a/ssh/messages.go b/ssh/messages.go index b55f860..118427b 100644 --- a/ssh/messages.go +++ b/ssh/messages.go
@@ -818,6 +818,8 @@ return new(userAuthSuccessMsg), nil case msgUserAuthFailure: msg = new(userAuthFailureMsg) + case msgUserAuthBanner: + msg = new(userAuthBannerMsg) case msgUserAuthPubKeyOk: msg = new(userAuthPubKeyOkMsg) case msgGlobalRequest:
diff --git a/ssh/messages_test.go b/ssh/messages_test.go index e790764..d8691bd 100644 --- a/ssh/messages_test.go +++ b/ssh/messages_test.go
@@ -206,6 +206,62 @@ } } +func TestDecode(t *testing.T) { + rnd := rand.New(rand.NewSource(0)) + kexInit := new(kexInitMsg).Generate(rnd, 10).Interface() + kexDHInit := new(kexDHInitMsg).Generate(rnd, 10).Interface() + kexDHReply := new(kexDHReplyMsg) + kexDHReply.Y = randomInt(rnd) + // Note: userAuthSuccessMsg can't be tested directly since it + // doesn't have a field for sshtype. So it's tested separately + // at the end. + decodeMessageTypes := []interface{}{ + new(disconnectMsg), + new(serviceRequestMsg), + new(serviceAcceptMsg), + new(extInfoMsg), + kexInit, + kexDHInit, + kexDHReply, + new(userAuthRequestMsg), + new(userAuthFailureMsg), + new(userAuthBannerMsg), + new(userAuthPubKeyOkMsg), + new(globalRequestMsg), + new(globalRequestSuccessMsg), + new(globalRequestFailureMsg), + new(channelOpenMsg), + new(channelDataMsg), + new(channelOpenConfirmMsg), + new(channelOpenFailureMsg), + new(windowAdjustMsg), + new(channelEOFMsg), + new(channelCloseMsg), + new(channelRequestMsg), + new(channelRequestSuccessMsg), + new(channelRequestFailureMsg), + new(userAuthGSSAPIToken), + new(userAuthGSSAPIMIC), + new(userAuthGSSAPIErrTok), + new(userAuthGSSAPIError), + } + for _, msg := range decodeMessageTypes { + decoded, err := decode(Marshal(msg)) + if err != nil { + t.Errorf("error decoding %T", msg) + } else if reflect.TypeOf(msg) != reflect.TypeOf(decoded) { + t.Errorf("error decoding %T, unexpected %T", msg, decoded) + } + } + + userAuthSuccess, err := decode([]byte{msgUserAuthSuccess}) + if err != nil { + t.Errorf("error decoding userAuthSuccessMsg") + } else if reflect.TypeOf(userAuthSuccess) != reflect.TypeOf((*userAuthSuccessMsg)(nil)) { + t.Errorf("error decoding userAuthSuccessMsg, unexpected %T", userAuthSuccess) + } +} + func randomBytes(out []byte, rand *rand.Rand) { for i := 0; i < len(out); i++ { out[i] = byte(rand.Int31())