Merge pull request #529 from bradfitz/leaks

Fix test-only goroutine leaks; add leak checker to end2end tests.
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 4abe165..f1c7b42 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -34,12 +34,15 @@
 package grpc_test
 
 import (
+	"flag"
 	"fmt"
 	"io"
 	"math"
 	"net"
 	"reflect"
 	"runtime"
+	"sort"
+	"strings"
 	"sync"
 	"syscall"
 	"testing"
@@ -267,6 +270,7 @@
 const tlsDir = "testdata/"
 
 func TestReconnectTimeout(t *testing.T) {
+	defer leakCheck(t)()
 	lis, err := net.Listen("tcp", ":0")
 	if err != nil {
 		t.Fatalf("Failed to listen: %v", err)
@@ -317,19 +321,51 @@
 }
 
 type env struct {
+	name     string
 	network  string // The type of network such as tcp, unix, etc.
 	dialer   func(addr string, timeout time.Duration) (net.Conn, error)
 	security string // The security protocol such as TLS, SSH, etc.
 }
 
-func listTestEnv() []env {
-	if runtime.GOOS == "windows" {
-		return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}}
+func (e env) runnable() bool {
+	if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") {
+		return false
 	}
-	return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
+	return true
+}
+
+var (
+	tcpClearEnv  = env{name: "tcp-clear", network: "tcp"}
+	tcpTLSEnv    = env{name: "tcp-tls", network: "tcp", security: "tls"}
+	unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer}
+	unixTLSEnv   = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"}
+	allEnv       = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv}
+)
+
+var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', or 'unix-tls' to only run the tests for that environment. Empty means all.")
+
+func listTestEnv() (envs []env) {
+	if *onlyEnv != "" {
+		for _, e := range allEnv {
+			if e.name == *onlyEnv {
+				if !e.runnable() {
+					panic(fmt.Sprintf("--only_env environment %q does not run on %s", *onlyEnv, runtime.GOOS))
+				}
+				return []env{e}
+			}
+		}
+		panic(fmt.Sprintf("invalid --only_env value %q", *onlyEnv))
+	}
+	for _, e := range allEnv {
+		if e.runnable() {
+			envs = append(envs, e)
+		}
+	}
+	return envs
 }
 
 func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
+	t.Logf("Running test in %s environment...", e.name)
 	sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
 	la := ":0"
 	switch e.network {
@@ -392,6 +428,7 @@
 }
 
 func TestTimeoutOnDeadServer(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testTimeoutOnDeadServer(t, e)
 	}
@@ -434,8 +471,8 @@
 	cc.Close()
 }
 
-func healthCheck(t time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
-	ctx, _ := context.WithTimeout(context.Background(), t)
+func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
+	ctx, _ := context.WithTimeout(context.Background(), d)
 	hc := healthpb.NewHealthClient(cc)
 	req := &healthpb.HealthCheckRequest{
 		Service: serviceName,
@@ -444,6 +481,7 @@
 }
 
 func TestHealthCheckOnSuccess(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testHealthCheckOnSuccess(t, e)
 	}
@@ -461,6 +499,7 @@
 }
 
 func TestHealthCheckOnFailure(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testHealthCheckOnFailure(t, e)
 	}
@@ -478,6 +517,7 @@
 }
 
 func TestHealthCheckOff(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testHealthCheckOff(t, e)
 	}
@@ -493,6 +533,7 @@
 }
 
 func TestHealthCheckServingStatus(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testHealthCheckServingStatus(t, e)
 	}
@@ -533,6 +574,7 @@
 }
 
 func TestEmptyUnaryWithUserAgent(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testEmptyUnaryWithUserAgent(t, e)
 	}
@@ -577,6 +619,7 @@
 }
 
 func TestFailedEmptyUnary(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testFailedEmptyUnary(t, e)
 	}
@@ -594,6 +637,7 @@
 }
 
 func TestLargeUnary(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testLargeUnary(t, e)
 	}
@@ -629,6 +673,7 @@
 }
 
 func TestMetadataUnaryRPC(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testMetadataUnaryRPC(t, e)
 	}
@@ -657,10 +702,10 @@
 	if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil {
 		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
 	}
-	if !reflect.DeepEqual(testMetadata, header) {
+	if !reflect.DeepEqual(header, testMetadata) {
 		t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
 	}
-	if !reflect.DeepEqual(testMetadata, trailer) {
+	if !reflect.DeepEqual(trailer, testMetadata) {
 		t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
 	}
 }
@@ -695,6 +740,7 @@
 }
 
 func TestRetry(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testRetry(t, e)
 	}
@@ -709,9 +755,15 @@
 	tc := testpb.NewTestServiceClient(cc)
 	defer tearDown(s, cc)
 	var wg sync.WaitGroup
+
+	numRPC := 1000
+	rpcSpacing := 2 * time.Millisecond
+
 	wg.Add(1)
 	go func() {
-		time.Sleep(1 * time.Second)
+		// Halfway through starting RPCs, kill all connections:
+		time.Sleep(time.Duration(numRPC/2) * rpcSpacing)
+
 		// The server shuts down the network connection to make a
 		// transport error which will be detected by the client side
 		// code.
@@ -719,8 +771,8 @@
 		wg.Done()
 	}()
 	// All these RPCs should succeed eventually.
