http2: make Transport not reuse conns after a stream protocol error

If a server sends a stream error of type "protocol error" to a client,
that's the server saying "you're speaking http2 wrong". At that point,
regardless of whether we're in the right or not (that is, regardless of
whether the Transport is bug-free), clearly there's some confusion and
one of the two parties is either wrong or confused. There's no point
pushing on and trying to use the connection and potentially exacerbating
the confusion (as we saw in golang/go#47635).

Instead, make the client "poison" the connection by setting a new "do
not reuse" bit on it. Existing streams can finish up but new requests
won't pick that connection.

Also, make those requests as retryable without the caller getting an
error.

Given that golang/go#42777 existed, there are HTTP/2 servers in the
wild that incorrectly set RST_STREAM PROTOCOL_ERROR codes. But even
once those go away, this is still a reasonable fix for preventing
a broken connection from being stuck in the connection pool that fails
all future requests if a similar bug happens in another HTTP/2 server.

Updates golang/go#47635

Change-Id: I3f89ecd1d3710e49f7219ccb846e016eb269515b
Reviewed-on: https://go-review.googlesource.com/c/net/+/347033
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/http2/errors.go b/http2/errors.go
index 71f2c46..c789fa3 100644
--- a/http2/errors.go
+++ b/http2/errors.go
@@ -67,6 +67,11 @@
 	Cause    error // optional additional detail
 }
 
+// errFromPeer is a sentinel error value for StreamError.Cause to
+// indicate that the StreamError was sent from the peer over the wire
+// and wasn't locally generated in the Transport.
+var errFromPeer = errors.New("received from peer")
+
 func streamError(id uint32, code ErrCode) StreamError {
 	return StreamError{StreamID: id, Code: code}
 }
diff --git a/http2/transport.go b/http2/transport.go
index b261beb..dc31cfd 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -244,6 +244,7 @@
 	cond            *sync.Cond // hold mu; broadcast on flow/closed changes
 	flow            flow       // our conn-level flow control quota (cs.flow is per stream)
 	inflow          flow       // peer's conn-level flow control
+	doNotReuse      bool       // whether conn is marked to not be reused for any future requests
 	closing         bool
 	closed          bool
 	wantSettingsAck bool                     // we sent a SETTINGS frame and haven't heard back
@@ -563,6 +564,10 @@
 		return true
 	}
 	if se, ok := err.(StreamError); ok {
+		if se.Code == ErrCodeProtocol && se.Cause == errFromPeer {
+			// See golang/go#47635, golang/go#42777
+			return true
+		}
 		return se.Code == ErrCodeRefusedStream
 	}
 	return false
@@ -714,6 +719,13 @@
 	}
 }
 
+// SetDoNotReuse marks cc as not reusable for future HTTP requests.
+func (cc *ClientConn) SetDoNotReuse() {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	cc.doNotReuse = true
+}
+
 func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
@@ -776,6 +788,7 @@
 	}
 
 	st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
+		!cc.doNotReuse &&
 		int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
 		!cc.tooIdleLocked()
 	st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
@@ -2419,10 +2432,17 @@
 		// which closes this, so there
 		// isn't a race.
 	default:
-		err := streamError(cs.ID, f.ErrCode)
-		cs.resetErr = err
+		serr := streamError(cs.ID, f.ErrCode)
+		if f.ErrCode == ErrCodeProtocol {
+			rl.cc.SetDoNotReuse()
+			serr.Cause = errFromPeer
+			// TODO(bradfitz): increment a varz here, once Transport
+			// takes an optional interface-typed field that expvar.Map.Add
+			// implements.
+		}
+		cs.resetErr = serr
 		close(cs.peerReset)
-		cs.bufPipe.CloseWithError(err)
+		cs.bufPipe.CloseWithError(serr)
 		cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
 	}
 	return nil
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 2da7d9d..4412a89 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -4944,3 +4944,104 @@
 		})
 	}
 }
+
+// collectClientsConnPool is a ClientConnPool that wraps lower and
+// collects what calls were made on it.
+type collectClientsConnPool struct {
+	lower ClientConnPool
+
+	mu      sync.Mutex
+	getErrs int
+	got     []*ClientConn
+}
+
+func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
+	cc, err := p.lower.GetClientConn(req, addr)
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	if err != nil {
+		p.getErrs++
+		return nil, err
+	}
+	p.got = append(p.got, cc)
+	return cc, nil
+}
+
+func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
+	p.lower.MarkDead(cc)
+}
+
+func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
+	ct := newClientTester(t)
+	pool := &collectClientsConnPool{
+		lower: &clientConnPool{t: ct.tr},
+	}
+	ct.tr.ConnPool = pool
+	done := make(chan struct{})
+	ct.client = func() error {
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		const want = "only one dial allowed in test mode"
+		if got := fmt.Sprint(err); got != want {
+			t.Errorf("didn't dial again: got %#q; want %#q", got, want)
+		}
+		close(done)
+		ct.sc.Close()
+		if res != nil {
+			res.Body.Close()
+		}
+
+		pool.mu.Lock()
+		defer pool.mu.Unlock()
+		if pool.getErrs != 1 {
+			t.Errorf("pool get errors = %v; want 1", pool.getErrs)
+		}
+		if len(pool.got) == 1 {
+			cc := pool.got[0]
+			cc.mu.Lock()
+			if !cc.doNotReuse {
+				t.Error("ClientConn not marked doNotReuse")
+			}
+			cc.mu.Unlock()
+		} else {
+			t.Errorf("pool get success = %v; want 1", len(pool.got))
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		var sentErr bool
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-done:
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+				if !sentErr {
+					sentErr = true
+					ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
+					continue
+				}
+				var buf bytes.Buffer
+				enc := hpack.NewEncoder(&buf)
+				// send headers without Trailer header
+				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+				ct.fr.WriteHeaders(HeadersFrameParam{
+					StreamID:      f.StreamID,
+					EndHeaders:    true,
+					EndStream:     true,
+					BlockFragment: buf.Bytes(),
+				})
+			}
+		}
+		return nil
+	}
+	ct.run()
+}