http2: avoid blocking while holding ClientConn.mu

Operations which examine the state of a ClientConn--notably,
the connection pool's check to see if a conn is available to
take a new request--need to acquire mu. Blocking while holding mu,
such as when writing to the network, blocks these operations.

Remove blocking operations from the mutex.
Perform network writes with only ClientConn.wmu held.
Clarify that wmu guards the per-conn HPACK encoder and buffer.

Add a new mutex guarding request creation, covering the critical
section starting with allocating a new stream ID and continuing
until the stream is created.

Fix a locking issue where trailers were written from the HPACK
buffer with only wmu held, but headers were encoded into the buffer
with only mu held. (Now both encoding and writes occur with wmu
held.)

Fixes golang/go#32388.
Fixes golang/go#48340.

Change-Id: Ibb313424ed2f32c1aeac4645b76aedf227b597a3
Reviewed-on: https://go-review.googlesource.com/c/net/+/349594
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index aaaffa8..dc4bdde 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -270,22 +270,29 @@
 	nextStreamID    uint32
 	pendingRequests int                       // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
 	pings           map[[8]byte]chan struct{} // in flight ping data to notification channel
-	bw              *bufio.Writer
 	br              *bufio.Reader
-	fr              *Framer
 	lastActive      time.Time
 	lastIdle        time.Time // time last idle
-	// Settings from peer: (also guarded by mu)
+	// Settings from peer: (also guarded by wmu)
 	maxFrameSize          uint32
 	maxConcurrentStreams  uint32
 	peerMaxHeaderListSize uint64
 	initialWindowSize     uint32
 
+	// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
+	// Write to reqHeaderMu to lock it, read from it to unlock.
+	// Lock reqmu BEFORE mu or wmu.
+	reqHeaderMu chan struct{}
+
+	// wmu is held while writing.
+	// Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
+	// Only acquire both at the same time when changing peer settings.
+	wmu  sync.Mutex
+	bw   *bufio.Writer
+	fr   *Framer
+	werr error        // first write error that has occurred
 	hbuf bytes.Buffer // HPACK encoder writes into this
 	henc *hpack.Encoder
-
-	wmu  sync.Mutex // held while writing; acquire AFTER mu if holding both
-	werr error      // first write error that has occurred
 }
 
 // clientStream is the state for a single HTTP/2 stream. One of these
@@ -404,10 +411,11 @@
 	cc.mu.Lock()
 	if cs.stopReqBody == nil {
 		cs.stopReqBody = err
-		if cs.req.Body != nil {
-			cs.req.Body.Close()
-		}
 		cc.cond.Broadcast()
+		// Close the body after releasing the mutex, in case it blocks.
+		if body := cs.req.Body; body != nil {
+			defer body.Close()
+		}
 	}
 	cc.mu.Unlock()
 }
@@ -672,6 +680,7 @@
 		singleUse:             singleUse,
 		wantSettingsAck:       true,
 		pings:                 make(map[[8]byte]chan struct{}),
+		reqHeaderMu:           make(chan struct{}, 1),
 	}
 	if d := t.idleConnTimeout(); d != 0 {
 		cc.idleTimeout = d
@@ -900,15 +909,18 @@
 
 func (cc *ClientConn) sendGoAway() error {
 	cc.mu.Lock()
-	defer cc.mu.Unlock()
-	cc.wmu.Lock()
-	defer cc.wmu.Unlock()
-	if cc.closing {
+	closing := cc.closing
+	cc.closing = true
+	maxStreamID := cc.nextStreamID
+	cc.mu.Unlock()
+	if closing {
 		// GOAWAY sent already
 		return nil
 	}
+
+	cc.wmu.Lock()
+	defer cc.wmu.Unlock()
 	// Send a graceful shutdown frame to server
-	maxStreamID := cc.nextStreamID
 	if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil {
 		return err
 	}
@@ -916,7 +928,6 @@
 		return err
 	}
 	// Prevent new requests
-	cc.closing = true
 	return nil
 }
 
@@ -924,17 +935,22 @@
 // err is sent to streams.
 func (cc *ClientConn) closeForError(err error) error {
 	cc.mu.Lock()
-	defer cc.cond.Broadcast()
-	defer cc.mu.Unlock()
-	for id, cs := range cc.streams {
+	streams := cc.streams
+	cc.streams = nil
+	cc.closed = true
+	cc.mu.Unlock()
+
+	for _, cs := range streams {
 		select {
 		case cs.resc <- resAndError{err: err}:
 		default:
 		}
 		cs.bufPipe.CloseWithError(err)
-		delete(cc.streams, id)
 	}
-	cc.closed = true
+
+	cc.mu.Lock()
+	defer cc.cond.Broadcast()
+	defer cc.mu.Unlock()
 	return cc.tconn.Close()
 }
 
@@ -1022,6 +1038,7 @@
 }
 
 func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
+	ctx := req.Context()
 	if err := checkConnHeaders(req); err != nil {
 		return nil, false, err
 	}
@@ -1035,6 +1052,26 @@
 	}
 	hasTrailers := trailers != ""
 