-	for i := 0; i < 1000; i++ {
-		time.Sleep(2 * time.Millisecond)
+	for i := 0; i < numRPC; i++ {
+		time.Sleep(rpcSpacing)
 		wg.Add(1)
 		go performOneRPC(t, tc, &wg)
 	}
@@ -728,6 +780,7 @@
 }
 
 func TestRPCTimeout(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testRPCTimeout(t, e)
 	}
@@ -762,6 +815,7 @@
 }
 
 func TestCancel(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testCancel(t, e)
 	}
@@ -794,6 +848,7 @@
 }
 
 func TestCancelNoIO(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testCancelNoIO(t, e)
 	}
@@ -847,6 +902,7 @@
 )
 
 func TestNoService(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testNoService(t, e)
 	}
@@ -858,8 +914,10 @@
 	tc := testpb.NewTestServiceClient(cc)
 	defer tearDown(s, cc)
 	// Make sure setting ack has been sent.
-	time.Sleep(2 * time.Second)
-	stream, err := tc.FullDuplexCall(context.Background())
+	time.Sleep(20 * time.Millisecond)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	stream, err := tc.FullDuplexCall(ctx)
 	if err != nil {
 		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
 	}
@@ -869,6 +927,7 @@
 }
 
 func TestPingPong(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testPingPong(t, e)
 	}
@@ -927,6 +986,7 @@
 }
 
 func TestMetadataStreamingRPC(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testMetadataStreamingRPC(t, e)
 	}
@@ -994,6 +1054,7 @@
 }
 
 func TestServerStreaming(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testServerStreaming(t, e)
 	}
@@ -1047,6 +1108,7 @@
 }
 
 func TestFailedServerStreaming(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testFailedServerStreaming(t, e)
 	}
@@ -1078,6 +1140,7 @@
 }
 
 func TestClientStreaming(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testClientStreaming(t, e)
 	}
@@ -1118,6 +1181,7 @@
 }
 
 func TestExceedMaxStreamsLimit(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testExceedMaxStreamsLimit(t, e)
 	}
@@ -1129,13 +1193,16 @@
 	cc := clientSetUp(t, addr, nil, nil, "", e)
 	tc := testpb.NewTestServiceClient(cc)
 	defer tearDown(s, cc)
-	_, err := tc.StreamingInputCall(context.Background())
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	_, err := tc.StreamingInputCall(ctx)
 	if err != nil {
 		t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
 	}
 	// Loop until receiving the new max stream setting from the server.
 	for {
-		ctx, _ := context.WithTimeout(context.Background(), time.Second)
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+		defer cancel()
 		_, err := tc.StreamingInputCall(ctx)
 		if err == nil {
 			time.Sleep(time.Second)
@@ -1149,6 +1216,7 @@
 }
 
 func TestCompressServerHasNoSupport(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testCompressServerHasNoSupport(t, e)
 	}
@@ -1202,6 +1270,7 @@
 }
 
 func TestCompressOK(t *testing.T) {
+	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
 		testCompressOK(t, e)
 	}
@@ -1228,7 +1297,9 @@
 		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
 	}
 	// Streaming RPC
-	stream, err := tc.FullDuplexCall(context.Background())
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	stream, err := tc.FullDuplexCall(ctx)
 	if err != nil {
 		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
 	}
@@ -1253,3 +1324,72 @@
 		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
 	}
 }
+
+// interestingGoroutines returns all goroutines we care about for the purpose
+// of leak checking. It excludes testing or runtime ones.
+func interestingGoroutines() (gs []string) {
+	buf := make([]byte, 2<<20)
+	buf = buf[:runtime.Stack(buf, true)]
+	for _, g := range strings.Split(string(buf), "\n\n") {
+		sl := strings.SplitN(g, "\n", 2)
+		if len(sl) != 2 {
+			continue
+		}
+		stack := strings.TrimSpace(sl[1])
+		if strings.HasPrefix(stack, "testing.RunTests") {
+			continue
+		}
+
+		if stack == "" ||
+			strings.Contains(stack, "testing.Main(") ||
+			strings.Contains(stack, "runtime.goexit") ||
+			strings.Contains(stack, "created by runtime.gc") ||
+			strings.Contains(stack, "interestingGoroutines") ||
+			strings.Contains(stack, "runtime.MHeap_Scavenger") {
+			continue
+		}
+		gs = append(gs, g)
+	}
+	sort.Strings(gs)
+	return
+}
+
+var failOnLeaks = flag.Bool("fail_on_leaks", false, "Fail tests if goroutines leak.")
+
+// leakCheck snapshots the currently-running goroutines and returns a
+// function to be run at the end of tests to see whether any
+// goroutines leaked.
+func leakCheck(t testing.TB) func() {
+	orig := map[string]bool{}
+	for _, g := range interestingGoroutines() {
+		orig[g] = true
+	}
+	leakf := t.Logf
+	if *failOnLeaks {
+		leakf = t.Errorf
+	}
+	return func() {
+		// Loop, waiting for goroutines to shut down.
+		// Wait up to 5 seconds, but finish as quickly as possible.
+		deadline := time.Now().Add(5 * time.Second)
+		for {
+			var leaked []string
+			for _, g := range interestingGoroutines() {
+				if !orig[g] {
+					leaked = append(leaked, g)
+				}
+			}
+			if len(leaked) == 0 {
+				return
+			}
+			if time.Now().Before(deadline) {
+				time.Sleep(50 * time.Millisecond)
+				continue
+			}
+			for _, g := range leaked {
+				leakf("Leaked goroutine: %v", g)
+			}
+			return
+		}
+	}
+}