http2: perform connection health check

After the connection has been idle for a while, periodic pings are sent
over the connection to check its health. Unhealthy connection is closed
and removed from the connection pool.

Fixes golang/go#31643

Change-Id: Idbbc9cb2d3e26c39f84a33e945e482d41cd8583c
GitHub-Last-Rev: 36607fe185ce6684e9900403f82a298ad5567650
GitHub-Pull-Request: golang/net#55
Reviewed-on: https://go-review.googlesource.com/c/net/+/198040
Run-TryBot: Andrew Bonventre <andybons@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 54acc1e..76a92e0 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -108,6 +108,19 @@
 	// waiting for their turn.
 	StrictMaxConcurrentStreams bool
 
+	// ReadIdleTimeout is the timeout after which a health check using ping
+	// frame will be carried out if no frame is received on the connection.
+	// Note that a ping response will is considered a received frame, so if
+	// there is no other traffic on the connection, the health check will
+	// be performed every ReadIdleTimeout interval.
+	// If zero, no health check is performed.
+	ReadIdleTimeout time.Duration
+
+	// PingTimeout is the timeout after which the connection will be closed
+	// if a response to Ping is not received.
+	// Defaults to 15s.
+	PingTimeout time.Duration
+
 	// t1, if non-nil, is the standard library Transport using
 	// this transport. Its settings are used (but not its
 	// RoundTrip method, etc).
@@ -131,6 +144,14 @@
 	return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
 }
 
+func (t *Transport) pingTimeout() time.Duration {
+	if t.PingTimeout == 0 {
+		return 15 * time.Second
+	}
+	return t.PingTimeout
+
+}
+
 // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
 // It returns an error if t1 has already been HTTP/2-enabled.
 func ConfigureTransport(t1 *http.Transport) error {
@@ -675,6 +696,20 @@
 	return cc, nil
 }
 
+func (cc *ClientConn) healthCheck() {
+	pingTimeout := cc.t.pingTimeout()
+	// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
+	// trigger the healthCheck again if there is no frame received.
+	ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+	defer cancel()
+	err := cc.Ping(ctx)
+	if err != nil {
+		cc.closeForLostPing()
+		cc.t.connPool().MarkDead(cc)
+		return
+	}
+}
+
 func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
@@ -846,14 +881,12 @@
 	return nil
 }
 