+	// Acquire the new-request lock by writing to reqHeaderMu.
+	// This lock guards the critical section covering allocating a new stream ID
+	// (requires mu) and creating the stream (requires wmu).
+	if cc.reqHeaderMu == nil {
+		panic("RoundTrip on initialized ClientConn") // for tests
+	}
+	select {
+	case cc.reqHeaderMu <- struct{}{}:
+	case <-req.Cancel:
+		return nil, false, errRequestCanceled
+	case <-ctx.Done():
+		return nil, false, ctx.Err()
+	}
+	reqHeaderMuNeedsUnlock := true
+	defer func() {
+		if reqHeaderMuNeedsUnlock {
+			<-cc.reqHeaderMu
+		}
+	}()
+
 	cc.mu.Lock()
 	if err := cc.awaitOpenSlotForRequest(req); err != nil {
 		cc.mu.Unlock()
@@ -1066,21 +1103,23 @@
 		requestedGzip = true
 	}
 
-	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
-	// sent by writeRequestBody below, along with any Trailers,
-	// again in form HEADERS{1}, CONTINUATION{0,})
-	hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
-	if err != nil {
-		cc.mu.Unlock()
-		return nil, false, err
-	}
-
 	cs := cc.newStream()
 	cs.req = req
 	cs.trace = httptrace.ContextClientTrace(req.Context())
 	cs.requestedGzip = requestedGzip
 	bodyWriter := cc.t.getBodyWriterState(cs, body)
 	cs.on100 = bodyWriter.on100
+	cc.mu.Unlock()
+
+	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
+	// sent by writeRequestBody below, along with any Trailers,
+	// again in form HEADERS{1}, CONTINUATION{0,})
+	cc.wmu.Lock()
+	hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
+	if err != nil {
+		cc.wmu.Unlock()
+		return nil, false, err
+	}
 
 	defer func() {
 		cc.wmu.Lock()
@@ -1091,24 +1130,24 @@
 		}
 	}()
 
-	cc.wmu.Lock()
 	endStream := !hasBody && !hasTrailers
-	werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
+	err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
 	cc.wmu.Unlock()
+	<-cc.reqHeaderMu // release the new-request lock
+	reqHeaderMuNeedsUnlock = false
 	traceWroteHeaders(cs.trace)
-	cc.mu.Unlock()
 
-	if werr != nil {
+	if err != nil {
 		if hasBody {
 			bodyWriter.cancel()
 		}
 		cc.forgetStreamID(cs.ID)
 		// Don't bother sending a RST_STREAM (our write already failed;
 		// no need to keep writing)
-		traceWroteRequest(cs.trace, werr)
+		traceWroteRequest(cs.trace, err)
 		// TODO(dneil): An error occurred while writing the headers.
 		// Should we return an error indicating that this request can be retried?
-		return nil, false, werr
+		return nil, false, err
 	}
 
 	var respHeaderTimer <-chan time.Time
@@ -1125,7 +1164,6 @@
 
 	readLoopResCh := cs.resc
 	bodyWritten := false
-	ctx := req.Context()
 
 	handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
 		res := re.res
@@ -1427,19 +1465,17 @@
 		return nil
 	}
 
