| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package ssh |
| |
| // Session tests. |
| |
| import ( |
| "bytes" |
| crypto_rand "crypto/rand" |
| "errors" |
| "io" |
| "math/rand" |
| "net" |
| "sync" |
| "testing" |
| |
| "golang.org/x/crypto/ssh/terminal" |
| ) |
| |
| type serverType func(Channel, <-chan *Request, *testing.T) |
| |
| // dial constructs a new test server and returns a *ClientConn. |
| func dial(handler serverType, t *testing.T) *Client { |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| |
| var wg sync.WaitGroup |
| t.Cleanup(wg.Wait) |
| wg.Add(1) |
| go func() { |
| defer func() { |
| c1.Close() |
| wg.Done() |
| }() |
| conf := ServerConfig{ |
| NoClientAuth: true, |
| } |
| conf.AddHostKey(testSigners["rsa"]) |
| |
| conn, chans, reqs, err := NewServerConn(c1, &conf) |
| if err != nil { |
| t.Errorf("Unable to handshake: %v", err) |
| return |
| } |
| wg.Add(1) |
| go func() { |
| DiscardRequests(reqs) |
| wg.Done() |
| }() |
| |
| for newCh := range chans { |
| if newCh.ChannelType() != "session" { |
| newCh.Reject(UnknownChannelType, "unknown channel type") |
| continue |
| } |
| |
| ch, inReqs, err := newCh.Accept() |
| if err != nil { |
| t.Errorf("Accept: %v", err) |
| continue |
| } |
| wg.Add(1) |
| go func() { |
| handler(ch, inReqs, t) |
| wg.Done() |
| }() |
| } |
| if err := conn.Wait(); err != io.EOF { |
| t.Logf("server exit reason: %v", err) |
| } |
| }() |
| |
| config := &ClientConfig{ |
| User: "testuser", |
| HostKeyCallback: InsecureIgnoreHostKey(), |
| } |
| |
| conn, chans, reqs, err := NewClientConn(c2, "", config) |
| if err != nil { |
| t.Fatalf("unable to dial remote side: %v", err) |
| } |
| |
| return NewClient(conn, chans, reqs) |
| } |
| |
| // Test a simple string is returned to session.Stdout. |
| func TestSessionShell(t *testing.T) { |
| conn := dial(shellHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| stdout := new(bytes.Buffer) |
| session.Stdout = stdout |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %s", err) |
| } |
| if err := session.Wait(); err != nil { |
| t.Fatalf("Remote command did not exit cleanly: %v", err) |
| } |
| actual := stdout.String() |
| if actual != "golang" { |
| t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) |
| } |
| } |
| |
| // TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. |
| |
| // Test a simple string is returned via StdoutPipe. |
| func TestSessionStdoutPipe(t *testing.T) { |
| conn := dial(shellHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| stdout, err := session.StdoutPipe() |
| if err != nil { |
| t.Fatalf("Unable to request StdoutPipe(): %v", err) |
| } |
| var buf bytes.Buffer |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| done := make(chan bool, 1) |
| go func() { |
| if _, err := io.Copy(&buf, stdout); err != nil { |
| t.Errorf("Copy of stdout failed: %v", err) |
| } |
| done <- true |
| }() |
| if err := session.Wait(); err != nil { |
| t.Fatalf("Remote command did not exit cleanly: %v", err) |
| } |
| <-done |
| actual := buf.String() |
| if actual != "golang" { |
| t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) |
| } |
| } |
| |
| // Test that a simple string is returned via the Output helper, |
| // and that stderr is discarded. |
| func TestSessionOutput(t *testing.T) { |
| conn := dial(fixedOutputHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| |
| buf, err := session.Output("") // cmd is ignored by fixedOutputHandler |
| if err != nil { |
| t.Error("Remote command did not exit cleanly:", err) |
| } |
| w := "this-is-stdout." |
| g := string(buf) |
| if g != w { |
| t.Error("Remote command did not return expected string:") |
| t.Logf("want %q", w) |
| t.Logf("got %q", g) |
| } |
| } |
| |
| // Test that both stdout and stderr are returned |
| // via the CombinedOutput helper. |
| func TestSessionCombinedOutput(t *testing.T) { |
| conn := dial(fixedOutputHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| |
| buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler |
| if err != nil { |
| t.Error("Remote command did not exit cleanly:", err) |
| } |
| const stdout = "this-is-stdout." |
| const stderr = "this-is-stderr." |
| g := string(buf) |
| if g != stdout+stderr && g != stderr+stdout { |
| t.Error("Remote command did not return expected string:") |
| t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) |
| t.Logf("got %q", g) |
| } |
| } |
| |
| // Test non-0 exit status is returned correctly. |
| func TestExitStatusNonZero(t *testing.T) { |
| conn := dial(exitStatusNonZeroHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err == nil { |
| t.Fatalf("expected command to fail but it didn't") |
| } |
| e, ok := err.(*ExitError) |
| if !ok { |
| t.Fatalf("expected *ExitError but got %T", err) |
| } |
| if e.ExitStatus() != 15 { |
| t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) |
| } |
| } |
| |
| // Test 0 exit status is returned correctly. |
| func TestExitStatusZero(t *testing.T) { |
| conn := dial(exitStatusZeroHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err != nil { |
| t.Fatalf("expected nil but got %v", err) |
| } |
| } |
| |
| // Test exit signal and status are both returned correctly. |
| func TestExitSignalAndStatus(t *testing.T) { |
| conn := dial(exitSignalAndStatusHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err == nil { |
| t.Fatalf("expected command to fail but it didn't") |
| } |
| e, ok := err.(*ExitError) |
| if !ok { |
| t.Fatalf("expected *ExitError but got %T", err) |
| } |
| if e.Signal() != "TERM" || e.ExitStatus() != 15 { |
| t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
| } |
| } |
| |
| // Test exit signal and status are both returned correctly. |
| func TestKnownExitSignalOnly(t *testing.T) { |
| conn := dial(exitSignalHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err == nil { |
| t.Fatalf("expected command to fail but it didn't") |
| } |
| e, ok := err.(*ExitError) |
| if !ok { |
| t.Fatalf("expected *ExitError but got %T", err) |
| } |
| if e.Signal() != "TERM" || e.ExitStatus() != 143 { |
| t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
| } |
| } |
| |
| // Test exit signal and status are both returned correctly. |
| func TestUnknownExitSignal(t *testing.T) { |
| conn := dial(exitSignalUnknownHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err == nil { |
| t.Fatalf("expected command to fail but it didn't") |
| } |
| e, ok := err.(*ExitError) |
| if !ok { |
| t.Fatalf("expected *ExitError but got %T", err) |
| } |
| if e.Signal() != "SYS" || e.ExitStatus() != 128 { |
| t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
| } |
| } |
| |
| func TestExitWithoutStatusOrSignal(t *testing.T) { |
| conn := dial(exitWithoutSignalOrStatus, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatalf("Unable to request new session: %v", err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err == nil { |
| t.Fatalf("expected command to fail but it didn't") |
| } |
| if _, ok := err.(*ExitMissingError); !ok { |
| t.Fatalf("got %T want *ExitMissingError", err) |
| } |
| } |
| |
| // windowTestBytes is the number of bytes that we'll send to the SSH server. |
| const windowTestBytes = 16000 * 200 |
| |
| // TestServerWindow writes random data to the server. The server is expected to echo |
| // the same data back, which is compared against the original. |
| func TestServerWindow(t *testing.T) { |
| origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
| io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) |
| origBytes := origBuf.Bytes() |
| |
| conn := dial(echoHandler, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer session.Close() |
| |
| serverStdin, err := session.StdinPipe() |
| if err != nil { |
| t.Fatalf("StdinPipe failed: %v", err) |
| } |
| |
| result := make(chan []byte) |
| go func() { |
| defer close(result) |
| echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
| serverStdout, err := session.StdoutPipe() |
| if err != nil { |
| t.Errorf("StdoutPipe failed: %v", err) |
| return |
| } |
| n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) |
| if err != nil && err != io.EOF { |
| t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) |
| } |
| result <- echoedBuf.Bytes() |
| }() |
| |
| written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) |
| if err != nil { |
| t.Errorf("failed to copy origBuf to serverStdin: %v", err) |
| } else if written != windowTestBytes { |
| t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes) |
| } |
| |
| echoedBytes := <-result |
| |
| if !bytes.Equal(origBytes, echoedBytes) { |
| t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) |
| } |
| } |
| |
| // Verify the client can handle a keepalive packet from the server. |
| func TestClientHandlesKeepalives(t *testing.T) { |
| conn := dial(channelKeepaliveSender, t) |
| defer conn.Close() |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer session.Close() |
| if err := session.Shell(); err != nil { |
| t.Fatalf("Unable to execute command: %v", err) |
| } |
| err = session.Wait() |
| if err != nil { |
| t.Fatalf("expected nil but got: %v", err) |
| } |
| } |
| |
| type exitStatusMsg struct { |
| Status uint32 |
| } |
| |
| type exitSignalMsg struct { |
| Signal string |
| CoreDumped bool |
| Errmsg string |
| Lang string |
| } |
| |
| func handleTerminalRequests(in <-chan *Request) { |
| for req := range in { |
| ok := false |
| switch req.Type { |
| case "shell": |
| ok = true |
| if len(req.Payload) > 0 { |
| // We don't accept any commands, only the default shell. |
| ok = false |
| } |
| case "env": |
| ok = true |
| } |
| req.Reply(ok, nil) |
| } |
| } |
| |
| func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { |
| term := terminal.NewTerminal(ch, prompt) |
| go handleTerminalRequests(in) |
| return term |
| } |
| |
| func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| // this string is returned to stdout |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| sendStatus(0, ch, t) |
| } |
| |
| func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| sendStatus(15, ch, t) |
| } |
| |
| func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| sendStatus(15, ch, t) |
| sendSignal("TERM", ch, t) |
| } |
| |
| func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| sendSignal("TERM", ch, t) |
| } |
| |
| func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| sendSignal("SYS", ch, t) |
| } |
| |
| func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| } |
| |
| func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| // this string is returned to stdout |
| shell := newServerShell(ch, in, "golang") |
| readLine(shell, t) |
| sendStatus(0, ch, t) |
| } |
| |
| // Ignores the command, writes fixed strings to stderr and stdout. |
| // Strings are "this-is-stdout." and "this-is-stderr.". |
| func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| _, err := ch.Read(nil) |
| |
| req, ok := <-in |
| if !ok { |
| t.Fatalf("error: expected channel request, got: %#v", err) |
| return |
| } |
| |
| // ignore request, always send some text |
| req.Reply(true, nil) |
| |
| _, err = io.WriteString(ch, "this-is-stdout.") |
| if err != nil { |
| t.Fatalf("error writing on server: %v", err) |
| } |
| _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") |
| if err != nil { |
| t.Fatalf("error writing on server: %v", err) |
| } |
| sendStatus(0, ch, t) |
| } |
| |
| func readLine(shell *terminal.Terminal, t *testing.T) { |
| if _, err := shell.ReadLine(); err != nil && err != io.EOF { |
| t.Errorf("unable to read line: %v", err) |
| } |
| } |
| |
| func sendStatus(status uint32, ch Channel, t *testing.T) { |
| msg := exitStatusMsg{ |
| Status: status, |
| } |
| if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { |
| t.Errorf("unable to send status: %v", err) |
| } |
| } |
| |
| func sendSignal(signal string, ch Channel, t *testing.T) { |
| sig := exitSignalMsg{ |
| Signal: signal, |
| CoreDumped: false, |
| Errmsg: "Process terminated", |
| Lang: "en-GB-oed", |
| } |
| if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { |
| t.Errorf("unable to send signal: %v", err) |
| } |
| } |
| |
| func discardHandler(ch Channel, t *testing.T) { |
| defer ch.Close() |
| io.Copy(io.Discard, ch) |
| } |
| |
| func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { |
| t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) |
| } |
| } |
| |
| // copyNRandomly copies n bytes from src to dst. It uses a variable, and random, |
| // buffer size to exercise more code paths. |
| func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { |
| var ( |
| buf = make([]byte, 32*1024) |
| written int |
| remaining = n |
| ) |
| for remaining > 0 { |
| l := rand.Intn(1 << 15) |
| if remaining < l { |
| l = remaining |
| } |
| nr, er := src.Read(buf[:l]) |
| nw, ew := dst.Write(buf[:nr]) |
| remaining -= nw |
| written += nw |
| if ew != nil { |
| return written, ew |
| } |
| if nr != nw { |
| return written, io.ErrShortWrite |
| } |
| if er != nil && er != io.EOF { |
| return written, er |
| } |
| } |
| return written, nil |
| } |
| |
| func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| shell := newServerShell(ch, in, "> ") |
| readLine(shell, t) |
| if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { |
| t.Errorf("unable to send channel keepalive request: %v", err) |
| } |
| sendStatus(0, ch, t) |
| } |
| |
| func TestClientWriteEOF(t *testing.T) { |
| conn := dial(simpleEchoHandler, t) |
| defer conn.Close() |
| |
| session, err := conn.NewSession() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer session.Close() |
| stdin, err := session.StdinPipe() |
| if err != nil { |
| t.Fatalf("StdinPipe failed: %v", err) |
| } |
| stdout, err := session.StdoutPipe() |
| if err != nil { |
| t.Fatalf("StdoutPipe failed: %v", err) |
| } |
| |
| data := []byte(`0000`) |
| _, err = stdin.Write(data) |
| if err != nil { |
| t.Fatalf("Write failed: %v", err) |
| } |
| stdin.Close() |
| |
| res, err := io.ReadAll(stdout) |
| if err != nil { |
| t.Fatalf("Read failed: %v", err) |
| } |
| |
| if !bytes.Equal(data, res) { |
| t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) |
| } |
| } |
| |
| func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { |
| defer ch.Close() |
| data, err := io.ReadAll(ch) |
| if err != nil { |
| t.Errorf("handler read error: %v", err) |
| } |
| _, err = ch.Write(data) |
| if err != nil { |
| t.Errorf("handler write error: %v", err) |
| } |
| } |
| |
| func TestSessionID(t *testing.T) { |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| defer c1.Close() |
| defer c2.Close() |
| |
| serverID := make(chan []byte, 1) |
| clientID := make(chan []byte, 1) |
| |
| serverConf := &ServerConfig{ |
| NoClientAuth: true, |
| } |
| serverConf.AddHostKey(testSigners["ecdsa"]) |
| clientConf := &ClientConfig{ |
| HostKeyCallback: InsecureIgnoreHostKey(), |
| User: "user", |
| } |
| |
| var wg sync.WaitGroup |
| t.Cleanup(wg.Wait) |
| |
| srvErrCh := make(chan error, 1) |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| conn, chans, reqs, err := NewServerConn(c1, serverConf) |
| srvErrCh <- err |
| if err != nil { |
| return |
| } |
| serverID <- conn.SessionID() |
| wg.Add(1) |
| go func() { |
| DiscardRequests(reqs) |
| wg.Done() |
| }() |
| for ch := range chans { |
| ch.Reject(Prohibited, "") |
| } |
| }() |
| |
| cliErrCh := make(chan error, 1) |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| conn, chans, reqs, err := NewClientConn(c2, "", clientConf) |
| cliErrCh <- err |
| if err != nil { |
| return |
| } |
| clientID <- conn.SessionID() |
| wg.Add(1) |
| go func() { |
| DiscardRequests(reqs) |
| wg.Done() |
| }() |
| for ch := range chans { |
| ch.Reject(Prohibited, "") |
| } |
| }() |
| |
| if err := <-srvErrCh; err != nil { |
| t.Fatalf("server handshake: %v", err) |
| } |
| |
| if err := <-cliErrCh; err != nil { |
| t.Fatalf("client handshake: %v", err) |
| } |
| |
| s := <-serverID |
| c := <-clientID |
| if bytes.Compare(s, c) != 0 { |
| t.Errorf("server session ID (%x) != client session ID (%x)", s, c) |
| } else if len(s) == 0 { |
| t.Errorf("client and server SessionID were empty.") |
| } |
| } |
| |
| type noReadConn struct { |
| readSeen bool |
| net.Conn |
| } |
| |
| func (c *noReadConn) Close() error { |
| return nil |
| } |
| |
| func (c *noReadConn) Read(b []byte) (int, error) { |
| c.readSeen = true |
| return 0, errors.New("noReadConn error") |
| } |
| |
| func TestInvalidServerConfiguration(t *testing.T) { |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| defer c1.Close() |
| defer c2.Close() |
| |
| serveConn := noReadConn{Conn: c1} |
| serverConf := &ServerConfig{} |
| |
| NewServerConn(&serveConn, serverConf) |
| if serveConn.readSeen { |
| t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") |
| } |
| |
| serverConf.AddHostKey(testSigners["ecdsa"]) |
| |
| NewServerConn(&serveConn, serverConf) |
| if serveConn.readSeen { |
| t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") |
| } |
| } |
| |
| func TestHostKeyAlgorithms(t *testing.T) { |
| serverConf := &ServerConfig{ |
| NoClientAuth: true, |
| } |
| serverConf.AddHostKey(testSigners["rsa"]) |
| serverConf.AddHostKey(testSigners["ecdsa"]) |
| |
| var wg sync.WaitGroup |
| t.Cleanup(wg.Wait) |
| connect := func(clientConf *ClientConfig, want string) { |
| var alg string |
| clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { |
| alg = key.Type() |
| return nil |
| } |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| defer c1.Close() |
| defer c2.Close() |
| |
| wg.Add(1) |
| go func() { |
| NewServerConn(c1, serverConf) |
| wg.Done() |
| }() |
| _, _, _, err = NewClientConn(c2, "", clientConf) |
| if err != nil { |
| t.Fatalf("NewClientConn: %v", err) |
| } |
| if alg != want { |
| t.Errorf("selected key algorithm %s, want %s", alg, want) |
| } |
| } |
| |
| // By default, we get the preferred algorithm, which is ECDSA 256. |
| |
| clientConf := &ClientConfig{ |
| HostKeyCallback: InsecureIgnoreHostKey(), |
| } |
| connect(clientConf, KeyAlgoECDSA256) |
| |
| // Client asks for RSA explicitly. |
| clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} |
| connect(clientConf, KeyAlgoRSA) |
| |
| // Client asks for RSA-SHA2-512 explicitly. |
| clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512} |
| // We get back an "ssh-rsa" key but the verification happened |
| // with an RSA-SHA2-512 signature. |
| connect(clientConf, KeyAlgoRSA) |
| |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| defer c1.Close() |
| defer c2.Close() |
| |
| wg.Add(1) |
| go func() { |
| NewServerConn(c1, serverConf) |
| wg.Done() |
| }() |
| clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} |
| _, _, _, err = NewClientConn(c2, "", clientConf) |
| if err == nil { |
| t.Fatal("succeeded connecting with unknown hostkey algorithm") |
| } |
| } |
| |
| func TestServerClientAuthCallback(t *testing.T) { |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| defer c1.Close() |
| defer c2.Close() |
| |
| userCh := make(chan string, 1) |
| |
| serverConf := &ServerConfig{ |
| NoClientAuth: true, |
| NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { |
| userCh <- conn.User() |
| return nil, nil |
| }, |
| } |
| const someUsername = "some-username" |
| |
| serverConf.AddHostKey(testSigners["ecdsa"]) |
| clientConf := &ClientConfig{ |
| HostKeyCallback: InsecureIgnoreHostKey(), |
| User: someUsername, |
| } |
| |
| var wg sync.WaitGroup |
| t.Cleanup(wg.Wait) |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| _, chans, reqs, err := NewServerConn(c1, serverConf) |
| if err != nil { |
| t.Errorf("server handshake: %v", err) |
| userCh <- "error" |
| return |
| } |
| wg.Add(1) |
| go func() { |
| DiscardRequests(reqs) |
| wg.Done() |
| }() |
| for ch := range chans { |
| ch.Reject(Prohibited, "") |
| } |
| }() |
| |
| conn, _, _, err := NewClientConn(c2, "", clientConf) |
| if err != nil { |
| t.Fatalf("client handshake: %v", err) |
| return |
| } |
| conn.Close() |
| |
| got := <-userCh |
| if got != someUsername { |
| t.Errorf("username = %q; want %q", got, someUsername) |
| } |
| } |