http2: avoid extra GetConn trace call

CL 352469 inverts the case in shouldTraceGetConn: We want to call GetConn
for connections that have been previously used, but it calls GetConn
only on approximately the first use. "Approximately", because it uses
cc.nextStreamID to determine if the connection has been used, which
is racy.

Restructure the decision to call GetConn to track a per-ClientConn bool
indicating whether GetConn has already been called for this connection.
Set this bool for connections received from net/http, clear it after the
first use of the connection.

Fixes net/http's TestTransportEventTrace_h2.

Change-Id: I8e3dbba7cfbce9acd3612e39b6b6ee558bbfc864
Reviewed-on: https://go-review.googlesource.com/c/net/+/353875
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 7f817e2..6648ebb 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -48,7 +48,7 @@
 	conns        map[string][]*ClientConn // key is host:port
 	dialing      map[string]*dialCall     // currently in-flight dials
 	keys         map[*ClientConn][]string
-	addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
+	addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls
 }
 
 func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@@ -60,29 +60,6 @@
 	noDialOnMiss = false
 )
 
-// shouldTraceGetConn reports whether getClientConn should call any
-// ClientTrace.GetConn hook associated with the http.Request.
-//
-// This complexity is needed to avoid double calls of the GetConn hook
-// during the back-and-forth between net/http and x/net/http2 (when the
-// net/http.Transport is upgraded to also speak http2), as well as support
-// the case where x/net/http2 is being used directly.
-func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool {
-	// If our Transport wasn't made via ConfigureTransport, always
-	// trace the GetConn hook if provided, because that means the
-	// http2 package is being used directly and it's the one
-	// dialing, as opposed to net/http.
-	if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
-		return true
-	}
-	// Otherwise, only use the GetConn hook if this connection has
-	// been used previously for other requests. For fresh
-	// connections, the net/http package does the dialing.
-	cc.mu.Lock()
-	defer cc.mu.Unlock()
-	return cc.nextStreamID == 1
-}
-
 func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
 	// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
 	if isConnectionCloseRequest(req) && dialOnMiss {
@@ -99,9 +76,13 @@
 		p.mu.Lock()
 		for _, cc := range p.conns[addr] {
 			if cc.ReserveNewRequest() {
-				if p.shouldTraceGetConn(cc) {
+				// When a connection is presented to us by the net/http package,
+				// the GetConn hook has already been called.
+				// Don't call it a second time here.
+				if !cc.getConnCalled {
 					traceGetConn(req, addr)
 				}
+				cc.getConnCalled = false
 				p.mu.Unlock()
 				return cc, nil
 			}
@@ -214,6 +195,7 @@
 
 func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
 	cc, err := t.NewClientConn(tc)
+	cc.getConnCalled = true // already called by the net/http package
 
 	p := c.p
 	p.mu.Lock()
diff --git a/http2/transport.go b/http2/transport.go
index ee4531f..e6aede6 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -242,11 +242,12 @@
 // ClientConn is the state of a single HTTP/2 client connection to an
 // HTTP/2 server.
 type ClientConn struct {
-	t         *Transport
-	tconn     net.Conn             // usually *tls.Conn, except specialized impls
-	tlsState  *tls.ConnectionState // nil only for specialized impls
-	reused    uint32               // whether conn is being reused; atomic
-	singleUse bool                 // whether being used for a single http.Request
+	t             *Transport
+	tconn         net.Conn             // usually *tls.Conn, except specialized impls
+	tlsState      *tls.ConnectionState // nil only for specialized impls
+	reused        uint32               // whether conn is being reused; atomic
+	singleUse     bool                 // whether being used for a single http.Request
+	getConnCalled bool                 // used by clientConnPool
 
 	// readLoop goroutine fields:
 	readerDone chan struct{} // closed on error
@@ -762,7 +763,6 @@
 // connection to initiate a new RoundTrip request.
 type clientConnIdleState struct {
 	canTakeNewRequest bool
-	freshConn         bool // whether it's unused by any previous request
 }
 
 func (cc *ClientConn) idleState() clientConnIdleState {
@@ -790,7 +790,6 @@
 		!cc.doNotReuse &&
 		int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
 		!cc.tooIdleLocked()
-	st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
 	return
 }
 
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 60b67ed..fe0ab88 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -243,6 +243,68 @@
 	}
 }
 
+func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
+	testTransportGetGotConnHooks(t, false)
+}
+func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
+
+func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, r.RemoteAddr)
+	}, func(s *httptest.Server) {
+		s.EnableHTTP2 = true
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	client := st.ts.Client()
+	ConfigureTransports(client.Transport.(*http.Transport))
+
+	var (
+		getConns int32
+		gotConns int32
+	)
+	for i := 0; i < 2; i++ {
+		trace := &httptrace.ClientTrace{
+			GetConn: func(hostport string) {
+				atomic.AddInt32(&getConns, 1)
+			},
+			GotConn: func(connInfo httptrace.GotConnInfo) {
+				got := atomic.AddInt32(&gotConns, 1)
+				wantReused, wantWasIdle := false, false
+				if got > 1 {
+					wantReused, wantWasIdle = true, true
+				}
+				if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
+					t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
+				}
+			},
+		}
+		req, err := http.NewRequest("GET", st.ts.URL, nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+		var res *http.Response
+		if useClient {
+			res, err = client.Do(req)
+		} else {
+			res, err = tr.RoundTrip(req)
+		}
+		if err != nil {
+			t.Fatal(err)
+		}
+		res.Body.Close()
+		if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
+			t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
+		}
+		if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
+			t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
+		}
+	}
+}
+
 // Tests that the Transport only keeps one pending dial open per destination address.
 // https://golang.org/issue/13397
 func TestTransportGroupsPendingDials(t *testing.T) {