go.crypto/ssh: improve test reliability
Fixes golang/go#3989.
Tested for several hours on an 8 core ec2 instance with
random GOMAXPROC values.
Also, rolls server_test.go into session_test using the
existing dial() framework.
R=fullung, agl, kardianos
CC=golang-dev
https://golang.org/cl/6475063
diff --git a/ssh/server_test.go b/ssh/server_test.go
deleted file mode 100644
index 18dbff1..0000000
--- a/ssh/server_test.go
+++ /dev/null
@@ -1,179 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package ssh
-
-import (
- "bytes"
- crypto_rand "crypto/rand"
- "io"
- "math/rand"
- "testing"
-)
-
-// windowTestBytes is the number of bytes that we'll send to the SSH server.
-const windowTestBytes = 16000 * 200
-
-// CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
-// buffer size to exercise more code paths.
-func CopyNRandomly(dst io.Writer, src io.Reader, n int) (written int, err error) {
- buf := make([]byte, 32*1024)
- for written < n {
- l := (rand.Intn(30) + 1) * 1024
- if d := n - written; d < l {
- l = d
- }
- nr, er := src.Read(buf[0:l])
- if nr > 0 {
- nw, ew := dst.Write(buf[0:nr])
- if nw > 0 {
- written += nw
- }
- if ew != nil {
- err = ew
- break
- }
- if nr != nw {
- err = io.ErrShortWrite
- break
- }
- }
- if er != nil {
- err = er
- break
- }
- }
- return written, err
-}
-
-func TestServerWindow(t *testing.T) {
- addr := startSSHServer(t)
- runSSHClient(t, addr)
-}
-
-// runSSHClient writes random data to the server. The server is expected to echo
-// the same data back, which is compared against the original.
-func runSSHClient(t *testing.T, addr string) {
- conn, err := Dial("tcp", addr, &ClientConfig{})
- if err != nil {
- t.Fatal(err)
- }
-
- session, err := conn.NewSession()
- if err != nil {
- t.Fatal(err)
- }
-
- origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
- echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
- io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
- origBytes := origBuf.Bytes()
-
- wait := make(chan bool)
-
- // Read back the data from the server.
- go func() {
- defer session.Close()
- defer close(wait)
- serverStdout, err := session.StdoutPipe()
- if err != nil {
- t.Fatal(err)
- }
-
- n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
- if err != nil && err != io.EOF {
- t.Fatal(err)
- }
- if n != windowTestBytes {
- t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
- }
- }()
-
- serverStdin, err := session.StdinPipe()
- if err != nil {
- t.Fatal(err)
- }
-
- written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
- if err != nil {
- t.Fatal(err)
- }
- if written != windowTestBytes {
- t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
- }
-
- <-wait
-
- if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
- t.Error("Echoed buffer differed from original")
- }
-}
-
-func startSSHServer(t *testing.T) (addr string) {
- config := &ServerConfig{
- NoClientAuth: true,
- }
-
- err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
- if err != nil {
- t.Fatalf("Failed to parse private key: %s", err.Error())
- }
-
- listener, err := Listen("tcp", "127.0.0.1:0", config)
- if err != nil {
- t.Fatalf("Bind error: %s", err)
- }
-
- addr = listener.Addr().String()
- go func() {
- defer listener.Close()
- for {
- sConn, err := listener.Accept()
- err = sConn.Handshake()
- if err != nil {
- if err != io.EOF {
- t.Fatalf("failed to handshake: %s", err)
- }
- return
- }
-
- go connRun(t, sConn)
- }
- }()
-
- return
-}
-
-func connRun(t *testing.T, sConn *ServerConn) {
- defer sConn.Close()
- for {
- channel, err := sConn.Accept()
- if err != nil {
- if err == io.EOF {
- break
- }
- t.Fatalf("ServerConn.Accept failed: %s", err)
- }
-
- if channel.ChannelType() != "session" {
- channel.Reject(UnknownChannelType, "unknown channel type")
- continue
- }
- err = channel.Accept()
- if err != nil {
- t.Fatalf("Channel.Accept failed: %s", err)
- }
-
- go func() {
- defer channel.Close()
- n, err := CopyNRandomly(channel, channel, windowTestBytes)
- if err != nil && err != io.EOF {
- if err == io.ErrShortWrite {
- t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
- }
- t.Fatal(err)
- }
- }()
- }
-}
diff --git a/ssh/session_test.go b/ssh/session_test.go
index bc21c50..7fc52ed 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -8,8 +8,10 @@
import (
"bytes"
+ crypto_rand "crypto/rand"
"io"
"io/ioutil"
+ "math/rand"
"net"
"testing"
@@ -42,6 +44,7 @@
t.Errorf("Unable to handshake: %v", err)
return
}
+ done := make(chan struct{})
for {
ch, err := conn.Accept()
if err == io.EOF {
@@ -60,9 +63,12 @@
continue
}
ch.Accept()
- go handler(ch.(*serverChan), t)
+ go func() {
+ defer close(done)
+ handler(ch.(*serverChan), t)
+ }()
}
- t.Log("done")
+ <-done
}()
config := &ClientConfig{
@@ -345,17 +351,19 @@
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
+ t.Fatalf("failed to request new session: %v", err)
}
defer session.Close()
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
+ stdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("failed to obtain stdinpipe: %v", err)
}
- // try to stuff 128k of data into a 32k hole.
- const size = 128 * 1024
- n, err := session.clientChan.stdin.Write(make([]byte, size))
- if n != size || err != nil {
- t.Fatalf("failed to write: %d, %v", n, err)
+ const size = 100 * 1000
+ for i := 0; i < 10; i++ {
+ n, err := stdin.Write(make([]byte, size))
+ if n != size || err != nil {
+ t.Fatalf("failed to write: %d, %v", n, err)
+ }
}
}
@@ -385,7 +393,7 @@
t.Logf("test skipped")
return
- conn := dial(shellHandler, t)
+ conn := dial(exitWithoutSignalOrStatus, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
@@ -431,6 +439,59 @@
}
}
+// windowTestBytes is the number of bytes that we'll send to the SSH server.
+const windowTestBytes = 16000 * 200
+
+// TestServerWindow writes random data to the server. The server is expected to echo
+// the same data back, which is compared against the original.
+func TestServerWindow(t *testing.T) {
+ origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+ io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
+ origBytes := origBuf.Bytes()
+
+ conn := dial(echoHandler, t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer session.Close()
+ result := make(chan []byte)
+
+ go func() {
+ defer close(result)
+ echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+ serverStdout, err := session.StdoutPipe()
+ if err != nil {
+ t.Errorf("StdoutPipe failed: %v", err)
+ return
+ }
+ n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
+ if err != nil && err != io.EOF {
+ t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
+ }
+ result <- echoedBuf.Bytes()
+ }()
+
+ serverStdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("StdinPipe failed: %v", err)
+ }
+ written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
+ if err != nil {
+ t.Fatalf("falied to copy origBuf to serverStdin: %v", err)
+ }
+ if written != windowTestBytes {
+ t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
+ }
+
+ echoedBytes := <-result
+
+ if !bytes.Equal(origBytes, echoedBytes) {
+ t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
+ }
+}
+
type exitStatusMsg struct {
PeersId uint32
Request string
@@ -509,7 +570,7 @@
func readLine(shell *ServerTerminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
- t.Fatalf("unable to read line: %v", err)
+ t.Errorf("unable to read line: %v", err)
}
}
@@ -567,9 +628,11 @@
// grow the window to avoid being fooled by
// the initial 1 << 14 window.
ch.sendWindowAdj(1024 * 1024)
- shell := newServerShell(ch, "> ")
- readLine(shell, t)
- io.Copy(ioutil.Discard, ch.serverConn)
+ // TODO(dfc) io.Copy can return a non EOF error here
+ // because closed channel errors can leak here if the
+ // read from ch causes a window adjustment after the
+ // remote has signaled close.
+ io.Copy(ioutil.Discard, ch)
}
func largeSendHandler(ch *serverChan, t *testing.T) {
@@ -585,3 +648,40 @@
t.Errorf("wrote packet larger than 32k")
}
}
+
+func echoHandler(ch *serverChan, t *testing.T) {
+ defer ch.Close()
+ if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
+ t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
+ }
+}
+
+// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
+// buffer size to exercise more code paths.
+func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
+ var (
+ buf = make([]byte, 32*1024)
+ written int
+ remaining = n
+ )
+ for remaining > 0 {
+ l := rand.Intn(1 << 15)
+ if remaining < l {
+ l = remaining
+ }
+ nr, er := src.Read(buf[:l])
+ nw, ew := dst.Write(buf[:nr])
+ remaining -= nw
+ written += nw
+ if ew != nil {
+ return written, ew
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ if er != nil && er != io.EOF {
+ return written, er
+ }
+ }
+ return written, nil
+}