x/crypto/ssh: interpret disconnect message as error in the transport layer.
This ensures that higher level parts (e.g. the client authentication
loop) never have to deal with disconnect messages.
Fixes https://github.com/coreos/fleet/issues/565.
Change-Id: Ie164b6c4b0982c7ed9af6d3bf91697a78a911a20
Reviewed-on: https://go-review.googlesource.com/20801
Reviewed-by: Anton Khramov <anton@endocode.com>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index e15be3e..6956ce4 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -321,8 +321,6 @@
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
- case msgDisconnect:
- return false, nil, io.EOF
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index b86d369..bd7fe77 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -10,6 +10,7 @@
"errors"
"fmt"
"net"
+ "reflect"
"runtime"
"strings"
"sync"
@@ -413,3 +414,45 @@
wg.Wait()
}
+
+func TestDisconnect(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("see golang.org/issue/7237")
+ }
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+ errMsg := &disconnectMsg{
+ Reason: 42,
+ Message: "such is life",
+ }
+ trC.writePacket(Marshal(errMsg))
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+
+ packet, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("readPacket 1: %v", err)
+ }
+ if packet[0] != msgRequestSuccess {
+ t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
+ }
+
+ _, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("readPacket 2 succeeded")
+ } else if !reflect.DeepEqual(err, errMsg) {
+ t.Errorf("got error %#v, want %#v", err, errMsg)
+ }
+
+ _, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("readPacket 3 succeeded")
+ }
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index eaf6106..247694b 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -47,7 +47,7 @@
}
func (d *disconnectMsg) Error() string {
- return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
+ return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
}
// See RFC 4253, section 7.1.
diff --git a/ssh/mux.go b/ssh/mux.go
index 321880a..a2af7f4 100644
--- a/ssh/mux.go
+++ b/ssh/mux.go
@@ -175,18 +175,6 @@
return m.sendMessage(globalRequestFailureMsg{Data: data})
}
-// TODO(hanwen): Disconnect is a transport layer message. We should
-// probably send and receive Disconnect somewhere in the transport
-// code.
-
-// Disconnect sends a disconnect message.
-func (m *mux) Disconnect(reason uint32, message string) error {
- return m.sendMessage(disconnectMsg{
- Reason: reason,
- Message: message,
- })
-}
-
func (m *mux) Close() error {
return m.conn.Close()
}
@@ -239,8 +227,6 @@
case msgNewKeys:
// Ignore notification of key change.
return nil
- case msgDisconnect:
- return m.handleDisconnect(packet)
case msgChannelOpen:
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
@@ -260,18 +246,6 @@
return ch.handlePacket(packet)
}
-func (m *mux) handleDisconnect(packet []byte) error {
- var d disconnectMsg
- if err := Unmarshal(packet, &d); err != nil {
- return err
- }
-
- if debugMux {
- log.Printf("caught disconnect: %v", d)
- }
- return &d
-}
-
func (m *mux) handleGlobalPacket(packet []byte) error {
msg, err := decode(packet)
if err != nil {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 5230389..591aae8 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -331,7 +331,6 @@
ok, data, err)
}
- clientMux.Disconnect(0, "")
if !seen {
t.Errorf("never saw 'peek' request")
}
@@ -378,28 +377,6 @@
}
}
-func TestMuxDisconnect(t *testing.T) {
- a, b := muxPair()
- defer a.Close()
- defer b.Close()
-
- go func() {
- for r := range b.incomingRequests {
- r.Reply(true, nil)
- }
- }()
-
- a.Disconnect(42, "whatever")
- ok, _, err := a.SendRequest("hello", true, nil)
- if ok || err == nil {
- t.Errorf("got reply after disconnecting")
- }
- err = b.Wait()
- if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
- t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
- }
-}
-
func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
diff --git a/ssh/transport.go b/ssh/transport.go
index 8351d37..4de98a6 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -114,12 +114,27 @@
err = errors.New("ssh: zero length packet")
}
- if len(packet) > 0 && packet[0] == msgNewKeys {
- select {
- case cipher := <-s.pendingKeyChange:
+ if len(packet) > 0 {
+ switch packet[0] {
+ case msgNewKeys:
+ select {
+ case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
- default:
- return nil, errors.New("ssh: got bogus newkeys message.")
+ default:
+ return nil, errors.New("ssh: got bogus newkeys message.")
+ }
+
+ case msgDisconnect:
+ // Transform a disconnect message into an
+ // error. Since this is lowest level at which
+ // we interpret message types, doing it here
+ // ensures that we don't have to handle it
+ // elsewhere.
+ var msg disconnectMsg
+ if err := Unmarshal(packet, &msg); err != nil {
+ return nil, err
+ }
+ return nil, &msg
}
}