net/http: don't cache http2.erringRoundTripper connections

Fixes #34978

Change-Id: I3baf1392ba7366ae6628889c47c343ef702ec438
Reviewed-on: https://go-review.googlesource.com/c/go/+/202078
Reviewed-by: Bryan C. Mills <bcmills@google.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index ceda34c..c2880a0 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -540,10 +540,15 @@
 		if err == nil {
 			return resp, nil
 		}
-		if http2isNoCachedConnError(err) {
+
+		// Failed. Clean up and determine whether to retry.
+
+		_, isH2DialError := pconn.alt.(http2erringRoundTripper)
+		if http2isNoCachedConnError(err) || isH2DialError {
 			t.removeIdleConn(pconn)
 			t.decConnsPerHost(pconn.cacheKey)
-		} else if !pconn.shouldRetryRequest(req, err) {
+		}
+		if !pconn.shouldRetryRequest(req, err) {
 			// Issue 16465: return underlying net.Conn.Read error from peek,
 			// as we've historically done.
 			if e, ok := err.(transportReadFromServerError); ok {
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index c84d3ea..f76530b 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -5814,3 +5814,78 @@
 		})
 	}
 }
+
+// breakableConn is a net.Conn wrapper with a Write method
+// that will fail when its brokenState is true.
+type breakableConn struct {
+	net.Conn
+	*brokenState
+}
+
+type brokenState struct {
+	sync.Mutex
+	broken bool
+}
+
+func (w *breakableConn) Write(b []byte) (n int, err error) {
+	w.Lock()
+	defer w.Unlock()
+	if w.broken {
+		return 0, errors.New("some write error")
+	}
+	return w.Conn.Write(b)
+}
+
+// Issue 34978: don't cache a broken HTTP/2 connection
+func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
+	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
+	defer cst.close()
+
+	var brokenState brokenState
+
+	cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
+		c, err := net.Dial(netw, addr)
+		if err != nil {
+			t.Errorf("unexpected Dial error: %v", err)
+			return nil, err
+		}
+		return &breakableConn{c, &brokenState}, err
+	}
+
+	const numReqs = 5
+	var gotConns uint32 // atomic
+	for i := 1; i <= numReqs; i++ {
+		brokenState.Lock()
+		brokenState.broken = false
+		brokenState.Unlock()
+
+		// doBreak controls whether we break the TCP connection after the TLS
+		// handshake (before the HTTP/2 handshake). We test a few failures
+		// in a row followed by a final success.
+		doBreak := i != numReqs
+
+		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+			GotConn: func(info httptrace.GotConnInfo) {
+				atomic.AddUint32(&gotConns, 1)
+			},
+			TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
+				brokenState.Lock()
+				defer brokenState.Unlock()
+				if doBreak {
+					brokenState.broken = true
+				}
+			},
+		})
+		req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		_, err = cst.c.Do(req)
+		if doBreak != (err != nil) {
+			t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
+		}
+	}
+	if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
+		t.Errorf("GotConn calls = %v; want %v", got, want)
+	}
+}