http2: track reused connections

nextStreamID was used as a means to determine if the connection was
being reused. Multiple requests can see a new connection because the
nextStreamID is updated after a ClientTrace reports it is being reused.

Updates golang/go#31982

Change-Id: Iaa4b62b217f015423cddb99fd86de75a352f8320
Reviewed-on: https://go-review.googlesource.com/c/net/+/176720
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 4ec0792..c0c80d8 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -28,6 +28,7 @@
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"golang.org/x/net/http/httpguts"
@@ -199,6 +200,7 @@
 	t         *Transport
 	tconn     net.Conn             // usually *tls.Conn, except specialized impls
 	tlsState  *tls.ConnectionState // nil only for specialized impls
+	reused    uint32               // whether conn is being reused; atomic
 	singleUse bool                 // whether being used for a single http.Request
 
 	// readLoop goroutine fields:
@@ -440,7 +442,8 @@
 			t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
 			return nil, err
 		}
-		traceGotConn(req, cc)
+		reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+		traceGotConn(req, cc, reused)
 		res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
 		if err != nil && retry <= 6 {
 			if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
@@ -2559,15 +2562,15 @@
 	trace.GetConn(hostPort)
 }
 
-func traceGotConn(req *http.Request, cc *ClientConn) {
+func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
 	trace := httptrace.ContextClientTrace(req.Context())
 	if trace == nil || trace.GotConn == nil {
 		return
 	}
 	ci := httptrace.GotConnInfo{Conn: cc.tconn}
+	ci.Reused = reused
 	cc.mu.Lock()
-	ci.Reused = cc.nextStreamID > 1
-	ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+	ci.WasIdle = len(cc.streams) == 0 && reused
 	if ci.WasIdle && !cc.lastActive.IsZero() {
 		ci.IdleTime = time.Now().Sub(cc.lastActive)
 	}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 6c4bed3..567eb74 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -19,6 +19,7 @@
 	"net"
 	"net/http"
 	"net/http/httptest"
+	"net/http/httptrace"
 	"net/textproto"
 	"net/url"
 	"os"
@@ -98,6 +99,15 @@
 	if err != nil {
 		t.Fatal(err)
 	}
+	var gotConnCnt int32
+	trace := &httptrace.ClientTrace{
+		GotConn: func(connInfo httptrace.GotConnInfo) {
+			if !connInfo.Reused {
+				atomic.AddInt32(&gotConnCnt, 1)
+			}
+		},
+	}
+	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
 	tr := &Transport{
 		AllowHTTP: true,
 		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -118,6 +128,9 @@
 	if got, want := string(body), "Hello, /foobar, http: true"; got != want {
 		t.Fatalf("response got %v, want %v", got, want)
 	}
+	if got, want := gotConnCnt, int32(1); got != want {
+		t.Errorf("Too many got connections: %d", gotConnCnt)
+	}
 }
 
 func TestTransport(t *testing.T) {
@@ -244,6 +257,14 @@
 		mu    sync.Mutex
 		dials = map[string]int{}
 	)
+	var gotConnCnt int32
+	trace := &httptrace.ClientTrace{
+		GotConn: func(connInfo httptrace.GotConnInfo) {
+			if !connInfo.Reused {
+				atomic.AddInt32(&gotConnCnt, 1)
+			}
+		},
+	}
 	var wg sync.WaitGroup
 	for i := 0; i < 10; i++ {
 		wg.Add(1)
@@ -254,6 +275,7 @@
 				t.Error(err)
 				return
 			}
+			req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
 			res, err := tr.RoundTrip(req)
 			if err != nil {
 				t.Error(err)
@@ -298,6 +320,9 @@
 	}); err != nil {
 		t.Errorf("State of pool after CloseIdleConnections: %v", err)
 	}
+	if got, want := gotConnCnt, int32(1); got != want {
+		t.Errorf("Too many got connections: %d", gotConnCnt)
+	}
 }
 
 func retry(tries int, delay time.Duration, fn func() error) error {