ssh/test: set a timeout and WaitDelay on sshd subcommands

This uses a copy of testenv.Command copied from the main repo, with
light edits to allow the testenv helpers to build with Go 1.19.

The testenv helper revealed an exec.Command leak in TestCertLogin, so
we also fix that leak and simplify server cleanup using
testing.T.Cleanup.

For golang/go#60099.
Fixes golang/go#60343.

Change-Id: I7f79fcdb559498b987ee7689972ac53b83870aaf
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/496935
Auto-Submit: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Run-TryBot: Bryan Mills <bcmills@google.com>
diff --git a/internal/testenv/exec.go b/internal/testenv/exec.go
new file mode 100644
index 0000000..4bacdc3
--- /dev/null
+++ b/internal/testenv/exec.go
@@ -0,0 +1,120 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testenv
+
+import (
+	"context"
+	"os"
+	"os/exec"
+	"reflect"
+	"strconv"
+	"testing"
+	"time"
+)
+
+// CommandContext is like exec.CommandContext, but:
+//   - skips t if the platform does not support os/exec,
+//   - sends SIGQUIT (if supported by the platform) instead of SIGKILL
+//     in its Cancel function
+//   - if the test has a deadline, adds a Context timeout and WaitDelay
+//     for an arbitrary grace period before the test's deadline expires,
+//   - fails the test if the command does not complete before the test's deadline, and
+//   - sets a Cleanup function that verifies that the test did not leak a subprocess.
+func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd {
+	t.Helper()
+
+	var (
+		cancelCtx   context.CancelFunc
+		gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging)
+	)
+
+	if t, ok := t.(interface {
+		testing.TB
+		Deadline() (time.Time, bool)
+	}); ok {
+		if td, ok := t.Deadline(); ok {
+			// Start with a minimum grace period, just long enough to consume the
+			// output of a reasonable program after it terminates.
+			gracePeriod = 100 * time.Millisecond
+			if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" {
+				scale, err := strconv.Atoi(s)
+				if err != nil {
+					t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err)
+				}
+				gracePeriod *= time.Duration(scale)
+			}
+
+			// If time allows, increase the termination grace period to 5% of the
+			// test's remaining time.
+			testTimeout := time.Until(td)
+			if gp := testTimeout / 20; gp > gracePeriod {
+				gracePeriod = gp
+			}
+
+			// When we run commands that execute subprocesses, we want to reserve two
+			// grace periods to clean up: one for the delay between the first
+			// termination signal being sent (via the Cancel callback when the Context
+			// expires) and the process being forcibly terminated (via the WaitDelay
+			// field), and a second one for the delay becween the process being
+			// terminated and and the test logging its output for debugging.
+			//
+			// (We want to ensure that the test process itself has enough time to
+			// log the output before it is also terminated.)
+			cmdTimeout := testTimeout - 2*gracePeriod
+
+			if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout {
+				// Either ctx doesn't have a deadline, or its deadline would expire
+				// after (or too close before) the test has already timed out.
+				// Add a shorter timeout so that the test will produce useful output.
+				ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout)
+			}
+		}
+	}
+
+	cmd := exec.CommandContext(ctx, name, args...)
+	// Set the Cancel and WaitDelay fields only if present (go 1.20 and later).
+	// TODO: When Go 1.19 is no longer supported, remove this use of reflection
+	// and instead set the fields directly.
+	if cmdCancel := reflect.ValueOf(cmd).Elem().FieldByName("Cancel"); cmdCancel.IsValid() {
+		cmdCancel.Set(reflect.ValueOf(func() error {
+			if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded {
+				// The command timed out due to running too close to the test's deadline.
+				// There is no way the test did that intentionally — it's too close to the
+				// wire! — so mark it as a test failure. That way, if the test expects the
+				// command to fail for some other reason, it doesn't have to distinguish
+				// between that reason and a timeout.
+				t.Errorf("test timed out while running command: %v", cmd)
+			} else {
+				// The command is being terminated due to ctx being canceled, but
+				// apparently not due to an explicit test deadline that we added.
+				// Log that information in case it is useful for diagnosing a failure,
+				// but don't actually fail the test because of it.
+				t.Logf("%v: terminating command: %v", ctx.Err(), cmd)
+			}
+			return cmd.Process.Signal(Sigquit)
+		}))
+	}
+	if cmdWaitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); cmdWaitDelay.IsValid() {
+		cmdWaitDelay.Set(reflect.ValueOf(gracePeriod))
+	}
+
+	t.Cleanup(func() {
+		if cancelCtx != nil {
+			cancelCtx()
+		}
+		if cmd.Process != nil && cmd.ProcessState == nil {
+			t.Errorf("command was started, but test did not wait for it to complete: %v", cmd)
+		}
+	})
+
+	return cmd
+}
+
+// Command is like exec.Command, but applies the same changes as
+// testenv.CommandContext (with a default Context).
+func Command(t testing.TB, name string, args ...string) *exec.Cmd {
+	t.Helper()
+	return CommandContext(t, context.Background(), name, args...)
+}
diff --git a/internal/testenv/testenv_notunix.go b/internal/testenv/testenv_notunix.go
new file mode 100644
index 0000000..c8918ce
--- /dev/null
+++ b/internal/testenv/testenv_notunix.go
@@ -0,0 +1,15 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows || plan9 || (js && wasm) || wasip1
+
+package testenv
+
+import (
+	"os"
+)
+
+// Sigquit is the signal to send to kill a hanging subprocess.
+// On Unix we send SIGQUIT, but on non-Unix we only have os.Kill.
+var Sigquit = os.Kill
diff --git a/internal/testenv/testenv_unix.go b/internal/testenv/testenv_unix.go
new file mode 100644
index 0000000..4f51823
--- /dev/null
+++ b/internal/testenv/testenv_unix.go
@@ -0,0 +1,15 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package testenv
+
+import (
+	"syscall"
+)
+
+// Sigquit is the signal to send to kill a hanging subprocess.
+// Send SIGQUIT to get a stack trace.
+var Sigquit = syscall.SIGQUIT
diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go
index d90526c..43fbdb2 100644
--- a/ssh/test/agent_unix_test.go
+++ b/ssh/test/agent_unix_test.go
@@ -17,7 +17,6 @@
 
 func TestAgentForward(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
diff --git a/ssh/test/banner_test.go b/ssh/test/banner_test.go
index 22bdd67..3bfdd4b 100644
--- a/ssh/test/banner_test.go
+++ b/ssh/test/banner_test.go
@@ -13,7 +13,6 @@
 
 func TestBannerCallbackAgainstOpenSSH(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 
 	clientConf := clientConfig()
 
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go
index 77891e3..83dd534 100644
--- a/ssh/test/cert_test.go
+++ b/ssh/test/cert_test.go
@@ -18,7 +18,6 @@
 // Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly
 func TestCertLogin(t *testing.T) {
 	s := newServer(t)
-	defer s.Shutdown()
 
 	// Use a key different from the default.
 	clientKey := testSigners["dsa"]
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go
index d3e3d54..4a7ec31 100644
--- a/ssh/test/dial_unix_test.go
+++ b/ssh/test/dial_unix_test.go
@@ -24,7 +24,6 @@
 
 func testDial(t *testing.T, n, listenAddr string, x dialTester) {
 	server := newServer(t)
-	defer server.Shutdown()
 	sshConn := server.Dial(clientConfig())
 	defer sshConn.Close()
 
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index f0595af..1171bc3 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -23,7 +23,6 @@
 
 func testPortForward(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -120,7 +119,6 @@
 
 func testAcceptClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 
 	sshListener, err := conn.Listen(n, listenAddr)
@@ -162,10 +160,9 @@
 // Check that listeners exit if the underlying client transport dies.
 func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
-	conn := server.Dial(clientConfig())
+	client := server.Dial(clientConfig())
 
-	sshListener, err := conn.Listen(n, listenAddr)
+	sshListener, err := client.Listen(n, listenAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -184,14 +181,10 @@
 
 	// It would be even nicer if we closed the server side, but it
 	// is more involved as the fd for that side is dup()ed.
-	server.clientConn.Close()
+	server.lastDialConn.Close()
 
-	select {
-	case <-time.After(1 * time.Second):
-		t.Errorf("timeout: listener did not close.")
-	case err := <-quit:
-		t.Logf("quit as expected (error %v)", err)
-	}
+	err = <-quit
+	t.Logf("quit as expected (error %v)", err)
 }
 
 func TestPortForwardConnectionCloseTCP(t *testing.T) {
diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go
index da8f674..6c253a7 100644
--- a/ssh/test/multi_auth_test.go
+++ b/ssh/test/multi_auth_test.go
@@ -108,7 +108,6 @@
 			ctx := newMultiAuthTestCtx(t)
 
 			server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
-			defer server.Shutdown()
 
 			clientConfig := clientConfig()
 			server.setTestPassword(clientConfig.User, ctx.password)
diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go
index 7d96ced..e98b786 100644
--- a/ssh/test/session_test.go
+++ b/ssh/test/session_test.go
@@ -25,7 +25,6 @@
 
 func TestRunCommandSuccess(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -42,7 +41,6 @@
 
 func TestHostKeyCheck(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 
 	conf := clientConfig()
 	hostDB := hostKeyDB()
@@ -64,7 +62,6 @@
 
 func TestRunCommandStdin(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -87,7 +84,6 @@
 
 func TestRunCommandStdinError(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -111,7 +107,6 @@
 
 func TestRunCommandFailed(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -128,7 +123,6 @@
 
 func TestRunCommandWeClosed(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -148,7 +142,6 @@
 
 func TestFuncLargeRead(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -180,7 +173,6 @@
 
 func TestKeyChange(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conf := clientConfig()
 	hostDB := hostKeyDB()
 	conf.HostKeyCallback = hostDB.Check
@@ -227,7 +219,6 @@
 		t.Skipf("skipping on %s", runtime.GOOS)
 	}
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -292,7 +283,6 @@
 		t.Skipf("skipping on %s", runtime.GOOS)
 	}
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -340,7 +330,6 @@
 
 func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conf := clientConfig()
 	conf.Ciphers = []string{cipher}
 	// Don't fail if sshd doesn't have the cipher.
@@ -399,7 +388,6 @@
 	for _, mac := range macOrder {
 		t.Run(mac, func(t *testing.T) {
 			server := newServer(t)
-			defer server.Shutdown()
 			conf := clientConfig()
 			conf.MACs = []string{mac}
 			// Don't fail if sshd doesn't have the MAC.
@@ -425,7 +413,6 @@
 	for _, kex := range kexOrder {
 		t.Run(kex, func(t *testing.T) {
 			server := newServer(t)
-			defer server.Shutdown()
 			conf := clientConfig()
 			// Don't fail if sshd doesn't have the kex.
 			conf.KeyExchanges = append([]string{kex}, kexOrder...)
@@ -460,8 +447,6 @@
 			} else {
 				t.Errorf("failed for key %q", key)
 			}
-
-			server.Shutdown()
 		})
 	}
 }
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 3012a97..f3f55db 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -23,6 +23,7 @@
 	"testing"
 	"text/template"
 
+	"golang.org/x/crypto/internal/testenv"
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/crypto/ssh/testdata"
 )
@@ -67,17 +68,13 @@
 
 type server struct {
 	t          *testing.T
-	cleanup    func() // executed during Shutdown
 	configfile string
-	cmd        *exec.Cmd
-	output     bytes.Buffer // holds stderr from sshd process
 
 	testUser     string // test username for sshd
 	testPasswd   string // test password for sshd
 	sshdTestPwSo string // dynamic library to inject a custom password into sshd
 
-	// Client half of the network connection.
-	clientConn net.Conn
+	lastDialConn net.Conn
 }
 
 func username() string {
@@ -193,15 +190,15 @@
 		s.t.Fatalf("unixConnection: %v", err)
 	}
 
-	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
+	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	f, err := c2.File()
 	if err != nil {
 		s.t.Fatalf("UnixConn.File: %v", err)
 	}
 	defer f.Close()
-	s.cmd.Stdin = f
-	s.cmd.Stdout = f
-	s.cmd.Stderr = &s.output
+	cmd.Stdin = f
+	cmd.Stdout = f
+	cmd.Stderr = new(bytes.Buffer)
 
 	if s.sshdTestPwSo != "" {
 		if s.testUser == "" {
@@ -210,18 +207,29 @@
 		if s.testPasswd == "" {
 			s.t.Fatal("password missing from sshd_test_pw.so config")
 		}
-		s.cmd.Env = append(os.Environ(),
+		cmd.Env = append(os.Environ(),
 			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
 			fmt.Sprintf("TEST_USER=%s", s.testUser),
 			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
 	}
 
-	if err := s.cmd.Start(); err != nil {
-		s.t.Fail()
-		s.Shutdown()
+	if err := cmd.Start(); err != nil {
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
-	s.clientConn = c1
+	s.lastDialConn = c1
+	s.t.Cleanup(func() {
+		// Don't check for errors; if it fails it's most
+		// likely "os: process already finished", and we don't
+		// care about that. Use os.Interrupt, so child
+		// processes are killed too.
+		cmd.Process.Signal(os.Interrupt)
+		cmd.Wait()
+		if s.t.Failed() {
+			// log any output from sshd process
+			s.t.Logf("sshd:\n%s", cmd.Stderr)
+		}
+	})
+
 	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
 	if err != nil {
 		return nil, err
@@ -232,29 +240,11 @@
 func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
 	conn, err := s.TryDial(config)
 	if err != nil {
-		s.t.Fail()
-		s.Shutdown()
 		s.t.Fatalf("ssh.Client: %v", err)
 	}
 	return conn
 }
 
-func (s *server) Shutdown() {
-	if s.cmd != nil && s.cmd.Process != nil {
-		// Don't check for errors; if it fails it's most
-		// likely "os: process already finished", and we don't
-		// care about that. Use os.Interrupt, so child
-		// processes are killed too.
-		s.cmd.Process.Signal(os.Interrupt)
-		s.cmd.Wait()
-	}
-	if s.t.Failed() {
-		// log any output from sshd process
-		s.t.Logf("sshd: %s", s.output.String())
-	}
-	s.cleanup()
-}
-
 func writeFile(path string, contents []byte) {
 	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
 	if err != nil {
@@ -351,15 +341,15 @@
 		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
 	}
 	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
+	t.Cleanup(func() {
+		if err := os.RemoveAll(dir); err != nil {
+			t.Error(err)
+		}
+	})
 
 	return &server{
 		t:          t,
 		configfile: f.Name(),
-		cleanup: func() {
-			if err := os.RemoveAll(dir); err != nil {
-				t.Error(err)
-			}
-		},
 	}
 }