-// Close closes the client connection immediately.
-//
-// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
-func (cc *ClientConn) Close() error {
+// closes the client connection immediately. In-flight requests are interrupted.
+// err is sent to streams.
+func (cc *ClientConn) closeForError(err error) error {
 	cc.mu.Lock()
 	defer cc.cond.Broadcast()
 	defer cc.mu.Unlock()
-	err := errors.New("http2: client connection force closed via ClientConn.Close")
 	for id, cs := range cc.streams {
 		select {
 		case cs.resc <- resAndError{err: err}:
@@ -866,6 +899,20 @@
 	return cc.tconn.Close()
 }
 
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *ClientConn) Close() error {
+	err := errors.New("http2: client connection force closed via ClientConn.Close")
+	return cc.closeForError(err)
+}
+
+// closes the client connection immediately. In-flight requests are interrupted.
+func (cc *ClientConn) closeForLostPing() error {
+	err := errors.New("http2: client connection lost")
+	return cc.closeForError(err)
+}
+
 const maxAllocFrameSize = 512 << 10
 
 // frameBuffer returns a scratch buffer suitable for writing DATA frames.
@@ -1737,8 +1784,17 @@
 	rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
 	gotReply := false // ever saw a HEADERS reply
 	gotSettings := false
+	readIdleTimeout := cc.t.ReadIdleTimeout
+	var t *time.Timer
+	if readIdleTimeout != 0 {
+		t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
+		defer t.Stop()
+	}
 	for {
 		f, err := cc.fr.ReadFrame()
+		if t != nil {
+			t.Reset(readIdleTimeout)
+		}
 		if err != nil {
 			cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
 		}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index a17aa4a..23b4989 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3309,6 +3309,166 @@
 	req.Header = http.Header{}
 }
 
+func TestTransportCloseAfterLostPing(t *testing.T) {
+	clientDone := make(chan struct{})
+	ct := newClientTester(t)
+	ct.tr.PingTimeout = 1 * time.Second
+	ct.tr.ReadIdleTimeout = 1 * time.Second
+	ct.client = func() error {
+		defer ct.cc.(*net.TCPConn).CloseWrite()
+		defer close(clientDone)
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		_, err := ct.tr.RoundTrip(req)
+		if err == nil || !strings.Contains(err.Error(), "client connection lost") {
+			return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		<-clientDone
+		return nil
+	}
+	ct.run()
+}
+
+func TestTransportPingWhenReading(t *testing.T) {
+	testCases := []struct {
+		name                   string
+		readIdleTimeout        time.Duration
+		serverResponseInterval time.Duration
+		expectedPingCount      int
+	}{
+		{
+			name:                   "two pings in each serverResponseInterval",
+			readIdleTimeout:        400 * time.Millisecond,
+			serverResponseInterval: 1000 * time.Millisecond,
+			expectedPingCount:      4,
+		},
+		{
+			name:                   "one ping in each serverResponseInterval",
+			readIdleTimeout:        700 * time.Millisecond,
+			serverResponseInterval: 1000 * time.Millisecond,
+			expectedPingCount:      2,
+		},
+		{
+			name:                   "zero ping in each serverResponseInterval",
+			readIdleTimeout:        1000 * time.Millisecond,
+			serverResponseInterval: 500 * time.Millisecond,
+			expectedPingCount:      0,
+		},
+		{
+			name:                   "0 readIdleTimeout means no ping",
+			readIdleTimeout:        0 * time.Millisecond,
+			serverResponseInterval: 500 * time.Millisecond,
+			expectedPingCount:      0,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc // capture range variable
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+			testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
+		})
+	}
+}
+
+func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
+	var pingCount int
+	clientDone := make(chan struct{})
+	ct := newClientTester(t)
+	ct.tr.PingTimeout = 10 * time.Millisecond
+	ct.tr.ReadIdleTimeout = readIdleTimeout
+	// guards the ct.fr.Write
+	var wmu sync.Mutex
+
+	ct.client = func() error {
+		defer ct.cc.(*net.TCPConn).CloseWrite()
+		defer close(clientDone)
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		if err != nil {
+			return fmt.Errorf("RoundTrip: %v", err)
+		}
+		defer res.Body.Close()
+		if res.StatusCode != 200 {
+			return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
+		}
+		_, err = ioutil.ReadAll(res.Body)
+		return err
+	}
+
+	ct.server = func() error {
+		ct.greet()
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-clientDone:
+					// If the client's done, it
+					// will have reported any
+					// errors on its side.
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+				if !f.HeadersEnded() {
+					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+				}
+				enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
+				ct.fr.WriteHeaders(HeadersFrameParam{
+					StreamID:      f.StreamID,
+					EndHeaders:    true,
+					EndStream:     false,
+					BlockFragment: buf.Bytes(),
+				})
+
+				go func() {
+					for i := 0; i < 2; i++ {
+						wmu.Lock()
+						if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
+							wmu.Unlock()
+							t.Error(err)
+							return
+						}
+						wmu.Unlock()
+						time.Sleep(serverResponseInterval)
+					}
+					wmu.Lock()
+					if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
+						wmu.Unlock()
+						t.Error(err)
+						return
+					}
+					wmu.Unlock()
+				}()
+			case *PingFrame:
+				pingCount++
+				wmu.Lock()
+				if err := ct.fr.WritePing(true, f.Data); err != nil {
+					wmu.Unlock()
+					return err
+				}
+				wmu.Unlock()
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+	}
+	ct.run()
+	if e, a := expectedPingCount, pingCount; e != a {
+		t.Errorf("expected receiving %d pings, got %d pings", e, a)
+
+	}
+}
+
 func TestTransportRetryAfterGOAWAY(t *testing.T) {
 	var dialer struct {
 		sync.Mutex