http2: don't reuse connections that are experiencing errors

When a request on a connection fails to complete successfully,
mark the conn as doNotReuse. It's possible for requests to
fail for reasons unrelated to connection health,
but opening a new connection unnecessarily is less of an
impact than reusing a dead connection.

Fixes golang/go#59690

Change-Id: I40bf6cefae602ead70c3bcf2fe573cc13f34a385
Reviewed-on: https://go-review.googlesource.com/c/net/+/486156
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
diff --git a/http2/transport.go b/http2/transport.go
index f965579..ac90a26 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1266,6 +1266,27 @@
 		return res, nil
 	}
 
+	cancelRequest := func(cs *clientStream, err error) error {
+		cs.cc.mu.Lock()
+		defer cs.cc.mu.Unlock()
+		cs.abortStreamLocked(err)
+		if cs.ID != 0 {
+			// This request may have failed because of a problem with the connection,
+			// or for some unrelated reason. (For example, the user might have canceled
+			// the request without waiting for a response.) Mark the connection as
+			// not reusable, since trying to reuse a dead connection is worse than
+			// unnecessarily creating a new one.
+			//
+			// If cs.ID is 0, then the request was never allocated a stream ID and
+			// whatever went wrong was unrelated to the connection. We might have
+			// timed out waiting for a stream slot when StrictMaxConcurrentStreams
+			// is set, for example, in which case retrying on a different connection
+			// will not help.
+			cs.cc.doNotReuse = true
+		}
+		return err
+	}
+
 	for {
 		select {
 		case <-cs.respHeaderRecv:
@@ -1280,15 +1301,12 @@
 				return handleResponseHeaders()
 			default:
 				waitDone()
-				return nil, cs.abortErr
+				return nil, cancelRequest(cs, cs.abortErr)
 			}
 		case <-ctx.Done():
-			err := ctx.Err()
-			cs.abortStream(err)
-			return nil, err
+			return nil, cancelRequest(cs, ctx.Err())
 		case <-cs.reqCancel:
-			cs.abortStream(errRequestCanceled)
-			return nil, errRequestCanceled
+			return nil, cancelRequest(cs, errRequestCanceled)
 		}
 	}
 }
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 5adef42..54d4551 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -775,7 +775,6 @@
 	cc, err := net.Dial("tcp", ln.Addr().String())
 	if err != nil {
 		t.Fatal(err)
-
 	}
 	sc, err := ln.Accept()
 	if err != nil {
@@ -1765,6 +1764,18 @@
 	defer tr.CloseIdleConnections()
 
 	checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
+		// Make an arbitrary request to ensure we get the server's
+		// settings frame and initialize peerMaxHeaderListSize.
+		req0, err := http.NewRequest("GET", st.ts.URL, nil)
+		if err != nil {
+			t.Fatalf("newRequest: NewRequest: %v", err)
+		}
+		res0, err := tr.RoundTrip(req0)
+		if err != nil {
+			t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
+		}
+		res0.Body.Close()
+
 		res, err := tr.RoundTrip(req)
 		if err != wantErr {
 			if res != nil {
@@ -1825,13 +1836,9 @@
 		return req
 	}
 
-	// Make an arbitrary request to ensure we get the server's
-	// settings frame and initialize peerMaxHeaderListSize.
+	// Validate peerMaxHeaderListSize.
 	req := newRequest()
 	checkRoundTrip(req, nil, "Initial request")
-
-	// Get the ClientConn associated with the request and validate
-	// peerMaxHeaderListSize.
 	addr := authorityAddr(req.URL.Scheme, req.URL.Host)
 	cc, err := tr.connPool().GetClientConn(req, addr)
 	if err != nil {
@@ -3738,35 +3745,33 @@
 	ct.run()
 }
 
-func TestTransportRetryAfterGOAWAY(t *testing.T) {
-	var dialer struct {
-		sync.Mutex
-		count int
-	}
-	ct1 := make(chan *clientTester)
-	ct2 := make(chan *clientTester)
-
+func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
 	ln := newLocalListener(t)
 	defer ln.Close()
 
+	var (
+		mu    sync.Mutex
+		count int
+		conns []net.Conn
+	)
+	var wg sync.WaitGroup
 	tr := &Transport{
 		TLSClientConfig: tlsConfigInsecure,
 	}
 	tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
-		dialer.Lock()
-		defer dialer.Unlock()
-		dialer.count++
-		if dialer.count == 3 {
-			return nil, errors.New("unexpected number of dials")
-		}
+		mu.Lock()
+		defer mu.Unlock()
+		count++
 		cc, err := net.Dial("tcp", ln.Addr().String())
 		if err != nil {
 			return nil, fmt.Errorf("dial error: %v", err)
 		}
+		conns = append(conns, cc)
 		sc, err := ln.Accept()
 		if err != nil {
 			return nil, fmt.Errorf("accept error: %v", err)
 		}
+		conns = append(conns, sc)
 		ct := &clientTester{
 			t:  t,
 			tr: tr,
@@ -3774,19 +3779,26 @@
 			sc: sc,
 			fr: NewFramer(sc, sc),
 		}
-		switch dialer.count {
-		case 1:
-			ct1 <- ct
-		case 2:
-			ct2 <- ct
-		}
+		wg.Add(1)
+		go func(count int) {
+			defer wg.Done()
+			server(count, ct)
+			sc.Close()
+		}(count)
 		return cc, nil
 	}
 
-	errs := make(chan error, 3)
+	client(tr)
+	tr.CloseIdleConnections()
+	ln.Close()
+	for _, c := range conns {
+		c.Close()
+	}
+	wg.Wait()
+}
 
-	// Client.
-	go func() {
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
+	client := func(tr *Transport) {
 		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
 		res, err := tr.RoundTrip(req)
 		if res != nil {
@@ -3796,102 +3808,76 @@
 			}
 		}
 		if err != nil {
-			err = fmt.Errorf("RoundTrip: %v", err)
-		}
-		errs <- err
-	}()
-
-	connToClose := make(chan io.Closer, 2)
-
-	// Server for the first request.
-	go func() {
-		ct := <-ct1
-
-		connToClose <- ct.cc
-		ct.greet()
-		hf, err := ct.firstHeaders()
-		if err != nil {
-			errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
-			return
-		}
-		t.Logf("server1 got %v", hf)
-		if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
-			errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
-			return
-		}
-		errs <- nil
-	}()
-
-	// Server for the second request.
-	go func() {
-		ct := <-ct2
-
-		connToClose <- ct.cc
-		ct.greet()
-		hf, err := ct.firstHeaders()
-		if err != nil {
-			errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
-			return
-		}
-		t.Logf("server2 got %v", hf)
-
-		var buf bytes.Buffer
-		enc := hpack.NewEncoder(&buf)
-		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
-		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
-		err = ct.fr.WriteHeaders(HeadersFrameParam{
-			StreamID:      hf.StreamID,
-			EndHeaders:    true,
-			EndStream:     false,
-			BlockFragment: buf.Bytes(),
-		})
-		if err != nil {
-			errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
-		} else {
-			errs <- nil
-		}
-	}()
-
-	for k := 0; k < 3; k++ {
-		err := <-errs
-		if err != nil {
-			t.Error(err)
+			t.Errorf("RoundTrip: %v", err)
 		}
 	}
 
