http2: rework Ping test to rely less on timing

Pings are either expected to occur, so count until you reach your goal
before a deadline, or they do not occur, and the deadline is exceeded.

Fixes golang/go#42514

Change-Id: If9ff19ed4954bee83ddeba83a4ac9c2d43f6e1c1
Reviewed-on: https://go-review.googlesource.com/c/net/+/269797
Trust: Bryan C. Mills <bcmills@google.com>
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 900fe85..c9c948c 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3393,63 +3393,54 @@
 
 func TestTransportPingWhenReading(t *testing.T) {
 	testCases := []struct {
-		name                   string
-		readIdleTimeout        time.Duration
-		serverResponseInterval time.Duration
-		expectedPingCount      int
+		name              string
+		readIdleTimeout   time.Duration
+		deadline          time.Duration
+		expectedPingCount int
 	}{
 		{
-			name:                   "two pings in each serverResponseInterval",
-			readIdleTimeout:        400 * time.Millisecond,
-			serverResponseInterval: 1000 * time.Millisecond,
-			expectedPingCount:      4,
+			name:              "two pings",
+			readIdleTimeout:   100 * time.Millisecond,
+			deadline:          time.Second,
+			expectedPingCount: 2,
 		},
 		{
-			name:                   "one ping in each serverResponseInterval",
-			readIdleTimeout:        700 * time.Millisecond,
-			serverResponseInterval: 1000 * time.Millisecond,
-			expectedPingCount:      2,
+			name:              "zero ping",
+			readIdleTimeout:   time.Second,
+			deadline:          200 * time.Millisecond,
+			expectedPingCount: 0,
 		},
 		{
-			name:                   "zero ping in each serverResponseInterval",
-			readIdleTimeout:        1000 * time.Millisecond,
-			serverResponseInterval: 500 * time.Millisecond,
-			expectedPingCount:      0,
-		},
-		{
-			name:                   "0 readIdleTimeout means no ping",
-			readIdleTimeout:        0 * time.Millisecond,
-			serverResponseInterval: 500 * time.Millisecond,
-			expectedPingCount:      0,
+			name:              "0 readIdleTimeout means no ping",
+			readIdleTimeout:   0 * time.Millisecond,
+			deadline:          500 * time.Millisecond,
+			expectedPingCount: 0,
 		},
 	}
 
 	for _, tc := range testCases {
 		tc := tc // capture range variable
 		t.Run(tc.name, func(t *testing.T) {
-			t.Parallel()
-			testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
+			testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
 		})
 	}
 }
 
-func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
+func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
 	var pingCount int
-	clientDone := make(chan struct{})
 	ct := newClientTester(t)
 	ct.tr.PingTimeout = 10 * time.Millisecond
 	ct.tr.ReadIdleTimeout = readIdleTimeout
-	// guards the ct.fr.Write
-	var wmu sync.Mutex
 
+	ctx, cancel := context.WithTimeout(context.Background(), deadline)
+	defer cancel()
 	ct.client = func() error {
 		defer ct.cc.(*net.TCPConn).CloseWrite()
 		if runtime.GOOS == "plan9" {
 			// CloseWrite not supported on Plan 9; Issue 17906
 			defer ct.cc.(*net.TCPConn).Close()
 		}
-		defer close(clientDone)
-		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
 		res, err := ct.tr.RoundTrip(req)
 		if err != nil {
 			return fmt.Errorf("RoundTrip: %v", err)
@@ -3459,6 +3450,11 @@
 			return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
 		}
 		_, err = ioutil.ReadAll(res.Body)
+		if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
+			return nil
+		}
+
+		cancel()
 		return err
 	}
 
@@ -3466,13 +3462,12 @@
 		ct.greet()
 		var buf bytes.Buffer
 		enc := hpack.NewEncoder(&buf)
-		var wg sync.WaitGroup
-		defer wg.Wait()
+		var streamID uint32
 		for {
 			f, err := ct.fr.ReadFrame()
 			if err != nil {
 				select {
-				case <-clientDone:
+				case <-ctx.Done():
 					// If the client's done, it
 					// will have reported any
 					// errors on its side.
@@ -3494,46 +3489,24 @@
 					EndStream:     false,
 					BlockFragment: buf.Bytes(),
 				})
-
-				wg.Add(1)
-				go func() {
-					defer wg.Done()
-					for i := 0; i < 2; i++ {
-						wmu.Lock()
-						if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
-							wmu.Unlock()
-							t.Error(err)
-							return
-						}
-						wmu.Unlock()
-						time.Sleep(serverResponseInterval)
-					}
-					wmu.Lock()
-					if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
-						wmu.Unlock()
-						t.Error(err)
-						return
-					}
-					wmu.Unlock()
-				}()
+				streamID = f.StreamID
 			case *PingFrame:
 				pingCount++
-				wmu.Lock()
+				if pingCount == expectedPingCount {
+					if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
+						return err
+					}
+				}
 				if err := ct.fr.WritePing(true, f.Data); err != nil {
-					wmu.Unlock()
 					return err
 				}
-				wmu.Unlock()
+			case *RSTStreamFrame:
 			default:
 				return fmt.Errorf("Unexpected client frame %v", f)
 			}
 		}
 	}
 	ct.run()
-	if e, a := expectedPingCount, pingCount; e != a {
-		t.Errorf("expected receiving %d pings, got %d pings", e, a)
-
-	}
 }
 
 func TestTransportRetryAfterGOAWAY(t *testing.T) {