http2: don't reuse connections that are experiencing errors
When a request on a connection fails to complete successfully,
mark the conn as doNotReuse. It's possible for requests to
fail for reasons unrelated to connection health,
but opening a new connection unnecessarily is less of an
impact than reusing a dead connection.
Fixes golang/go#59690
Change-Id: I40bf6cefae602ead70c3bcf2fe573cc13f34a385
Reviewed-on: https://go-review.googlesource.com/c/net/+/486156
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
diff --git a/http2/transport.go b/http2/transport.go
index f965579..ac90a26 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1266,6 +1266,27 @@
return res, nil
}
+ cancelRequest := func(cs *clientStream, err error) error {
+ cs.cc.mu.Lock()
+ defer cs.cc.mu.Unlock()
+ cs.abortStreamLocked(err)
+ if cs.ID != 0 {
+ // This request may have failed because of a problem with the connection,
+ // or for some unrelated reason. (For example, the user might have canceled
+ // the request without waiting for a response.) Mark the connection as
+ // not reusable, since trying to reuse a dead connection is worse than
+ // unnecessarily creating a new one.
+ //
+ // If cs.ID is 0, then the request was never allocated a stream ID and
+ // whatever went wrong was unrelated to the connection. We might have
+ // timed out waiting for a stream slot when StrictMaxConcurrentStreams
+ // is set, for example, in which case retrying on a different connection
+ // will not help.
+ cs.cc.doNotReuse = true
+ }
+ return err
+ }
+
for {
select {
case <-cs.respHeaderRecv:
@@ -1280,15 +1301,12 @@
return handleResponseHeaders()
default:
waitDone()
- return nil, cs.abortErr
+ return nil, cancelRequest(cs, cs.abortErr)
}
case <-ctx.Done():
- err := ctx.Err()
- cs.abortStream(err)
- return nil, err
+ return nil, cancelRequest(cs, ctx.Err())
case <-cs.reqCancel:
- cs.abortStream(errRequestCanceled)
- return nil, errRequestCanceled
+ return nil, cancelRequest(cs, errRequestCanceled)
}
}
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 5adef42..54d4551 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -775,7 +775,6 @@
cc, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
-
}
sc, err := ln.Accept()
if err != nil {
@@ -1765,6 +1764,18 @@
defer tr.CloseIdleConnections()
checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
+ // Make an arbitrary request to ensure we get the server's
+ // settings frame and initialize peerMaxHeaderListSize.
+ req0, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatalf("newRequest: NewRequest: %v", err)
+ }
+ res0, err := tr.RoundTrip(req0)
+ if err != nil {
+ t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
+ }
+ res0.Body.Close()
+
res, err := tr.RoundTrip(req)
if err != wantErr {
if res != nil {
@@ -1825,13 +1836,9 @@
return req
}
- // Make an arbitrary request to ensure we get the server's
- // settings frame and initialize peerMaxHeaderListSize.
+ // Validate peerMaxHeaderListSize.
req := newRequest()
checkRoundTrip(req, nil, "Initial request")
-
- // Get the ClientConn associated with the request and validate
- // peerMaxHeaderListSize.
addr := authorityAddr(req.URL.Scheme, req.URL.Host)
cc, err := tr.connPool().GetClientConn(req, addr)
if err != nil {
@@ -3738,35 +3745,33 @@
ct.run()
}
-func TestTransportRetryAfterGOAWAY(t *testing.T) {
- var dialer struct {
- sync.Mutex
- count int
- }
- ct1 := make(chan *clientTester)
- ct2 := make(chan *clientTester)
-
+func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
ln := newLocalListener(t)
defer ln.Close()
+ var (
+ mu sync.Mutex
+ count int
+ conns []net.Conn
+ )
+ var wg sync.WaitGroup
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
}
tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- dialer.Lock()
- defer dialer.Unlock()
- dialer.count++
- if dialer.count == 3 {
- return nil, errors.New("unexpected number of dials")
- }
+ mu.Lock()
+ defer mu.Unlock()
+ count++
cc, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
return nil, fmt.Errorf("dial error: %v", err)
}
+ conns = append(conns, cc)
sc, err := ln.Accept()
if err != nil {
return nil, fmt.Errorf("accept error: %v", err)
}
+ conns = append(conns, sc)
ct := &clientTester{
t: t,
tr: tr,
@@ -3774,19 +3779,26 @@
sc: sc,
fr: NewFramer(sc, sc),
}
- switch dialer.count {
- case 1:
- ct1 <- ct
- case 2:
- ct2 <- ct
- }
+ wg.Add(1)
+ go func(count int) {
+ defer wg.Done()
+ server(count, ct)
+ sc.Close()
+ }(count)
return cc, nil
}
- errs := make(chan error, 3)
+ client(tr)
+ tr.CloseIdleConnections()
+ ln.Close()
+ for _, c := range conns {
+ c.Close()
+ }
+ wg.Wait()
+}
- // Client.
- go func() {
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
+ client := func(tr *Transport) {
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
res, err := tr.RoundTrip(req)
if res != nil {
@@ -3796,102 +3808,76 @@
}
}
if err != nil {
- err = fmt.Errorf("RoundTrip: %v", err)
- }
- errs <- err
- }()
-
- connToClose := make(chan io.Closer, 2)
-
- // Server for the first request.
- go func() {
- ct := <-ct1
-
- connToClose <- ct.cc
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server1 got %v", hf)
- if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
- errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
- return
- }
- errs <- nil
- }()
-
- // Server for the second request.
- go func() {
- ct := <-ct2
-
- connToClose <- ct.cc
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server2 got %v", hf)
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- err = ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- if err != nil {
- errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
- } else {
- errs <- nil
- }
- }()
-
- for k := 0; k < 3; k++ {
- err := <-errs
- if err != nil {
- t.Error(err)
+ t.Errorf("RoundTrip: %v", err)
}
}
- close(connToClose)
- for c := range connToClose {
- c.Close()
+ server := func(count int, ct *clientTester) {
+ switch count {
+ case 1:
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ t.Errorf("server1 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server1 got %v", hf)
+ if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
+ t.Errorf("server1 failed writing GOAWAY: %v", err)
+ return
+ }
+ case 2:
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ t.Errorf("server2 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server2 got %v", hf)
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+ err = ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ if err != nil {
+ t.Errorf("server2 failed writing response HEADERS: %v", err)
+ }
+ default:
+ t.Errorf("unexpected number of dials")
+ return
+ }
}
+
+ testClientMultipleDials(t, client, server)
}
func TestTransportRetryAfterRefusedStream(t *testing.T) {
clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
+ client := func(tr *Transport) {
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := ct.tr.RoundTrip(req)
+ resp, err := tr.RoundTrip(req)
if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
+ t.Errorf("RoundTrip: %v", err)
+ return
}
resp.Body.Close()
if resp.StatusCode != 204 {
- return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ t.Errorf("Status = %v; want 204", resp.StatusCode)
+ return
}
- return nil
}
- ct.server = func() error {
+
+ server := func(count int, ct *clientTester) {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
- nreq := 0
-
for {
f, err := ct.fr.ReadFrame()
if err != nil {
@@ -3900,19 +3886,19 @@
// If the client's done, it
// will have reported any
// errors on its side.
- return nil
default:
- return err
+ t.Error(err)
}
+ return
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ t.Errorf("headers should have END_HEADERS be ended: %v", f)
+ return
}
- nreq++
- if nreq == 1 {
+ if count == 1 {
ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
} else {
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
@@ -3924,11 +3910,13 @@
})
}
default:
- return fmt.Errorf("Unexpected client frame %v", f)
+ t.Errorf("Unexpected client frame %v", f)
+ return
}
}
}
- ct.run()
+
+ testClientMultipleDials(t, client, server)
}
func TestTransportRetryHasLimit(t *testing.T) {
@@ -4143,6 +4131,7 @@
greet := make(chan struct{}) // server sends initial SETTINGS frame
gotRequest := make(chan struct{}) // server received a request
clientDone := make(chan struct{})
+ cancelClientRequest := make(chan struct{})
// Collect errors from goroutines.
var wg sync.WaitGroup
@@ -4221,9 +4210,8 @@
req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
if k == maxConcurrent {
// This request will be canceled.
- cancel := make(chan struct{})
- req.Cancel = cancel
- close(cancel)
+ req.Cancel = cancelClientRequest
+ close(cancelClientRequest)
_, err := ct.tr.RoundTrip(req)
close(clientRequestCancelled)
if err == nil {
@@ -5986,14 +5974,21 @@
}
func TestClientConnReservations(t *testing.T) {
- cc := &ClientConn{
- reqHeaderMu: make(chan struct{}, 1),
- streams: make(map[uint32]*clientStream),
- maxConcurrentStreams: initialMaxConcurrentStreams,
- nextStreamID: 1,
- t: &Transport{},
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.MaxConcurrentStreams = initialMaxConcurrentStreams
+ })
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ cc, err := tr.newClientConn(st.cc, false)
+ if err != nil {
+ t.Fatal(err)
}
- cc.cond = sync.NewCond(&cc.mu)
+
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
n := 0
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
n++
@@ -6001,8 +5996,8 @@
if n != initialMaxConcurrentStreams {
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
}
- if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
- t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
+ if _, err := cc.RoundTrip(req); err != nil {
+ t.Fatalf("RoundTrip error = %v", err)
}
n2 := 0
for n2 <= 5 && cc.ReserveNewRequest() {
@@ -6014,7 +6009,7 @@
// Use up all the reservations
for i := 0; i < n; i++ {
- cc.RoundTrip(new(http.Request))
+ cc.RoundTrip(req)
}
n2 = 0
@@ -6370,3 +6365,95 @@
}
res.Body.Close()
}
+
+type blockReadConn struct {
+ net.Conn
+ blockc chan struct{}
+}
+
+func (c *blockReadConn) Read(b []byte) (n int, err error) {
+ <-c.blockc
+ return c.Conn.Read(b)
+}
+
+func TestTransportReuseAfterError(t *testing.T) {
+ serverReqc := make(chan struct{}, 3)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ serverReqc <- struct{}{}
+ }, optOnlyServer)
+ defer st.Close()
+
+ var (
+ unblockOnce sync.Once
+ blockc = make(chan struct{})
+ connCountMu sync.Mutex
+ connCount int
+ )
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ // The first connection dialed will block on reads until blockc is closed.
+ connCountMu.Lock()
+ defer connCountMu.Unlock()
+ connCount++
+ conn, err := tls.Dial(network, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ if connCount == 1 {
+ return &blockReadConn{
+ Conn: conn,
+ blockc: blockc,
+ }, nil
+ }
+ return conn, nil
+ },
+ }
+ defer tr.CloseIdleConnections()
+ defer unblockOnce.Do(func() {
+ // Ensure that reads on blockc are unblocked if we return early.
+ close(blockc)
+ })
+
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
+
+ // Request 1 is made on conn 1.
+ // Reading the response will block.
+ // Wait until the server receives the request, and continue.
+ req1c := make(chan struct{})
+ go func() {
+ defer close(req1c)
+ res1, err := tr.RoundTrip(req.Clone(context.Background()))
+ if err != nil {
+ t.Errorf("request 1: %v", err)
+ } else {
+ res1.Body.Close()
+ }
+ }()
+ <-serverReqc
+
+ // Request 2 is also made on conn 1.
+ // Reading the response will block.
+ // The request fails when the context deadline expires.
+ // Conn 1 should now be flagged as unfit for reuse.
+ timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+ _, err := tr.RoundTrip(req.Clone(timeoutCtx))
+ if err == nil {
+ t.Errorf("request 2 unexpectedly succeeded (want timeout)")
+ }
+ time.Sleep(1 * time.Millisecond)
+
+ // Request 3 is made on a new conn, and succeeds.
+ res3, err := tr.RoundTrip(req.Clone(context.Background()))
+ if err != nil {
+ t.Fatalf("request 3: %v", err)
+ }
+ res3.Body.Close()
+
+ // Unblock conn 1, and verify that request 1 completes.
+ unblockOnce.Do(func() {
+ close(blockc)
+ })
+ <-req1c
+}