net: deflake TestDialTimeoutFDLeak

This change makes TestDialTimeoutFDLeak work on almost all the supported
platforms.

Updates #4384.

Change-Id: I3608f438003003f9b7cfa17c9e5fe7077700fd60
Reviewed-on: https://go-review.googlesource.com/8392
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/src/net/dial_test.go b/src/net/dial_test.go
index 39e61d5..448faac 100644
--- a/src/net/dial_test.go
+++ b/src/net/dial_test.go
@@ -7,7 +7,7 @@
 import (
 	"bytes"
 	"fmt"
-	"io"
+	"net/internal/socktest"
 	"os"
 	"os/exec"
 	"reflect"
@@ -179,61 +179,51 @@
 }
 
 func TestDialTimeoutFDLeak(t *testing.T) {
-	if runtime.GOOS != "linux" {
-		// TODO(bradfitz): test on other platforms
-		t.Skipf("skipping test on %q", runtime.GOOS)
+	switch runtime.GOOS {
+	case "plan9":
+		t.Skipf("%s does not have full support of socktest", runtime.GOOS)
 	}
 
-	ln := newLocalListener(t)
-	defer ln.Close()
+	const T = 100 * time.Millisecond
 
-	type connErr struct {
-		conn Conn
-		err  error
+	switch runtime.GOOS {
+	case "plan9", "windows":
+		origTestHookDialChannel := testHookDialChannel
+		testHookDialChannel = func() { time.Sleep(2 * T) }
+		defer func() { testHookDialChannel = origTestHookDialChannel }()
+		if runtime.GOOS == "plan9" {
+			break
+		}
+		fallthrough
+	default:
+		sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
+			time.Sleep(2 * T)
+			return nil, errTimeout
+		})
+		defer sw.Set(socktest.FilterConnect, nil)
 	}
-	dials := listenerBacklog + 100
-	// used to be listenerBacklog + 5, but was found to be unreliable, issue 4384.
-	maxGoodConnect := listenerBacklog + runtime.NumCPU()*10
-	resc := make(chan connErr)
-	for i := 0; i < dials; i++ {
+
+	before := sw.Sockets()
+	const N = 100
+	var wg sync.WaitGroup
+	wg.Add(N)
+	for i := 0; i < N; i++ {
 		go func() {
-			conn, err := DialTimeout("tcp", ln.Addr().String(), 500*time.Millisecond)
-			resc <- connErr{conn, err}
+			defer wg.Done()
+			// This dial never starts to send any SYN
+			// segment because of above socket filter and
+			// test hook.
+			c, err := DialTimeout("tcp", "127.0.0.1:0", T)
+			if err == nil {
+				t.Errorf("unexpectedly established: tcp:%s->%s", c.LocalAddr(), c.RemoteAddr())
+				c.Close()
+			}
 		}()
 	}
-
-	var firstErr string
-	var ngood int
-	var toClose []io.Closer
-	for i := 0; i < dials; i++ {
-		ce := <-resc
-		if ce.err == nil {
-			ngood++
-			if ngood > maxGoodConnect {
-				t.Errorf("%d good connects; expected at most %d", ngood, maxGoodConnect)
-			}
-			toClose = append(toClose, ce.conn)
-			continue
-		}
-		err := ce.err
-		if firstErr == "" {
-			firstErr = err.Error()
-		} else if err.Error() != firstErr {
-			t.Fatalf("inconsistent error messages: first was %q, then later %q", firstErr, err)
-		}
-	}
-	for _, c := range toClose {
-		c.Close()
-	}
-	for i := 0; i < 100; i++ {
-		if got := numFD(); got < dials {
-			// Test passes.
-			return
-		}
-		time.Sleep(10 * time.Millisecond)
-	}
-	if got := numFD(); got >= dials {
-		t.Errorf("num fds after %d timeouts = %d; want <%d", dials, got, dials)
+	wg.Wait()
+	after := sw.Sockets()
+	if len(after) != len(before) {
+		t.Errorf("got %d; want %d", len(after), len(before))
 	}
 }
 
@@ -329,23 +319,6 @@
 	}
 }
 
-func numFD() int {
-	if runtime.GOOS == "linux" {
-		f, err := os.Open("/proc/self/fd")
-		if err != nil {
-			panic(err)
-		}
-		defer f.Close()
-		names, err := f.Readdirnames(0)
-		if err != nil {
-			panic(err)
-		}
-		return len(names)
-	}
-	// All tests using this should be skipped anyway, but:
-	panic("numFDs not implemented on " + runtime.GOOS)
-}
-
 func TestDialer(t *testing.T) {
 	ln, err := Listen("tcp4", "127.0.0.1:0")
 	if err != nil {