go.crypto/ssh: improve channel max packet handling

This proposal moves the check for max packet into
channel.writePacket. Callers should be aware they cannot
pass a buffer larger than max packet. This is only a
concern to chanWriter.Write and appropriate guards are
already in place.

There was some max packet handling in transport.go but it was
incorrect. This has been removed.

This proposal also cleans up session_test.go.

R=gustav.paul, agl, fullung, huin
CC=golang-dev
https://golang.org/cl/6460075
diff --git a/ssh/channel.go b/ssh/channel.go
index 9ca6789..7304140 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -6,6 +6,7 @@
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"sync"
 )
@@ -14,8 +15,13 @@
 // section 5.2.
 type extendedDataTypeCode uint32
 
-// extendedDataStderr is the extended data type that is used for stderr.
-const extendedDataStderr extendedDataTypeCode = 1
+const (
+	// extendedDataStderr is the extended data type that is used for stderr.
+	extendedDataStderr extendedDataTypeCode = 1
+
+	// minPacketLength defines the smallest valid packet
+	minPacketLength = 9
+)
 
 // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
 // SSH connection. Channel.Read can return a ChannelRequest as an error.
@@ -74,7 +80,7 @@
 	conn              // the underlying transport
 	localId, remoteId uint32
 	remoteWin         window
-	maxPacketSize     uint32
+	maxPacket         uint32
 
 	theyClosed  bool // indicates the close msg has been received from the remote side
 	weClosed    bool // incidates the close msg has been sent from our side
@@ -114,6 +120,13 @@
 	return c.writePacket(marshal(msgChannelOpenFailure, reject))
 }
 
+func (c *channel) writePacket(b []byte) error {
+	if uint32(len(b)) > c.maxPacket {
+		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
+	}
+	return c.conn.writePacket(b)
+}
+
 type serverChan struct {
 	channel
 	// immutable once created
@@ -144,7 +157,7 @@
 		PeersId:       c.remoteId,
 		MyId:          c.localId,
 		MyWindow:      c.myWindow,
-		MaxPacketSize: c.maxPacketSize,
+		MaxPacketSize: c.maxPacket,
 	}
 	return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
 }
