http2: perform connection health check
After the connection has been idle for a while, periodic pings are sent
over the connection to check its health. Unhealthy connection is closed
and removed from the connection pool.
Fixes golang/go#31643
Change-Id: Idbbc9cb2d3e26c39f84a33e945e482d41cd8583c
GitHub-Last-Rev: 36607fe185ce6684e9900403f82a298ad5567650
GitHub-Pull-Request: golang/net#55
Reviewed-on: https://go-review.googlesource.com/c/net/+/198040
Run-TryBot: Andrew Bonventre <andybons@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 54acc1e..76a92e0 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -108,6 +108,19 @@
// waiting for their turn.
StrictMaxConcurrentStreams bool
+ // ReadIdleTimeout is the timeout after which a health check using ping
+ // frame will be carried out if no frame is received on the connection.
+ // Note that a ping response will is considered a received frame, so if
+ // there is no other traffic on the connection, the health check will
+ // be performed every ReadIdleTimeout interval.
+ // If zero, no health check is performed.
+ ReadIdleTimeout time.Duration
+
+ // PingTimeout is the timeout after which the connection will be closed
+ // if a response to Ping is not received.
+ // Defaults to 15s.
+ PingTimeout time.Duration
+
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
@@ -131,6 +144,14 @@
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
+func (t *Transport) pingTimeout() time.Duration {
+ if t.PingTimeout == 0 {
+ return 15 * time.Second
+ }
+ return t.PingTimeout
+
+}
+
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
func ConfigureTransport(t1 *http.Transport) error {
@@ -675,6 +696,20 @@
return cc, nil
}
+func (cc *ClientConn) healthCheck() {
+ pingTimeout := cc.t.pingTimeout()
+ // We don't need to periodically ping in the health check, because the readLoop of ClientConn will
+ // trigger the healthCheck again if there is no frame received.
+ ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ defer cancel()
+ err := cc.Ping(ctx)
+ if err != nil {
+ cc.closeForLostPing()
+ cc.t.connPool().MarkDead(cc)
+ return
+ }
+}
+
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -846,14 +881,12 @@
return nil
}
-// Close closes the client connection immediately.
-//
-// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
-func (cc *ClientConn) Close() error {
+// closes the client connection immediately. In-flight requests are interrupted.
+// err is sent to streams.
+func (cc *ClientConn) closeForError(err error) error {
cc.mu.Lock()
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
- err := errors.New("http2: client connection force closed via ClientConn.Close")
for id, cs := range cc.streams {
select {
case cs.resc <- resAndError{err: err}:
@@ -866,6 +899,20 @@
return cc.tconn.Close()
}
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *ClientConn) Close() error {
+ err := errors.New("http2: client connection force closed via ClientConn.Close")
+ return cc.closeForError(err)
+}
+
+// closes the client connection immediately. In-flight requests are interrupted.
+func (cc *ClientConn) closeForLostPing() error {
+ err := errors.New("http2: client connection lost")
+ return cc.closeForError(err)
+}
+
const maxAllocFrameSize = 512 << 10
// frameBuffer returns a scratch buffer suitable for writing DATA frames.
@@ -1737,8 +1784,17 @@
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
gotReply := false // ever saw a HEADERS reply
gotSettings := false
+ readIdleTimeout := cc.t.ReadIdleTimeout
+ var t *time.Timer
+ if readIdleTimeout != 0 {
+ t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
+ defer t.Stop()
+ }
for {
f, err := cc.fr.ReadFrame()
+ if t != nil {
+ t.Reset(readIdleTimeout)
+ }
if err != nil {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index a17aa4a..23b4989 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3309,6 +3309,166 @@
req.Header = http.Header{}
}
+func TestTransportCloseAfterLostPing(t *testing.T) {
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.tr.PingTimeout = 1 * time.Second
+ ct.tr.ReadIdleTimeout = 1 * time.Second
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ _, err := ct.tr.RoundTrip(req)
+ if err == nil || !strings.Contains(err.Error(), "client connection lost") {
+ return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ <-clientDone
+ return nil
+ }
+ ct.run()
+}
+
+func TestTransportPingWhenReading(t *testing.T) {
+ testCases := []struct {
+ name string
+ readIdleTimeout time.Duration
+ serverResponseInterval time.Duration
+ expectedPingCount int
+ }{
+ {
+ name: "two pings in each serverResponseInterval",
+ readIdleTimeout: 400 * time.Millisecond,
+ serverResponseInterval: 1000 * time.Millisecond,
+ expectedPingCount: 4,
+ },
+ {
+ name: "one ping in each serverResponseInterval",
+ readIdleTimeout: 700 * time.Millisecond,
+ serverResponseInterval: 1000 * time.Millisecond,
+ expectedPingCount: 2,
+ },
+ {
+ name: "zero ping in each serverResponseInterval",
+ readIdleTimeout: 1000 * time.Millisecond,
+ serverResponseInterval: 500 * time.Millisecond,
+ expectedPingCount: 0,
+ },
+ {
+ name: "0 readIdleTimeout means no ping",
+ readIdleTimeout: 0 * time.Millisecond,
+ serverResponseInterval: 500 * time.Millisecond,
+ expectedPingCount: 0,
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
+ })
+ }
+}
+
+func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
+ var pingCount int
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.tr.PingTimeout = 10 * time.Millisecond
+ ct.tr.ReadIdleTimeout = readIdleTimeout
+ // guards the ct.fr.Write
+ var wmu sync.Mutex
+
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ return err
+ }
+
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+
+ go func() {
+ for i := 0; i < 2; i++ {
+ wmu.Lock()
+ if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
+ wmu.Unlock()
+ t.Error(err)
+ return
+ }
+ wmu.Unlock()
+ time.Sleep(serverResponseInterval)
+ }
+ wmu.Lock()
+ if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
+ wmu.Unlock()
+ t.Error(err)
+ return
+ }
+ wmu.Unlock()
+ }()
+ case *PingFrame:
+ pingCount++
+ wmu.Lock()
+ if err := ct.fr.WritePing(true, f.Data); err != nil {
+ wmu.Unlock()
+ return err
+ }
+ wmu.Unlock()
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+ if e, a := expectedPingCount, pingCount; e != a {
+ t.Errorf("expected receiving %d pings, got %d pings", e, a)
+
+ }
+}
+
func TestTransportRetryAfterGOAWAY(t *testing.T) {
var dialer struct {
sync.Mutex