| // Copyright 2012 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package ssh |
| |
| import ( |
| "bytes" |
| crypto_rand "crypto/rand" |
| "io" |
| "math/rand" |
| "testing" |
| ) |
| |
| // windowTestBytes is the number of bytes that we'll send to the SSH server. |
| const windowTestBytes = 16000 * 200 |
| |
| // CopyNRandomly copies n bytes from src to dst. It uses a variable, and random, |
| // buffer size to exercise more code paths. |
| func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) { |
| buf := make([]byte, 32*1024) |
| for written < n { |
| l := (rand.Intn(30) + 1) * 1024 |
| if d := n - written; d < int64(l) { |
| l = int(d) |
| } |
| nr, er := src.Read(buf[0:l]) |
| if nr > 0 { |
| nw, ew := dst.Write(buf[0:nr]) |
| if nw > 0 { |
| written += int64(nw) |
| } |
| if ew != nil { |
| err = ew |
| break |
| } |
| if nr != nw { |
| err = io.ErrShortWrite |
| break |
| } |
| } |
| if er != nil { |
| err = er |
| break |
| } |
| } |
| return written, err |
| } |
| |
| func TestServerWindow(t *testing.T) { |
| addr := startSSHServer(t) |
| runSSHClient(t, addr) |
| } |
| |
| // runSSHClient writes random data to the server. The server is expected to echo |
| // the same data back, which is compared against the original. |
| func runSSHClient(t *testing.T, addr string) { |
| conn, err := Dial("tcp", addr, &ClientConfig{}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
| echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
| io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) |
| origBytes := origBuf.Bytes() |
| |
| wait := make(chan bool) |
| |
| // Read back the data from the server. |
| go func() { |
| defer session.Close() |
| serverStdout, err := session.StdoutPipe() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes) |
| if err != nil && err != io.EOF { |
| t.Fatal(err) |
| } |
| if n != windowTestBytes { |
| t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes) |
| } |
| wait <- true |
| }() |
| |
| serverStdin, err := session.StdinPipe() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if written != windowTestBytes { |
| t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) |
| } |
| |
| <-wait |
| |
| if !bytes.Equal(origBytes, echoedBuf.Bytes()) { |
| t.Error("Echoed buffer differed from original") |
| } |
| } |
| |
| func startSSHServer(t *testing.T) (addr string) { |
| config := &ServerConfig{ |
| NoClientAuth: true, |
| } |
| |
| err := config.SetRSAPrivateKey([]byte(testServerPrivateKey)) |
| if err != nil { |
| t.Fatalf("Failed to parse private key: %s", err.Error()) |
| } |
| |
| listener, err := Listen("tcp", ":0", config) |
| if err != nil { |
| t.Fatalf("Bind error: %s", err) |
| } |
| |
| addr = listener.Addr().String() |
| |
| go func() { |
| for { |
| sConn, err := listener.Accept() |
| |
| err = sConn.Handshake() |
| if err != nil { |
| if err != io.EOF { |
| t.Fatalf("failed to handshake: %s", err) |
| } |
| return |
| } |
| |
| go connRun(t, sConn) |
| } |
| }() |
| |
| return |
| } |
| |
| func connRun(t *testing.T, sConn *ServerConn) { |
| for { |
| channel, err := sConn.Accept() |
| if err != nil { |
| if err == io.EOF { |
| break |
| } |
| t.Fatalf("ServerConn.Accept failed: %s", err) |
| } |
| |
| if channel.ChannelType() != "session" { |
| channel.Reject(UnknownChannelType, "unknown channel type") |
| continue |
| } |
| err = channel.Accept() |
| if err != nil { |
| t.Fatalf("Channel.Accept failed: %s", err) |
| } |
| |
| go func() { |
| defer channel.Close() |
| |
| n, err := CopyNRandomly(channel, channel, windowTestBytes) |
| if err != nil && err != io.EOF { |
| if err == io.ErrShortWrite { |
| t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes) |
| } |
| t.Fatal(err) |
| } |
| }() |
| } |
| } |