@@ -450,10 +463,12 @@
 func (c *clientChan) waitForChannelOpenResponse() error {
 	switch msg := (<-c.msg).(type) {
 	case *channelOpenConfirmMsg:
+		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+			return errors.New("ssh: invalid MaxPacketSize from peer")
+		}
 		// fixup remoteId field
 		c.remoteId = msg.MyId
-		// TODO(dfc) asset this is < 2^31.
-		c.maxPacketSize = msg.MaxPacketSize
+		c.maxPacket = msg.MaxPacketSize
 		c.remoteWin.add(msg.MyWindow)
 		return nil
 	case *channelOpenFailureMsg:
@@ -478,10 +493,11 @@
 
 // Write writes data to the remote process's standard input.
 func (w *chanWriter) Write(data []byte) (written int, err error) {
+	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
 	for len(data) > 0 {
-		// never send more data than maxPacketSize even if
+		// never send more data than maxPacket even if
 		// there is sufficent window.
-		n := min(int(w.maxPacketSize), len(data))
+		n := min(int(w.maxPacket-headerLength), len(data))
 		n = int(w.remoteWin.reserve(uint32(n)))
 		remoteId := w.remoteId
 		packet := []byte{
diff --git a/ssh/server.go b/ssh/server.go
index 2825a3d..fc985a1 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -564,13 +564,15 @@
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
+				if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+					return nil, errors.New("ssh: invalid MaxPacketSize from peer")
+				}
 				c := &serverChan{
 					channel: channel{
 						conn:      s,
 						remoteId:  msg.PeersId,
 						remoteWin: window{Cond: newCond()},
-						// TODO(dfc) assert this param is < 2^31.
-						maxPacketSize: msg.MaxPacketSize,
+						maxPacket: msg.MaxPacketSize,
 					},
 					chanType:    msg.ChanType,
 					extraData:   msg.TypeSpecificData,
diff --git a/ssh/session_test.go b/ssh/session_test.go
index a5aafb1..9f1c3b6 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -16,7 +16,7 @@
 	"code.google.com/p/go.crypto/ssh/terminal"
 )
 
-type serverType func(*serverChan)
+type serverType func(*serverChan, *testing.T)
 
 // dial constructs a new test server and returns a *ClientConn.
 func dial(handler serverType, t *testing.T) *ClientConn {
@@ -28,7 +28,7 @@
 
 	l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
 	if err != nil {
-		t.Fatalf("unable to listen: %s", err)
+		t.Fatalf("unable to listen: %v", err)
 	}
 	go func() {
 		defer l.Close()
@@ -60,7 +60,7 @@
 				continue
 			}
 			ch.Accept()
-			go handler(ch.(*serverChan))
+			go handler(ch.(*serverChan), t)
 		}
 		t.Log("done")
 	}()
@@ -74,7 +74,7 @@
 
 	c, err := Dial("tcp", l.Addr().String(), config)
 	if err != nil {
-		t.Fatalf("unable to dial remote side: %s", err)
+		t.Fatalf("unable to dial remote side: %v", err)
 	}
 	return c
 }
@@ -85,7 +85,7 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	stdout := new(bytes.Buffer)
@@ -94,7 +94,7 @@
 		t.Fatalf("Unable to execute command: %s", err)
 	}
 	if err := session.Wait(); err != nil {
-		t.Fatalf("Remote command did not exit cleanly: %s", err)
+		t.Fatalf("Remote command did not exit cleanly: %v", err)
 	}
 	actual := stdout.String()
 	if actual != "golang" {
@@ -110,7 +110,7 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	stdout, err := session.StdoutPipe()
@@ -119,7 +119,7 @@
 	}
 	var buf bytes.Buffer
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	done := make(chan bool, 1)
 	go func() {
@@ -129,7 +129,7 @@
 		done <- true
 	}()
 	if err := session.Wait(); err != nil {
-		t.Fatalf("Remote command did not exit cleanly: %s", err)
+		t.Fatalf("Remote command did not exit cleanly: %v", err)
 	}
 	<-done
 	actual := buf.String()
@@ -144,11 +144,11 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -159,7 +159,7 @@
 		t.Fatalf("expected *ExitError but got %T", err)
 	}
 	if e.ExitStatus() != 15 {
-		t.Fatalf("expected command to exit with 15 but got %s", e.ExitStatus())
+		t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
 	}
 }
 
@@ -169,16 +169,16 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
@@ -188,11 +188,11 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -213,11 +213,11 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -238,11 +238,11 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -263,11 +263,11 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -286,7 +286,7 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	// Make sure that we closed all the clientChans when the connection
 	// failed.
@@ -302,16 +302,16 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
@@ -322,12 +322,12 @@
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 
 	// send a bogus zero sized window update
@@ -335,21 +335,21 @@
 
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
-// Verify that we never send a packet larger than maxpacket.
+// Verify that the client never sends a packet larger than maxpacket.
 func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
 	conn := dial(discardHandler, t)
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	// try to stuff 128k of data into a 32k hole.
 	const size = 128 * 1024
@@ -359,6 +359,27 @@
 	}
 }
 
+// Verify that the client never accepts a packet larger than maxpacket.
+func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
+	conn := dial(largeSendHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	out, err := session.StdoutPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect to Stdout: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	if _, err := ioutil.ReadAll(out); err != nil {
+		t.Fatalf("failed to read: %v", err)
+	}
+}
+
 type exitStatusMsg struct {
 	PeersId   uint32
 	Request   string
@@ -384,68 +405,70 @@
 	}
 }
 