+	cc.wmu.Lock()
 	var trls []byte
 	if hasTrailers {
-		cc.mu.Lock()
 		trls, err = cc.encodeTrailers(req)
-		cc.mu.Unlock()
 		if err != nil {
+			cc.wmu.Unlock()
 			cc.writeStreamReset(cs.ID, ErrCodeInternal, err)
 			cc.forgetStreamID(cs.ID)
 			return err
 		}
 	}
-
-	cc.wmu.Lock()
 	defer cc.wmu.Unlock()
 
 	// Two ways to send END_STREAM: either with trailers, or
@@ -1489,7 +1525,7 @@
 	}
 }
 
-// requires cc.mu be held.
+// requires cc.wmu be held.
 func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
 	cc.hbuf.Reset()
 
@@ -1677,7 +1713,7 @@
 	}
 }
 
-// requires cc.mu be held.
+// requires cc.wmu be held.
 func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) {
 	cc.hbuf.Reset()
 
@@ -1826,7 +1862,11 @@
 	} else if err == io.EOF {
 		err = io.ErrUnexpectedEOF
 	}
-	for _, cs := range cc.streams {
+	cc.closed = true
+	streams := cc.streams
+	cc.streams = nil
+	cc.mu.Unlock()
+	for _, cs := range streams {
 		cs.bufPipe.CloseWithError(err) // no-op if already closed
 		select {
 		case cs.resc <- resAndError{err: err}:
@@ -1834,7 +1874,7 @@
 		}
 		close(cs.done)
 	}
-	cc.closed = true
+	cc.mu.Lock()
 	cc.cond.Broadcast()
 	cc.mu.Unlock()
 }
@@ -2192,8 +2232,6 @@
 	}
 
 	cc.mu.Lock()
-	defer cc.mu.Unlock()
-
 	var connAdd, streamAdd int32
 	// Check the conn-level first, before the stream-level.
 	if v := cc.inflow.available(); v < transportDefaultConnFlow/2 {
@@ -2210,6 +2248,8 @@
 			cs.inflow.add(streamAdd)
 		}
 	}
