x/crypto/ssh: interpret disconnect message as error in the transport layer.

This ensures that higher level parts (e.g. the client authentication
loop) never have to deal with disconnect messages.

Fixes https://github.com/coreos/fleet/issues/565.

Change-Id: Ie164b6c4b0982c7ed9af6d3bf91697a78a911a20
Reviewed-on: https://go-review.googlesource.com/20801
Reviewed-by: Anton Khramov <anton@endocode.com>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index e15be3e..6956ce4 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -321,8 +321,6 @@
 			return false, msg.Methods, nil
 		case msgUserAuthSuccess:
 			return true, nil, nil
-		case msgDisconnect:
-			return false, nil, io.EOF
 		default:
 			return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 		}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index b86d369..bd7fe77 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -10,6 +10,7 @@
 	"errors"
 	"fmt"
 	"net"
+	"reflect"
 	"runtime"
 	"strings"
 	"sync"
@@ -413,3 +414,45 @@
 
 	wg.Wait()
 }
+
+func TestDisconnect(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+
+	defer trC.Close()
+	defer trS.Close()
+
+	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+	errMsg := &disconnectMsg{
+		Reason: 42,
+		Message: "such is life",
+	}
+	trC.writePacket(Marshal(errMsg))
+	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+
+	packet, err := trS.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket 1: %v", err)
+	}
+	if packet[0] != msgRequestSuccess {
+		t.Errorf("got packet %v, want packet type %d", packet,  msgRequestSuccess)
+	}
+
+	_, err = trS.readPacket()
+	if err == nil {
+		t.Errorf("readPacket 2 succeeded")
+	} else if !reflect.DeepEqual(err, errMsg) {
+		t.Errorf("got error %#v, want %#v", err, errMsg)
+	}
+
+	_, err = trS.readPacket()
+	if err == nil {
+		t.Errorf("readPacket 3 succeeded")
+	}
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index eaf6106..247694b 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -47,7 +47,7 @@
 }
 
 func (d *disconnectMsg) Error() string {
-	return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
+	return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
 }
 
 // See RFC 4253, section 7.1.
diff --git a/ssh/mux.go b/ssh/mux.go
index 321880a..a2af7f4 100644
--- a/ssh/mux.go
+++ b/ssh/mux.go
@@ -175,18 +175,6 @@
 	return m.sendMessage(globalRequestFailureMsg{Data: data})
 }
 
-// TODO(hanwen): Disconnect is a transport layer message. We should
-// probably send and receive Disconnect somewhere in the transport
-// code.
-
-// Disconnect sends a disconnect message.
-func (m *mux) Disconnect(reason uint32, message string) error {
-	return m.sendMessage(disconnectMsg{
-		Reason:  reason,
-		Message: message,
-	})
-}
-
 func (m *mux) Close() error {
 	return m.conn.Close()
 }
@@ -239,8 +227,6 @@
 	case msgNewKeys:
 		// Ignore notification of key change.
 		return nil
-	case msgDisconnect:
-		return m.handleDisconnect(packet)
 	case msgChannelOpen:
 		return m.handleChannelOpen(packet)
 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
@@ -260,18 +246,6 @@
 	return ch.handlePacket(packet)
 }
 
-func (m *mux) handleDisconnect(packet []byte) error {
-	var d disconnectMsg
-	if err := Unmarshal(packet, &d); err != nil {
-		return err
-	}
-
-	if debugMux {
-		log.Printf("caught disconnect: %v", d)
-	}
-	return &d
-}
-
 func (m *mux) handleGlobalPacket(packet []byte) error {
 	msg, err := decode(packet)
 	if err != nil {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 5230389..591aae8 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -331,7 +331,6 @@
 			ok, data, err)
 	}
 
-	clientMux.Disconnect(0, "")
 	if !seen {
 		t.Errorf("never saw 'peek' request")
 	}
@@ -378,28 +377,6 @@
 	}
 }
 
-func TestMuxDisconnect(t *testing.T) {
-	a, b := muxPair()
-	defer a.Close()
-	defer b.Close()
-
-	go func() {
-		for r := range b.incomingRequests {
-			r.Reply(true, nil)
-		}
-	}()
-
-	a.Disconnect(42, "whatever")
-	ok, _, err := a.SendRequest("hello", true, nil)
-	if ok || err == nil {
-		t.Errorf("got reply after disconnecting")
-	}
-	err = b.Wait()
-	if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
-		t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
-	}
-}
-
 func TestMuxCloseChannel(t *testing.T) {
 	r, w, mux := channelPair(t)
 	defer mux.Close()
diff --git a/ssh/transport.go b/ssh/transport.go
index 8351d37..4de98a6 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -114,12 +114,27 @@
 		err = errors.New("ssh: zero length packet")
 	}
 
-	if len(packet) > 0 && packet[0] == msgNewKeys {
-		select {
-		case cipher := <-s.pendingKeyChange:
+	if len(packet) > 0 {
+		switch packet[0] {
+		case msgNewKeys:
+			select {
+			case cipher := <-s.pendingKeyChange:
 			s.packetCipher = cipher
-		default:
-			return nil, errors.New("ssh: got bogus newkeys message.")
+			default:
+				return nil, errors.New("ssh: got bogus newkeys message.")
+			}
+
+		case msgDisconnect:
+			// Transform a disconnect message into an
+			// error. Since this is lowest level at which
+			// we interpret message types, doing it here
+			// ensures that we don't have to handle it
+			// elsewhere.
+			var msg disconnectMsg
+			if err := Unmarshal(packet, &msg); err != nil {
+				return nil, err
+			}
+			return nil, &msg
 		}
 	}