-func exitStatusZeroHandler(ch *serverChan) {
+func exitStatusZeroHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func exitStatusNonZeroHandler(ch *serverChan) {
+func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(15, ch)
+	sendStatus(15, ch, t)
 }
 
-func exitSignalAndStatusHandler(ch *serverChan) {
+func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(15, ch)
-	sendSignal("TERM", ch)
+	sendStatus(15, ch, t)
+	sendSignal("TERM", ch, t)
 }
 
-func exitSignalHandler(ch *serverChan) {
+func exitSignalHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendSignal("TERM", ch)
+	sendSignal("TERM", ch, t)
 }
 
-func exitSignalUnknownHandler(ch *serverChan) {
+func exitSignalUnknownHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendSignal("SYS", ch)
+	sendSignal("SYS", ch, t)
 }
 
-func exitWithoutSignalOrStatus(ch *serverChan) {
+func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
 }
 
-func shellHandler(ch *serverChan) {
+func shellHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
 	shell := newServerShell(ch, "golang")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func sendStatus(status uint32, ch *serverChan) {
+func sendStatus(status uint32, ch *serverChan, t *testing.T) {
 	msg := exitStatusMsg{
 		PeersId:   ch.remoteId,
 		Request:   "exit-status",
 		WantReply: false,
 		Status:    status,
 	}
-	ch.serverConn.writePacket(marshal(msgChannelRequest, msg))
+	if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+		t.Errorf("unable to send status: %v", err)
+	}
 }
 
-func sendSignal(signal string, ch *serverChan) {
+func sendSignal(signal string, ch *serverChan, t *testing.T) {
 	sig := exitSignalMsg{
 		PeersId:    ch.remoteId,
 		Request:    "exit-signal",
@@ -455,10 +478,12 @@
 		Errmsg:     "Process terminated",
 		Lang:       "en-GB-oed",
 	}
-	ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
+	if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil {
+		t.Errorf("unable to send signal: %v", err)
+	}
 }
 
-func sendInvalidRecord(ch *serverChan) {
+func sendInvalidRecord(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	packet := make([]byte, 1+4+4+1)
 	packet[0] = msgChannelData
@@ -466,19 +491,21 @@
 	marshalUint32(packet[5:], 1)
 	packet[9] = 42
 
-	ch.serverConn.writePacket(packet)
+	if err := ch.writePacket(packet); err != nil {
+		t.Errorf("unable send invalid record: %v", err)
+	}
 }
 
-func sendZeroWindowAdjust(ch *serverChan) {
+func sendZeroWindowAdjust(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// send a bogus zero sized window update
 	ch.sendWindowAdj(0)
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func discardHandler(ch *serverChan) {
+func discardHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// grow the window to avoid being fooled by
 	// the initial 1 << 14 window.
@@ -487,3 +514,17 @@
 	shell.ReadLine()
 	io.Copy(ioutil.Discard, ch.serverConn)
 }
+
+func largeSendHandler(ch *serverChan, t *testing.T) {
+	defer ch.Close()
+	// grow the window to avoid being fooled by
+	// the initial 1 << 14 window.
+	ch.sendWindowAdj(1024 * 1024)
+	shell := newServerShell(ch, "> ")
+	shell.ReadLine()
+	// try to send more than the 32k window
+	// will allow
+	if err := ch.writePacket(make([]byte, 128*1024)); err == nil {
+		t.Errorf("wrote packet larger than 32k")
+	}
+}
diff --git a/ssh/transport.go b/ssh/transport.go
index 020d6d7..73b9783 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -19,9 +19,6 @@
 
 const (
 	packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
-	minPacketSize      = 16
-	maxPacketSize      = 36000
-	minPaddingSize     = 4 // TODO(huin) should this be configurable?
 )
 
 // conn represents an ssh transport that implements packet based
@@ -97,9 +94,6 @@
 	if length <= paddingLength+1 {
 		return nil, errors.New("ssh: invalid packet length")
 	}
-	if length > maxPacketSize {
-		return nil, errors.New("ssh: packet too large")
-	}
 
 	packet := make([]byte, length-1+macSize)
 	if _, err := io.ReadFull(r, packet); err != nil {
@@ -196,11 +190,8 @@
 		}
 	}
 
-	if err := w.Flush(); err != nil {
-		return err
-	}
 	w.seqNum++
-	return err
+	return w.Flush()
 }
 
 // Send a message to the remote peer