http2: call httptrace.ClientTrace.GetConn in Transport when needed

Tests in CL 122591 in the standard library, to be submitted after this
CL.

Updates golang/go#23041

Change-Id: I3538cc7d2a71e3a26ab4c2f47bb220a25404cddb
Reviewed-on: https://go-review.googlesource.com/122590
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index bdf5652..f4d9b5e 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -52,9 +52,31 @@
 	noDialOnMiss = false
 )
 
+// shouldTraceGetConn reports whether getClientConn should call any
+// ClientTrace.GetConn hook associated with the http.Request.
+//
+// This complexity is needed to avoid double calls of the GetConn hook
+// during the back-and-forth between net/http and x/net/http2 (when the
+// net/http.Transport is upgraded to also speak http2), as well as support
+// the case where x/net/http2 is being used directly.
+func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
+	// If our Transport wasn't made via ConfigureTransport, always
+	// trace the GetConn hook if provided, because that means the
+	// http2 package is being used directly and it's the one
+	// dialing, as opposed to net/http.
+	if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
+		return true
+	}
+	// Otherwise, only use the GetConn hook if this connection has
+	// been used previously for other requests. For fresh
+	// connections, the net/http package does the dialing.
+	return !st.freshConn
+}
+
 func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
 	if isConnectionCloseRequest(req) && dialOnMiss {
 		// It gets its own connection.
+		traceGetConn(req, addr)
 		const singleUse = true
 		cc, err := p.t.dialClientConn(addr, singleUse)
 		if err != nil {
@@ -64,7 +86,10 @@
 	}
 	p.mu.Lock()
 	for _, cc := range p.conns[addr] {
-		if cc.CanTakeNewRequest() {
+		if st := cc.idleState(); st.canTakeNewRequest {
+			if p.shouldTraceGetConn(st) {
+				traceGetConn(req, addr)
+			}
 			p.mu.Unlock()
 			return cc, nil
 		}
@@ -73,6 +98,7 @@
 		p.mu.Unlock()
 		return nil, ErrNoCachedConn
 	}
+	traceGetConn(req, addr)
 	call := p.getStartDialLocked(addr)
 	p.mu.Unlock()
 	<-call.done
diff --git a/http2/go17.go b/http2/go17.go
index bf3a7c1..d957b7b 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -50,6 +50,14 @@
 
 func setResponseUncompressed(res *http.Response) { res.Uncompressed = true }
 
+func traceGetConn(req *http.Request, hostPort string) {
+	trace := httptrace.ContextClientTrace(req.Context())
+	if trace == nil || trace.GetConn == nil {
+		return
+	}
+	trace.GetConn(hostPort)
+}
+
 func traceGotConn(req *http.Request, cc *ClientConn) {
 	trace := httptrace.ContextClientTrace(req.Context())
 	if trace == nil || trace.GotConn == nil {
diff --git a/http2/not_go17.go b/http2/not_go17.go
index 976163f..7ffb250 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -37,6 +37,7 @@
 type clientTrace struct{}
 
 func requestTrace(*http.Request) *clientTrace { return nil }
+func traceGetConn(*http.Request, string)      {}
 func traceGotConn(*http.Request, *ClientConn) {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceWroteHeaders(*clientTrace)          {}
diff --git a/http2/transport.go b/http2/transport.go
index d474377..300b02f 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -631,12 +631,32 @@
 	return cc.canTakeNewRequestLocked()
 }
 
-func (cc *ClientConn) canTakeNewRequestLocked() bool {
+// clientConnIdleState describes the suitability of a client
+// connection to initiate a new RoundTrip request.
+type clientConnIdleState struct {
+	canTakeNewRequest bool
+	freshConn         bool // whether it's unused by any previous request
+}
+
+func (cc *ClientConn) idleState() clientConnIdleState {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return cc.idleStateLocked()
+}
+
+func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
 	if cc.singleUse && cc.nextStreamID > 1 {
-		return false
+		return
 	}
-	return cc.goAway == nil && !cc.closed && !cc.closing &&
+	st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
 		int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
+	st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
+	return
+}
+
+func (cc *ClientConn) canTakeNewRequestLocked() bool {
+	st := cc.idleStateLocked()
+	return st.canTakeNewRequest
 }
 
 // onIdleTimeout is called from a time.AfterFunc goroutine. It will