ssh: handle bad servers better.
This change prevents bad servers from crashing a client by sending an
invalid channel ID. It also makes the client disconnect in more cases
of invalid messages from a server and cleans up the client channels
in the event of a disconnect.
R=dave
CC=golang-dev
https://golang.org/cl/6099050
diff --git a/ssh/client.go b/ssh/client.go
index 493d8ec..3b29923 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -184,8 +184,16 @@
// mainLoop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
- // TODO(dfc) signal the underlying close to all channels
- defer c.Close()
+ defer func() {
+ // We don't check, for example, that the channel IDs from the
+ // server are valid before using them. Thus a bad server can
+ // cause us to panic, but we don't want to crash the program.
+ recover()
+
+ c.Close()
+ c.closeAll()
+ }()
+
for {
packet, err := c.readPacket()
if err != nil {
@@ -199,28 +207,34 @@
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
- break
+ return
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
- if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
- packet = packet[9:]
- c.getChan(peersId).stdout.handleData(packet[:length])
+ length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
+ packet = packet[9:]
+
+ if length != uint32(len(packet)) {
+ return
}
+ c.getChan(peersId).stdout.handleData(packet)
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
- break
+ return
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
- if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
- packet = packet[13:]
- // RFC 4254 5.2 defines data_type_code 1 to be data destined
- // for stderr on interactive sessions. Other data types are
- // silently discarded.
- if datatype == 1 {
- c.getChan(peersId).stderr.handleData(packet[:length])
- }
+ length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
+ packet = packet[13:]
+
+ if length != uint32(len(packet)) {
+ return
+ }
+ // RFC 4254 5.2 defines data_type_code 1 to be data destined
+ // for stderr on interactive sessions. Other data types are
+ // silently discarded.
+ if datatype == 1 {
+ c.getChan(peersId).stderr.handleData(packet)
}
default:
switch msg := decode(packet).(type) {
@@ -256,10 +270,10 @@
case *windowAdjustMsg:
if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
// invalid window update
- break
+ return
}
case *disconnectMsg:
- break
+ return
default:
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
}
@@ -408,6 +422,9 @@
func (c *chanlist) getChan(id uint32) *clientChan {
c.Lock()
defer c.Unlock()
+ if id >= uint32(len(c.chans)) {
+ return nil
+ }
return c.chans[int(id)]
}
@@ -417,6 +434,22 @@
c.chans[int(id)] = nil
}
+func (c *chanlist) closeAll() {
+ c.Lock()
+ defer c.Unlock()
+
+ for _, ch := range c.chans {
+ if ch == nil {
+ continue
+ }
+
+ ch.theyClosed = true
+ ch.stdout.eof()
+ ch.stderr.eof()
+ close(ch.msg)
+ }
+}
+
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
win *window
diff --git a/ssh/session_test.go b/ssh/session_test.go
index df66e1d..df97fcf 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -275,6 +275,20 @@
}
}
+func TestInvalidServerMessage(t *testing.T) {
+ conn := dial(sendInvalidRecord, t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %s", err)
+ }
+ // Make sure that we closed all the clientChans when the connection
+ // failed.
+ session.wait()
+
+ defer session.Close()
+}
+
type exitStatusMsg struct {
PeersId uint32
Request string
@@ -373,3 +387,14 @@
}
ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
}
+
+func sendInvalidRecord(ch *channel) {
+ defer ch.Close()
+ packet := make([]byte, 1+4+4+1)
+ packet[0] = msgChannelData
+ marshalUint32(packet[1:], 29348723 /* invalid channel id */)
+ marshalUint32(packet[5:], 1)
+ packet[9] = 42
+
+ ch.serverConn.writePacket(packet)
+}