ssh/test: avoid leaking a net.UnixConn in server.TryDialWithAddr
For golang/go#64959.
Change-Id: I2153166f4960058cdc2b82ae34ca250dcc6ba1c6
Cq-Include-Trybots: luci.golang.try:x_crypto-gotip-linux-amd64-longtest,x_crypto-gotip-windows-amd64-longtest
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/554062
Run-TryBot: Bryan Mills <bcmills@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 8dbedb0..9695de7 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -178,7 +178,7 @@
// addr is the user specified host:port. While we don't actually dial it,
// we need to know this for host key matching
-func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@@ -188,13 +188,26 @@
if err != nil {
s.t.Fatalf("unixConnection: %v", err)
}
+ defer func() {
+ // Close c2 after we've started the sshd command so that it won't prevent c1
+ // from returning EOF when the sshd command exits.
+ c2.Close()
- cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
+ // Leave c1 open if we're returning a client that wraps it.
+ // (The client is responsible for closing it.)
+ // Otherwise, close it to free up the socket.
+ if client == nil {
+ c1.Close()
+ }
+ }()
+
f, err := c2.File()
if err != nil {
s.t.Fatalf("UnixConn.File: %v", err)
}
defer f.Close()
+
+ cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
cmd.Stdin = f
cmd.Stdout = f
cmd.Stderr = new(bytes.Buffer)
@@ -223,7 +236,7 @@
// processes are killed too.
cmd.Process.Signal(os.Interrupt)
cmd.Wait()
- if s.t.Failed() {
+ if s.t.Failed() || testing.Verbose() {
// log any output from sshd process
s.t.Logf("sshd:\n%s", cmd.Stderr)
}