ssh: add Output and CombinedOutput helpers
R=golang-dev, dave
CC=golang-dev
https://golang.org/cl/9711043
diff --git a/ssh/session.go b/ssh/session.go
index 640d322..2576024 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -13,6 +13,7 @@
"fmt"
"io"
"io/ioutil"
+ "sync"
)
type Signal string
@@ -276,7 +277,8 @@
// Run runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
-// A Session only accepts one call to Run, Start or Shell.
+// A Session only accepts one call to Run, Start, Shell, Output,
+// or CombinedOutput.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
@@ -293,8 +295,46 @@
return s.Wait()
}
+// Output runs cmd on the remote host and returns its standard output.
+func (s *Session) Output(cmd string) ([]byte, error) {
+ if s.Stdout != nil {
+ return nil, errors.New("ssh: Stdout already set")
+ }
+ var b bytes.Buffer
+ s.Stdout = &b
+ err := s.Run(cmd)
+ return b.Bytes(), err
+}
+
+type singleWriter struct {
+ b bytes.Buffer
+ mu sync.Mutex
+}
+
+func (w *singleWriter) Write(p []byte) (int, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.b.Write(p)
+}
+
+// CombinedOutput runs cmd on the remote host and returns its combined
+// standard output and standard error.
+func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
+ if s.Stdout != nil {
+ return nil, errors.New("ssh: Stdout already set")
+ }
+ if s.Stderr != nil {
+ return nil, errors.New("ssh: Stderr already set")
+ }
+ var b singleWriter
+ s.Stdout = &b
+ s.Stderr = &b
+ err := s.Run(cmd)
+ return b.b.Bytes(), err
+}
+
// Shell starts a login shell on the remote host. A Session only
-// accepts one call to Run, Start or Shell.
+// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
func (s *Session) Shell() error {
if s.started {
return errors.New("ssh: session already started")
@@ -521,8 +561,6 @@
return s.clientChan.stderr, nil
}
-// TODO(dfc) add Output and CombinedOutput helpers
-
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport)
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 7453c8a..9588265 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -138,6 +138,54 @@
}
}
+// Test that a simple string is returned via the Output helper,
+// and that stderr is discarded.
+func TestSessionOutput(t *testing.T) {
+ conn := dial(fixedOutputHandler, t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %v", err)
+ }
+ defer session.Close()
+
+ buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
+ if err != nil {
+ t.Error("Remote command did not exit cleanly:", err)
+ }
+ w := "this-is-stdout."
+ g := string(buf)
+ if g != w {
+ t.Error("Remote command did not return expected string:")
+ t.Logf("want %q", w)
+ t.Logf("got %q", g)
+ }
+}
+
+// Test that both stdout and stderr are returned
+// via the CombinedOutput helper.
+func TestSessionCombinedOutput(t *testing.T) {
+ conn := dial(fixedOutputHandler, t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %v", err)
+ }
+ defer session.Close()
+
+ buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
+ if err != nil {
+ t.Error("Remote command did not exit cleanly:", err)
+ }
+ w := "this-is-stdout.this-is-stderr."
+ g := string(buf)
+ if g != w {
+ t.Error("Remote command did not return expected string:")
+ t.Logf("want %q", w)
+ t.Logf("got %q", g)
+ }
+}
+
// Test non-0 exit status is returned correctly.
func TestExitStatusNonZero(t *testing.T) {
conn := dial(exitStatusNonZeroHandler, t)
@@ -586,6 +634,30 @@
sendStatus(0, ch, t)
}
+// Ignores the command, writes fixed strings to stderr and stdout.
+// Strings are "this-is-stdout." and "this-is-stderr.".
+func fixedOutputHandler(ch *serverChan, t *testing.T) {
+ defer ch.Close()
+
+ _, err := ch.Read(make([]byte, 0))
+ if _, ok := err.(ChannelRequest); !ok {
+ t.Fatalf("error: expected channel request, got: %#v", err)
+ return
+ }
+ // ignore request, always send some text
+ ch.AckRequest(true)
+
+ _, err = io.WriteString(ch, "this-is-stdout.")
+ if err != nil {
+ t.Fatalf("error writing on server: %v", err)
+ }
+ _, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
+ if err != nil {
+ t.Fatalf("error writing on server: %v", err)
+ }
+ sendStatus(0, ch, t)
+}
+
func readLine(shell *ServerTerminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
t.Errorf("unable to read line: %v", err)