+	cc.mu.Unlock()
+
 	if connAdd != 0 || streamAdd != 0 {
 		cc.wmu.Lock()
 		defer cc.wmu.Unlock()
@@ -2235,19 +2275,25 @@
 
 	if unread > 0 || !serverSentStreamEnd {
 		cc.mu.Lock()
-		cc.wmu.Lock()
 		if !serverSentStreamEnd {
-			cc.fr.WriteRSTStream(cs.ID, ErrCodeCancel)
 			cs.didReset = true
 		}
 		// Return connection-level flow control.
 		if unread > 0 {
 			cc.inflow.add(int32(unread))
+		}
+		cc.mu.Unlock()
+
+		cc.wmu.Lock()
+		if !serverSentStreamEnd {
+			cc.fr.WriteRSTStream(cs.ID, ErrCodeCancel)
+		}
+		// Return connection-level flow control.
+		if unread > 0 {
 			cc.fr.WriteWindowUpdate(0, uint32(unread))
 		}
 		cc.bw.Flush()
 		cc.wmu.Unlock()
-		cc.mu.Unlock()
 	}
 
 	cs.bufPipe.BreakWithError(errClosedResponseBody)
@@ -2325,6 +2371,10 @@
 		}
 		if refund > 0 {
 			cc.inflow.add(int32(refund))
+		}
+		cc.mu.Unlock()
+
+		if refund > 0 {
 			cc.wmu.Lock()
 			cc.fr.WriteWindowUpdate(0, uint32(refund))
 			if !didReset {
@@ -2334,7 +2384,6 @@
 			cc.bw.Flush()
 			cc.wmu.Unlock()
 		}
-		cc.mu.Unlock()
 
 		if len(data) > 0 && !didReset {
 			if _, err := cs.bufPipe.Write(data); err != nil {
@@ -2400,6 +2449,23 @@
 
 func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error {
 	cc := rl.cc
+	// Locking both mu and wmu here allows frame encoding to read settings with only wmu held.
+	// Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless.
+	cc.wmu.Lock()
+	defer cc.wmu.Unlock()
+
+	if err := rl.processSettingsNoWrite(f); err != nil {
+		return err
+	}
+	if !f.IsAck() {
+		cc.fr.WriteSettingsAck()
+		cc.bw.Flush()
+	}
+	return nil
+}
+
+func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
+	cc := rl.cc
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 
@@ -2461,12 +2527,7 @@
 		cc.seenSettings = true
 	}
 
-	cc.wmu.Lock()
-	defer cc.wmu.Unlock()
-
-	cc.fr.WriteSettingsAck()
-	cc.bw.Flush()
-	return cc.werr
+	return nil
 }
 
 func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 97735fe..ab31640 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -9,6 +9,7 @@
 	"bytes"
 	"context"
 	"crypto/tls"
+	"encoding/hex"
 	"errors"
 	"flag"
 	"fmt"
@@ -3261,7 +3262,8 @@
 	const body = "foo"
 	req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
 	cc := &ClientConn{
-		closed: true,
+		closed:      true,
+		reqHeaderMu: make(chan struct{}, 1),
 	}
 	_, err := cc.RoundTrip(req)
 	if err != errClientConnUnusable {
@@ -4990,6 +4992,245 @@
 	return nil
 }
 
+// A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written.
+type blockingWriteConn struct {
+	net.Conn
+	writeOnce    sync.Once
+	writec       chan struct{} // closed after the write limit is reached
+	unblockc     chan struct{} // closed to unblock writes
+	count, limit int
+}
+
+func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
+	return &blockingWriteConn{
+		Conn:     conn,
+		limit:    limit,
+		writec:   make(chan struct{}),
+		unblockc: make(chan struct{}),
+	}
+}
+
+// wait waits until the conn blocks writing the limit+1st byte.
+func (c *blockingWriteConn) wait() {
+	<-c.writec
+}
+
+// unblock unblocks writes to the conn.
+func (c *blockingWriteConn) unblock() {
+	close(c.unblockc)
+}
+
+func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
+	if c.count+len(b) > c.limit {
+		c.writeOnce.Do(func() {
+			close(c.writec)
+		})
+		<-c.unblockc
+	}
+	n, err = c.Conn.Write(b)
+	c.count += n
+	return n, err
+}
+
+// Write several requests to a ClientConn at the same time, looking for race conditions.
+// See golang.org/issue/48340
+func TestTransportFrameBufferReuse(t *testing.T) {
+	filler := hex.EncodeToString([]byte(randString(2048)))
+
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		if got, want := r.Header.Get("Big"), filler; got != want {
+			t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
+		}
+		b, err := ioutil.ReadAll(r.Body)
+		if err != nil {
+			t.Errorf("error reading request body: %v", err)
+		}
+		if got, want := string(b), filler; got != want {
+			t.Errorf("request body = %q, want %q", got, want)
+		}
+		if got, want := r.Trailer.Get("Big"), filler; got != want {
+			t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
+		}
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	var wg sync.WaitGroup
+	defer wg.Wait()
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler))
+			if err != nil {
+				t.Fatal(err)
+			}
+			req.Header.Set("Big", filler)
+			req.Trailer = make(http.Header)
+			req.Trailer.Set("Big", filler)
+			res, err := tr.RoundTrip(req)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if got, want := res.StatusCode, 200; got != want {
+				t.Errorf("StatusCode = %v; want %v", got, want)
+			}
+			if res != nil && res.Body != nil {
+				res.Body.Close()
+			}
+		}()
+	}
+
+}
+
+// Ensure that a request blocking while being written to the underlying net.Conn doesn't
+// block access to the ClientConn pool. Test requests blocking while writing headers, the body,
+// and trailers.
+// See golang.org/issue/32388
+func TestTransportBlockingRequestWrite(t *testing.T) {
+	filler := hex.EncodeToString([]byte(randString(2048)))
+	for _, test := range []struct {
+		name string
+		req  func(url string) (*http.Request, error)
+	}{{
+		name: "headers",
+		req: func(url string) (*http.Request, error) {
+			req, err := http.NewRequest("POST", url, nil)
+			if err != nil {
+				return nil, err
+			}
+			req.Header.Set("Big", filler)
+			return req, err
+		},
+	}, {
+		name: "body",
+		req: func(url string) (*http.Request, error) {
+			req, err := http.NewRequest("POST", url, strings.NewReader(filler))
+			if err != nil {
+				return nil, err
+			}
+			return req, err
+		},
+	}, {
+		name: "trailer",
+		req: func(url string) (*http.Request, error) {
+			req, err := http.NewRequest("POST", url, strings.NewReader("body"))
+			if err != nil {
+				return nil, err
+			}
+			req.Trailer = make(http.Header)
+			req.Trailer.Set("Big", filler)
+			return req, err
+		},
+	}} {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+				if v := r.Header.Get("Big"); v != "" && v != filler {
+					t.Errorf("request header mismatch")
+				}
+				if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler {
+					t.Errorf("request body mismatch\ngot:  %q\nwant: %q", string(v), filler)
+				}
+				if v := r.Trailer.Get("Big"); v != "" && v != filler {
+					t.Errorf("request trailer mismatch\ngot:  %q\nwant: %q", string(v), filler)
+				}
+			}, optOnlyServer, func(s *Server) {
+				s.MaxConcurrentStreams = 1
+			})
+			defer st.Close()
+
+			// This Transport creates connections that block on writes after 1024 bytes.
+			connc := make(chan *blockingWriteConn, 1)
+			connCount := 0
+			tr := &Transport{
+				TLSClientConfig: tlsConfigInsecure,
+				DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+					connCount++
+					c, err := tls.Dial(network, addr, cfg)
+					wc := newBlockingWriteConn(c, 1024)
+					select {
+					case connc <- wc:
+					default:
+					}
+					return wc, err
+				},
+			}
+			defer tr.CloseIdleConnections()
+
+			// Request 1: A small request to ensure we read the server MaxConcurrentStreams.
+			{
+				req, err := http.NewRequest("POST", st.ts.URL, nil)
+				if err != nil {
+					t.Fatal(err)
+				}
+				res, err := tr.RoundTrip(req)
+				if err != nil {
+					t.Fatal(err)
+				}
+				if got, want := res.StatusCode, 200; got != want {
+					t.Errorf("StatusCode = %v; want %v", got, want)
+				}
+				if res != nil && res.Body != nil {
+					res.Body.Close()
+				}
+			}
+
+			// Request 2: A large request that blocks while being written.
+			reqc := make(chan struct{})
+			go func() {
+				defer close(reqc)
+				req, err := test.req(st.ts.URL)
+				if err != nil {
+					t.Error(err)
+					return
+				}
+				res, _ := tr.RoundTrip(req)
+				if res != nil && res.Body != nil {
+					res.Body.Close()
+				}
+			}()
+			conn := <-connc
+			conn.wait() // wait for the request to block
+
+			// Request 3: A small request that is sent on a new connection, since request 2
+			// is hogging the only available stream on the previous connection.
+			{
+				req, err := http.NewRequest("POST", st.ts.URL, nil)
+				if err != nil {
+					t.Fatal(err)
+				}
+				res, err := tr.RoundTrip(req)
+				if err != nil {
+					t.Fatal(err)
+				}
+				if got, want := res.StatusCode, 200; got != want {
+					t.Errorf("StatusCode = %v; want %v", got, want)
+				}
+				if res != nil && res.Body != nil {
+					res.Body.Close()
+				}
+			}
+
+			// Request 2 should still be blocking at this point.
+			select {
+			case <-reqc:
+				t.Errorf("request 2 unexpectedly completed")
+			default:
+			}
+
+			conn.unblock()
+			<-reqc
+
+			if connCount != 2 {
+				t.Errorf("created %v connections, want 1", connCount)
+			}
+		})
+	}
+}
+
 func TestTransportCloseRequestBody(t *testing.T) {
 	var statusCode int
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {