x/crypto/ssh: interpret disconnect message as error in the transport layer. This ensures that higher level parts (e.g. the client authentication loop) never have to deal with disconnect messages. Fixes https://github.com/coreos/fleet/issues/565. Change-Id: Ie164b6c4b0982c7ed9af6d3bf91697a78a911a20 Reviewed-on: https://go-review.googlesource.com/20801 Reviewed-by: Anton Khramov <anton@endocode.com> Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/client_auth.go b/ssh/client_auth.go index e15be3e..6956ce4 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go
@@ -321,8 +321,6 @@ return false, msg.Methods, nil case msgUserAuthSuccess: return true, nil, nil - case msgDisconnect: - return false, nil, io.EOF default: return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) }
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index b86d369..bd7fe77 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go
@@ -10,6 +10,7 @@ "errors" "fmt" "net" + "reflect" "runtime" "strings" "sync" @@ -413,3 +414,45 @@ wg.Wait() } + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +}
diff --git a/ssh/messages.go b/ssh/messages.go index eaf6106..247694b 100644 --- a/ssh/messages.go +++ b/ssh/messages.go
@@ -47,7 +47,7 @@ } func (d *disconnectMsg) Error() string { - return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) } // See RFC 4253, section 7.1.
diff --git a/ssh/mux.go b/ssh/mux.go index 321880a..a2af7f4 100644 --- a/ssh/mux.go +++ b/ssh/mux.go
@@ -175,18 +175,6 @@ return m.sendMessage(globalRequestFailureMsg{Data: data}) } -// TODO(hanwen): Disconnect is a transport layer message. We should -// probably send and receive Disconnect somewhere in the transport -// code. - -// Disconnect sends a disconnect message. -func (m *mux) Disconnect(reason uint32, message string) error { - return m.sendMessage(disconnectMsg{ - Reason: reason, - Message: message, - }) -} - func (m *mux) Close() error { return m.conn.Close() } @@ -239,8 +227,6 @@ case msgNewKeys: // Ignore notification of key change. return nil - case msgDisconnect: - return m.handleDisconnect(packet) case msgChannelOpen: return m.handleChannelOpen(packet) case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: @@ -260,18 +246,6 @@ return ch.handlePacket(packet) } -func (m *mux) handleDisconnect(packet []byte) error { - var d disconnectMsg - if err := Unmarshal(packet, &d); err != nil { - return err - } - - if debugMux { - log.Printf("caught disconnect: %v", d) - } - return &d -} - func (m *mux) handleGlobalPacket(packet []byte) error { msg, err := decode(packet) if err != nil {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 5230389..591aae8 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go
@@ -331,7 +331,6 @@ ok, data, err) } - clientMux.Disconnect(0, "") if !seen { t.Errorf("never saw 'peek' request") } @@ -378,28 +377,6 @@ } } -func TestMuxDisconnect(t *testing.T) { - a, b := muxPair() - defer a.Close() - defer b.Close() - - go func() { - for r := range b.incomingRequests { - r.Reply(true, nil) - } - }() - - a.Disconnect(42, "whatever") - ok, _, err := a.SendRequest("hello", true, nil) - if ok || err == nil { - t.Errorf("got reply after disconnecting") - } - err = b.Wait() - if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { - t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) - } -} - func TestMuxCloseChannel(t *testing.T) { r, w, mux := channelPair(t) defer mux.Close()
diff --git a/ssh/transport.go b/ssh/transport.go index 8351d37..4de98a6 100644 --- a/ssh/transport.go +++ b/ssh/transport.go
@@ -114,12 +114,27 @@ err = errors.New("ssh: zero length packet") } - if len(packet) > 0 && packet[0] == msgNewKeys { - select { - case cipher := <-s.pendingKeyChange: + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: s.packetCipher = cipher - default: - return nil, errors.New("ssh: got bogus newkeys message.") + default: + return nil, errors.New("ssh: got bogus newkeys message.") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg } }