http2: rework Ping test to rely less on timing
Pings are either expected to occur, so count until you reach your goal
before a deadline, or they do not occur, and the deadline is exceeded.
Fixes golang/go#42514
Change-Id: If9ff19ed4954bee83ddeba83a4ac9c2d43f6e1c1
Reviewed-on: https://go-review.googlesource.com/c/net/+/269797
Trust: Bryan C. Mills <bcmills@google.com>
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 900fe85..c9c948c 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3393,63 +3393,54 @@
func TestTransportPingWhenReading(t *testing.T) {
testCases := []struct {
- name string
- readIdleTimeout time.Duration
- serverResponseInterval time.Duration
- expectedPingCount int
+ name string
+ readIdleTimeout time.Duration
+ deadline time.Duration
+ expectedPingCount int
}{
{
- name: "two pings in each serverResponseInterval",
- readIdleTimeout: 400 * time.Millisecond,
- serverResponseInterval: 1000 * time.Millisecond,
- expectedPingCount: 4,
+ name: "two pings",
+ readIdleTimeout: 100 * time.Millisecond,
+ deadline: time.Second,
+ expectedPingCount: 2,
},
{
- name: "one ping in each serverResponseInterval",
- readIdleTimeout: 700 * time.Millisecond,
- serverResponseInterval: 1000 * time.Millisecond,
- expectedPingCount: 2,
+ name: "zero ping",
+ readIdleTimeout: time.Second,
+ deadline: 200 * time.Millisecond,
+ expectedPingCount: 0,
},
{
- name: "zero ping in each serverResponseInterval",
- readIdleTimeout: 1000 * time.Millisecond,
- serverResponseInterval: 500 * time.Millisecond,
- expectedPingCount: 0,
- },
- {
- name: "0 readIdleTimeout means no ping",
- readIdleTimeout: 0 * time.Millisecond,
- serverResponseInterval: 500 * time.Millisecond,
- expectedPingCount: 0,
+ name: "0 readIdleTimeout means no ping",
+ readIdleTimeout: 0 * time.Millisecond,
+ deadline: 500 * time.Millisecond,
+ expectedPingCount: 0,
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
+ testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
})
}
}
-func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
+func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
var pingCount int
- clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 10 * time.Millisecond
ct.tr.ReadIdleTimeout = readIdleTimeout
- // guards the ct.fr.Write
- var wmu sync.Mutex
+ ctx, cancel := context.WithTimeout(context.Background(), deadline)
+ defer cancel()
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
if runtime.GOOS == "plan9" {
// CloseWrite not supported on Plan 9; Issue 17906
defer ct.cc.(*net.TCPConn).Close()
}
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
res, err := ct.tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
@@ -3459,6 +3450,11 @@
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
}
_, err = ioutil.ReadAll(res.Body)
+ if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
+ return nil
+ }
+
+ cancel()
return err
}
@@ -3466,13 +3462,12 @@
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
- var wg sync.WaitGroup
- defer wg.Wait()
+ var streamID uint32
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
- case <-clientDone:
+ case <-ctx.Done():
// If the client's done, it
// will have reported any
// errors on its side.
@@ -3494,46 +3489,24 @@
EndStream: false,
BlockFragment: buf.Bytes(),
})
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- for i := 0; i < 2; i++ {
- wmu.Lock()
- if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
- wmu.Unlock()
- t.Error(err)
- return
- }
- wmu.Unlock()
- time.Sleep(serverResponseInterval)
- }
- wmu.Lock()
- if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
- wmu.Unlock()
- t.Error(err)
- return
- }
- wmu.Unlock()
- }()
+ streamID = f.StreamID
case *PingFrame:
pingCount++
- wmu.Lock()
+ if pingCount == expectedPingCount {
+ if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
+ return err
+ }
+ }
if err := ct.fr.WritePing(true, f.Data); err != nil {
- wmu.Unlock()
return err
}
- wmu.Unlock()
+ case *RSTStreamFrame:
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
- if e, a := expectedPingCount, pingCount; e != a {
- t.Errorf("expected receiving %d pings, got %d pings", e, a)
-
- }
}
func TestTransportRetryAfterGOAWAY(t *testing.T) {