-	close(connToClose)
-	for c := range connToClose {
-		c.Close()
+	server := func(count int, ct *clientTester) {
+		switch count {
+		case 1:
+			ct.greet()
+			hf, err := ct.firstHeaders()
+			if err != nil {
+				t.Errorf("server1 failed reading HEADERS: %v", err)
+				return
+			}
+			t.Logf("server1 got %v", hf)
+			if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
+				t.Errorf("server1 failed writing GOAWAY: %v", err)
+				return
+			}
+		case 2:
+			ct.greet()
+			hf, err := ct.firstHeaders()
+			if err != nil {
+				t.Errorf("server2 failed reading HEADERS: %v", err)
+				return
+			}
+			t.Logf("server2 got %v", hf)
+
+			var buf bytes.Buffer
+			enc := hpack.NewEncoder(&buf)
+			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+			enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+			err = ct.fr.WriteHeaders(HeadersFrameParam{
+				StreamID:      hf.StreamID,
+				EndHeaders:    true,
+				EndStream:     false,
+				BlockFragment: buf.Bytes(),
+			})
+			if err != nil {
+				t.Errorf("server2 failed writing response HEADERS: %v", err)
+			}
+		default:
+			t.Errorf("unexpected number of dials")
+			return
+		}
 	}
+
+	testClientMultipleDials(t, client, server)
 }
 
 func TestTransportRetryAfterRefusedStream(t *testing.T) {
 	clientDone := make(chan struct{})
-	ct := newClientTester(t)
-	ct.client = func() error {
-		defer ct.cc.(*net.TCPConn).CloseWrite()
-		if runtime.GOOS == "plan9" {
-			// CloseWrite not supported on Plan 9; Issue 17906
-			defer ct.cc.(*net.TCPConn).Close()
-		}
+	client := func(tr *Transport) {
 		defer close(clientDone)
 		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
-		resp, err := ct.tr.RoundTrip(req)
+		resp, err := tr.RoundTrip(req)
 		if err != nil {
-			return fmt.Errorf("RoundTrip: %v", err)
+			t.Errorf("RoundTrip: %v", err)
+			return
 		}
 		resp.Body.Close()
 		if resp.StatusCode != 204 {
-			return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+			t.Errorf("Status = %v; want 204", resp.StatusCode)
+			return
 		}
-		return nil
 	}
-	ct.server = func() error {
+
+	server := func(count int, ct *clientTester) {
 		ct.greet()
 		var buf bytes.Buffer
 		enc := hpack.NewEncoder(&buf)
-		nreq := 0
-
 		for {
 			f, err := ct.fr.ReadFrame()
 			if err != nil {
@@ -3900,19 +3886,19 @@
 					// If the client's done, it
 					// will have reported any
 					// errors on its side.
-					return nil
 				default:
-					return err
+					t.Error(err)
 				}
+				return
 			}
 			switch f := f.(type) {
 			case *WindowUpdateFrame, *SettingsFrame:
 			case *HeadersFrame:
 				if !f.HeadersEnded() {
-					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+					t.Errorf("headers should have END_HEADERS be ended: %v", f)
+					return
 				}
-				nreq++
-				if nreq == 1 {
+				if count == 1 {
 					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
 				} else {
 					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
@@ -3924,11 +3910,13 @@
 					})
 				}
 			default:
-				return fmt.Errorf("Unexpected client frame %v", f)
+				t.Errorf("Unexpected client frame %v", f)
+				return
 			}
 		}
 	}
-	ct.run()
+
+	testClientMultipleDials(t, client, server)
 }
 
 func TestTransportRetryHasLimit(t *testing.T) {
@@ -4143,6 +4131,7 @@
 	greet := make(chan struct{})      // server sends initial SETTINGS frame
 	gotRequest := make(chan struct{}) // server received a request
 	clientDone := make(chan struct{})
+	cancelClientRequest := make(chan struct{})
 
 	// Collect errors from goroutines.
 	var wg sync.WaitGroup
@@ -4221,9 +4210,8 @@
 				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
 				if k == maxConcurrent {
 					// This request will be canceled.
-					cancel := make(chan struct{})
-					req.Cancel = cancel
-					close(cancel)
+					req.Cancel = cancelClientRequest
+					close(cancelClientRequest)
 					_, err := ct.tr.RoundTrip(req)
 					close(clientRequestCancelled)
 					if err == nil {
@@ -5986,14 +5974,21 @@
 }
 
 func TestClientConnReservations(t *testing.T) {
-	cc := &ClientConn{
-		reqHeaderMu:          make(chan struct{}, 1),
-		streams:              make(map[uint32]*clientStream),
-		maxConcurrentStreams: initialMaxConcurrentStreams,
-		nextStreamID:         1,
-		t:                    &Transport{},
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+	}, func(s *Server) {
+		s.MaxConcurrentStreams = initialMaxConcurrentStreams
+	})
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	cc, err := tr.newClientConn(st.cc, false)
+	if err != nil {
+		t.Fatal(err)
 	}
-	cc.cond = sync.NewCond(&cc.mu)
+
+	req, _ := http.NewRequest("GET", st.ts.URL, nil)
 	n := 0
 	for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
 		n++
@@ -6001,8 +5996,8 @@
 	if n != initialMaxConcurrentStreams {
 		t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
 	}
-	if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
-		t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
+	if _, err := cc.RoundTrip(req); err != nil {
+		t.Fatalf("RoundTrip error = %v", err)
 	}
 	n2 := 0
 	for n2 <= 5 && cc.ReserveNewRequest() {
@@ -6014,7 +6009,7 @@
 
 	// Use up all the reservations
 	for i := 0; i < n; i++ {
-		cc.RoundTrip(new(http.Request))
+		cc.RoundTrip(req)
 	}
 
 	n2 = 0
@@ -6370,3 +6365,95 @@
 	}
 	res.Body.Close()
 }
+
+type blockReadConn struct {
+	net.Conn
+	blockc chan struct{}
+}
+
+func (c *blockReadConn) Read(b []byte) (n int, err error) {
+	<-c.blockc
+	return c.Conn.Read(b)
+}
+
+func TestTransportReuseAfterError(t *testing.T) {
+	serverReqc := make(chan struct{}, 3)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		serverReqc <- struct{}{}
+	}, optOnlyServer)
+	defer st.Close()
+
+	var (
+		unblockOnce sync.Once
+		blockc      = make(chan struct{})
+		connCountMu sync.Mutex
+		connCount   int
+	)
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			// The first connection dialed will block on reads until blockc is closed.
+			connCountMu.Lock()
+			defer connCountMu.Unlock()
+			connCount++
+			conn, err := tls.Dial(network, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			if connCount == 1 {
+				return &blockReadConn{
+					Conn:   conn,
+					blockc: blockc,
+				}, nil
+			}
+			return conn, nil
+		},
+	}
+	defer tr.CloseIdleConnections()
+	defer unblockOnce.Do(func() {
+		// Ensure that reads on blockc are unblocked if we return early.
+		close(blockc)
+	})
+
+	req, _ := http.NewRequest("GET", st.ts.URL, nil)
+
+	// Request 1 is made on conn 1.
+	// Reading the response will block.
+	// Wait until the server receives the request, and continue.
+	req1c := make(chan struct{})
+	go func() {
+		defer close(req1c)
+		res1, err := tr.RoundTrip(req.Clone(context.Background()))
+		if err != nil {
+			t.Errorf("request 1: %v", err)
+		} else {
+			res1.Body.Close()
+		}
+	}()
+	<-serverReqc
+
+	// Request 2 is also made on conn 1.
+	// Reading the response will block.
+	// The request fails when the context deadline expires.
+	// Conn 1 should now be flagged as unfit for reuse.
+	timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+	defer cancel()
+	_, err := tr.RoundTrip(req.Clone(timeoutCtx))
+	if err == nil {
+		t.Errorf("request 2 unexpectedly succeeded (want timeout)")
+	}
+	time.Sleep(1 * time.Millisecond)
+
+	// Request 3 is made on a new conn, and succeeds.
+	res3, err := tr.RoundTrip(req.Clone(context.Background()))
+	if err != nil {
+		t.Fatalf("request 3: %v", err)
+	}
+	res3.Body.Close()
+
+	// Unblock conn 1, and verify that request 1 completes.
+	unblockOnce.Do(func() {
+		close(blockc)
+	})
+	<-req1c
+}