go.crypto/ssh: improve channel max packet handling
This proposal moves the check for max packet into
channel.writePacket. Callers should be aware they cannot
pass a buffer larger than max packet. This is only a
concern to chanWriter.Write and appropriate guards are
already in place.
There was some max packet handling in transport.go but it was
incorrect. This has been removed.
This proposal also cleans up session_test.go.
R=gustav.paul, agl, fullung, huin
CC=golang-dev
https://golang.org/cl/6460075
diff --git a/ssh/session_test.go b/ssh/session_test.go
index a5aafb1..9f1c3b6 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -16,7 +16,7 @@
"code.google.com/p/go.crypto/ssh/terminal"
)
-type serverType func(*serverChan)
+type serverType func(*serverChan, *testing.T)
// dial constructs a new test server and returns a *ClientConn.
func dial(handler serverType, t *testing.T) *ClientConn {
@@ -28,7 +28,7 @@
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
- t.Fatalf("unable to listen: %s", err)
+ t.Fatalf("unable to listen: %v", err)
}
go func() {
defer l.Close()
@@ -60,7 +60,7 @@
continue
}
ch.Accept()
- go handler(ch.(*serverChan))
+ go handler(ch.(*serverChan), t)
}
t.Log("done")
}()
@@ -74,7 +74,7 @@
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
- t.Fatalf("unable to dial remote side: %s", err)
+ t.Fatalf("unable to dial remote side: %v", err)
}
return c
}
@@ -85,7 +85,7 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
@@ -94,7 +94,7 @@
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
- t.Fatalf("Remote command did not exit cleanly: %s", err)
+ t.Fatalf("Remote command did not exit cleanly: %v", err)
}
actual := stdout.String()
if actual != "golang" {
@@ -110,7 +110,7 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
@@ -119,7 +119,7 @@
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
done := make(chan bool, 1)
go func() {
@@ -129,7 +129,7 @@
done <- true
}()
if err := session.Wait(); err != nil {
- t.Fatalf("Remote command did not exit cleanly: %s", err)
+ t.Fatalf("Remote command did not exit cleanly: %v", err)
}
<-done
actual := buf.String()
@@ -144,11 +144,11 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
@@ -159,7 +159,7 @@
t.Fatalf("expected *ExitError but got %T", err)
}
if e.ExitStatus() != 15 {
- t.Fatalf("expected command to exit with 15 but got %s", e.ExitStatus())
+ t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
}
}
@@ -169,16 +169,16 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
- t.Fatalf("expected nil but got %s", err)
+ t.Fatalf("expected nil but got %v", err)
}
}
@@ -188,11 +188,11 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
@@ -213,11 +213,11 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
@@ -238,11 +238,11 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
@@ -263,11 +263,11 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
@@ -286,7 +286,7 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
// Make sure that we closed all the clientChans when the connection
// failed.
@@ -302,16 +302,16 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
- t.Fatalf("expected nil but got %s", err)
+ t.Fatalf("expected nil but got %v", err)
}
}
@@ -322,12 +322,12 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
// send a bogus zero sized window update
@@ -335,21 +335,21 @@
err = session.Wait()
if err != nil {
- t.Fatalf("expected nil but got %s", err)
+ t.Fatalf("expected nil but got %v", err)
}
}
-// Verify that we never send a packet larger than maxpacket.
+// Verify that the client never sends a packet larger than maxpacket.
func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
conn := dial(discardHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %s", err)
+ t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %s", err)
+ t.Fatalf("Unable to execute command: %v", err)
}
// try to stuff 128k of data into a 32k hole.
const size = 128 * 1024
@@ -359,6 +359,27 @@
}
}
+// Verify that the client never accepts a packet larger than maxpacket.
+func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
+ conn := dial(largeSendHandler, t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %v", err)
+ }
+ defer session.Close()
+ out, err := session.StdoutPipe()
+ if err != nil {
+ t.Fatalf("Unable to connect to Stdout: %v", err)
+ }
+ if err := session.Shell(); err != nil {
+ t.Fatalf("Unable to execute command: %v", err)
+ }
+ if _, err := ioutil.ReadAll(out); err != nil {
+ t.Fatalf("failed to read: %v", err)
+ }
+}
+
type exitStatusMsg struct {
PeersId uint32
Request string
@@ -384,68 +405,70 @@
}
}
-func exitStatusZeroHandler(ch *serverChan) {
+func exitStatusZeroHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendStatus(0, ch)
+ sendStatus(0, ch, t)
}
-func exitStatusNonZeroHandler(ch *serverChan) {
+func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendStatus(15, ch)
+ sendStatus(15, ch, t)
}
-func exitSignalAndStatusHandler(ch *serverChan) {
+func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendStatus(15, ch)
- sendSignal("TERM", ch)
+ sendStatus(15, ch, t)
+ sendSignal("TERM", ch, t)
}
-func exitSignalHandler(ch *serverChan) {
+func exitSignalHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendSignal("TERM", ch)
+ sendSignal("TERM", ch, t)
}
-func exitSignalUnknownHandler(ch *serverChan) {
+func exitSignalUnknownHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendSignal("SYS", ch)
+ sendSignal("SYS", ch, t)
}
-func exitWithoutSignalOrStatus(ch *serverChan) {
+func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, "> ")
shell.ReadLine()
}
-func shellHandler(ch *serverChan) {
+func shellHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, "golang")
shell.ReadLine()
- sendStatus(0, ch)
+ sendStatus(0, ch, t)
}
-func sendStatus(status uint32, ch *serverChan) {
+func sendStatus(status uint32, ch *serverChan, t *testing.T) {
msg := exitStatusMsg{
PeersId: ch.remoteId,
Request: "exit-status",
WantReply: false,
Status: status,
}
- ch.serverConn.writePacket(marshal(msgChannelRequest, msg))
+ if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+ t.Errorf("unable to send status: %v", err)
+ }
}
-func sendSignal(signal string, ch *serverChan) {
+func sendSignal(signal string, ch *serverChan, t *testing.T) {
sig := exitSignalMsg{
PeersId: ch.remoteId,
Request: "exit-signal",
@@ -455,10 +478,12 @@
Errmsg: "Process terminated",
Lang: "en-GB-oed",
}
- ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
+ if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil {
+ t.Errorf("unable to send signal: %v", err)
+ }
}
-func sendInvalidRecord(ch *serverChan) {
+func sendInvalidRecord(ch *serverChan, t *testing.T) {
defer ch.Close()
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
@@ -466,19 +491,21 @@
marshalUint32(packet[5:], 1)
packet[9] = 42
- ch.serverConn.writePacket(packet)
+ if err := ch.writePacket(packet); err != nil {
+ t.Errorf("unable send invalid record: %v", err)
+ }
}
-func sendZeroWindowAdjust(ch *serverChan) {
+func sendZeroWindowAdjust(ch *serverChan, t *testing.T) {
defer ch.Close()
// send a bogus zero sized window update
ch.sendWindowAdj(0)
shell := newServerShell(ch, "> ")
shell.ReadLine()
- sendStatus(0, ch)
+ sendStatus(0, ch, t)
}
-func discardHandler(ch *serverChan) {
+func discardHandler(ch *serverChan, t *testing.T) {
defer ch.Close()
// grow the window to avoid being fooled by
// the initial 1 << 14 window.
@@ -487,3 +514,17 @@
shell.ReadLine()
io.Copy(ioutil.Discard, ch.serverConn)
}
+
+func largeSendHandler(ch *serverChan, t *testing.T) {
+ defer ch.Close()
+ // grow the window to avoid being fooled by
+ // the initial 1 << 14 window.
+ ch.sendWindowAdj(1024 * 1024)
+ shell := newServerShell(ch, "> ")
+ shell.ReadLine()
+ // try to send more than the 32k window
+ // will allow
+ if err := ch.writePacket(make([]byte, 128*1024)); err == nil {
+ t.Errorf("wrote packet larger than 32k")
+ }
+}