blob: 3e79e48060ebb3519b881721028ebb8a1bb0de17 [file] [log] [blame]
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
crypto_rand "crypto/rand"
"io"
"math/rand"
"testing"
)
// windowTestBytes is the number of bytes that we'll send to the SSH server.
const windowTestBytes = 16000 * 200
// CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
// buffer size to exercise more code paths.
func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
buf := make([]byte, 32*1024)
for written < n {
l := (rand.Intn(30) + 1) * 1024
if d := n - written; d < int64(l) {
l = int(d)
}
nr, er := src.Read(buf[0:l])
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
err = er
break
}
}
return written, err
}
func TestServerWindow(t *testing.T) {
addr := startSSHServer(t)
runSSHClient(t, addr)
}
// runSSHClient writes random data to the server. The server is expected to echo
// the same data back, which is compared against the original.
func runSSHClient(t *testing.T, addr string) {
conn, err := Dial("tcp", addr, &ClientConfig{})
if err != nil {
t.Fatal(err)
}
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
origBytes := origBuf.Bytes()
wait := make(chan bool)
// Read back the data from the server.
go func() {
defer session.Close()
serverStdout, err := session.StdoutPipe()
if err != nil {
t.Fatal(err)
}
n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if n != windowTestBytes {
t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
}
wait <- true
}()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatal(err)
}
written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Fatal(err)
}
if written != windowTestBytes {
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
}
<-wait
if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
t.Error("Echoed buffer differed from original")
}
}
func startSSHServer(t *testing.T) (addr string) {
config := &ServerConfig{
NoClientAuth: true,
}
err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
if err != nil {
t.Fatalf("Failed to parse private key: %s", err.Error())
}
listener, err := Listen("tcp", ":0", config)
if err != nil {
t.Fatalf("Bind error: %s", err)
}
addr = listener.Addr().String()
go func() {
for {
sConn, err := listener.Accept()
err = sConn.Handshake()
if err != nil {
if err != io.EOF {
t.Fatalf("failed to handshake: %s", err)
}
return
}
go connRun(t, sConn)
}
}()
return
}
func connRun(t *testing.T, sConn *ServerConn) {
for {
channel, err := sConn.Accept()
if err != nil {
if err == io.EOF {
break
}
t.Fatalf("ServerConn.Accept failed: %s", err)
}
if channel.ChannelType() != "session" {
channel.Reject(UnknownChannelType, "unknown channel type")
continue
}
err = channel.Accept()
if err != nil {
t.Fatalf("Channel.Accept failed: %s", err)
}
go func() {
defer channel.Close()
n, err := CopyNRandomly(channel, channel, windowTestBytes)
if err != nil && err != io.EOF {
if err == io.ErrShortWrite {
t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
}
t.Fatal(err)
}
}()
}
}