go.crypt/ssh: Add additional test for server.

R=golang-dev, agl
CC=golang-dev
https://golang.org/cl/6075046
diff --git a/ssh/server_test.go b/ssh/server_test.go
new file mode 100644
index 0000000..3e79e48
--- /dev/null
+++ b/ssh/server_test.go
@@ -0,0 +1,180 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"bytes"
+	crypto_rand "crypto/rand"
+	"io"
+	"math/rand"
+	"testing"
+)
+
+// windowTestBytes is the number of bytes that we'll send to the SSH server.
+const windowTestBytes = 16000 * 200
+
+// CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
+// buffer size to exercise more code paths.
+func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
+	buf := make([]byte, 32*1024)
+	for written < n {
+		l := (rand.Intn(30) + 1) * 1024
+		if d := n - written; d < int64(l) {
+			l = int(d)
+		}
+		nr, er := src.Read(buf[0:l])
+		if nr > 0 {
+			nw, ew := dst.Write(buf[0:nr])
+			if nw > 0 {
+				written += int64(nw)
+			}
+			if ew != nil {
+				err = ew
+				break
+			}
+			if nr != nw {
+				err = io.ErrShortWrite
+				break
+			}
+		}
+		if er != nil {
+			err = er
+			break
+		}
+	}
+	return written, err
+}
+
+func TestServerWindow(t *testing.T) {
+	addr := startSSHServer(t)
+	runSSHClient(t, addr)
+}
+
+// runSSHClient writes random data to the server. The server is expected to echo
+// the same data back, which is compared against the original.
+func runSSHClient(t *testing.T, addr string) {
+	conn, err := Dial("tcp", addr, &ClientConfig{})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+	echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
+	origBytes := origBuf.Bytes()
+
+	wait := make(chan bool)
+
+	// Read back the data from the server.
+	go func() {
+		defer session.Close()
+		serverStdout, err := session.StdoutPipe()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
+		if err != nil && err != io.EOF {
+			t.Fatal(err)
+		}
+		if n != windowTestBytes {
+			t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
+		}
+		wait <- true
+	}()
+
+	serverStdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if written != windowTestBytes {
+		t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
+	}
+
+	<-wait
+
+	if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
+		t.Error("Echoed buffer differed from original")
+	}
+}
+
+func startSSHServer(t *testing.T) (addr string) {
+	config := &ServerConfig{
+		NoClientAuth: true,
+	}
+
+	err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
+	if err != nil {
+		t.Fatalf("Failed to parse private key: %s", err.Error())
+	}
+
+	listener, err := Listen("tcp", ":0", config)
+	if err != nil {
+		t.Fatalf("Bind error: %s", err)
+	}
+
+	addr = listener.Addr().String()
+
+	go func() {
+		for {
+			sConn, err := listener.Accept()
+
+			err = sConn.Handshake()
+			if err != nil {
+				if err != io.EOF {
+					t.Fatalf("failed to handshake: %s", err)
+				}
+				return
+			}
+
+			go connRun(t, sConn)
+		}
+	}()
+
+	return
+}
+
+func connRun(t *testing.T, sConn *ServerConn) {
+	for {
+		channel, err := sConn.Accept()
+		if err != nil {
+			if err == io.EOF {
+				break
+			}
+			t.Fatalf("ServerConn.Accept failed: %s", err)
+		}
+
+		if channel.ChannelType() != "session" {
+			channel.Reject(UnknownChannelType, "unknown channel type")
+			continue
+		}
+		err = channel.Accept()
+		if err != nil {
+			t.Fatalf("Channel.Accept failed: %s", err)
+		}
+
+		go func() {
+			defer channel.Close()
+
+			n, err := CopyNRandomly(channel, channel, windowTestBytes)
+			if err != nil && err != io.EOF {
+				if err == io.ErrShortWrite {
+					t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
+				}
+				t.Fatal(err)
+			}
+		}()
+	}
+}