net/http: don't cancel Dials when requests are canceled

Currently, when a Transport creates a new connection for a request,
it uses the request's Context to make the Dial. If a request
times out or is canceled before a Dial completes, the Dial is
canceled.

Change this so that the lifetime of a Dial call is not bound
by the request that originated it.

This change avoids a scenario where a Transport can start and
then cancel many Dial calls in rapid succession:

  - Request starts a Dial.
  - A previous request completes, making its connection available.
  - The new request uses the now-idle connection, and completes.
  - The request Context is canceled, and the Dial is aborted.

Fixes #59017

Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c
Reviewed-on: https://go-review.googlesource.com/c/go/+/576555
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
index 569b58c..33e6946 100644
--- a/src/net/http/client_test.go
+++ b/src/net/http/client_test.go
@@ -1938,21 +1938,25 @@
 	}
 }
 
+type testRoundTripper func(*Request) (*Response, error)
+
+func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
+	return t(req)
+}
+
 func TestClientPropagatesTimeoutToContext(t *testing.T) {
-	errDial := errors.New("not actually dialing")
 	c := &Client{
 		Timeout: 5 * time.Second,
-		Transport: &Transport{
-			DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
-				deadline, ok := ctx.Deadline()
-				if !ok {
-					t.Error("no deadline")
-				} else {
-					t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
-				}
-				return nil, errDial
-			},
-		},
+		Transport: testRoundTripper(func(req *Request) (*Response, error) {
+			ctx := req.Context()
+			deadline, ok := ctx.Deadline()
+			if !ok {
+				t.Error("no deadline")
+			} else {
+				t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
+			}
+			return nil, errors.New("not actually making a request")
+		}),
 	}
 	c.Get("https://example.tld/")
 }
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 8a6f4f1..56ebda1 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -86,6 +86,14 @@
 
 func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
 
+func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
+	orig := testHookProxyConnectTimeout
+	t.Cleanup(func() {
+		testHookProxyConnectTimeout = orig
+	})
+	testHookProxyConnectTimeout = f
+}
+
 func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
 	return &timeoutHandler{
 		handler:     handler,
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index d97298e..e6a97a0 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -108,6 +108,7 @@
 	connsPerHostMu   sync.Mutex
 	connsPerHost     map[connectMethodKey]int
 	connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
+	dialsInProgress  wantConnQueue
 
 	// Proxy specifies a function to return a proxy for a given
 	// Request. If the function returns a non-nil error, the
@@ -807,6 +808,13 @@
 			pconn.close(errCloseIdleConns)
 		}
 	}
+	t.connsPerHostMu.Lock()
+	t.dialsInProgress.all(func(w *wantConn) {
+		if w.cancelCtx != nil && !w.waiting() {
+			w.cancelCtx()
+		}
+	})
+	t.connsPerHostMu.Unlock()
 	if t2 := t.h2transport; t2 != nil {
 		t2.CloseIdleConnections()
 	}
@@ -1116,7 +1124,7 @@
 		t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
 	}
 	q := t.idleConnWait[w.key]
-	q.cleanFront()
+	q.cleanFrontNotWaiting()
 	q.pushBack(w)
 	t.idleConnWait[w.key] = q
 	return false
@@ -1230,10 +1238,11 @@
 	beforeDial func()
 	afterDial  func()
 
-	mu     sync.Mutex       // protects ctx, done and sending of the result
-	ctx    context.Context  // context for dial, cleared after delivered or canceled
-	done   bool             // true after delivered or canceled
-	result chan connOrError // channel to deliver connection or error
+	mu        sync.Mutex      // protects ctx, done and sending of the result
+	ctx       context.Context // context for dial, cleared after delivered or canceled
+	cancelCtx context.CancelFunc
+	done      bool             // true after delivered or canceled
+	result    chan connOrError // channel to deliver connection or error
 }
 
 type connOrError struct {
@@ -1352,9 +1361,9 @@
 	return nil
 }
 
-// cleanFront pops any wantConns that are no longer waiting from the head of the
+// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the
 // queue, reporting whether any were popped.
-func (q *wantConnQueue) cleanFront() (cleaned bool) {
+func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) {
 	for {
 		w := q.peekFront()
 		if w == nil || w.waiting() {
@@ -1365,6 +1374,28 @@
 	}
 }
 
+// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue.
+func (q *wantConnQueue) cleanFrontCanceled() {
+	for {
+		w := q.peekFront()
+		if w == nil || w.cancelCtx != nil {
+			return
+		}
+		q.popFront()
+	}
+}
+
+// all iterates over all wantConns in the queue.
+// The caller must not modify the queue while iterating.
+func (q *wantConnQueue) all(f func(*wantConn)) {
+	for _, w := range q.head[q.headPos:] {
+		f(w)
+	}
+	for _, w := range q.tail {
+		f(w)
+	}
+}
+
 func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
 	if t.DialTLSContext != nil {
 		conn, err = t.DialTLSContext(ctx, network, addr)
@@ -1389,10 +1420,18 @@
 		trace.GetConn(cm.addr())
 	}
 
+	// Detach from the request context's cancellation signal.
+	// The dial should proceed even if the request is canceled,
+	// because a future request may be able to make use of the connection.
+	//
+	// We retain the request context's values.
+	dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx))
+
 	w := &wantConn{
 		cm:         cm,
 		key:        cm.key(),
-		ctx:        ctx,
+		ctx:        dialCtx,
+		cancelCtx:  dialCancel,
 		result:     make(chan connOrError, 1),
 		beforeDial: testHookPrePendingDial,
 		afterDial:  testHookPostPendingDial,
@@ -1470,20 +1509,21 @@
 // Once w receives permission to dial, it will do so in a separate goroutine.
 func (t *Transport) queueForDial(w *wantConn) {
 	w.beforeDial()
-	if t.MaxConnsPerHost <= 0 {
-		go t.dialConnFor(w)
-		return
-	}
 
 	t.connsPerHostMu.Lock()
 	defer t.connsPerHostMu.Unlock()
 
+	if t.MaxConnsPerHost <= 0 {
+		t.startDialConnForLocked(w)
+		return
+	}
+
 	if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
 		if t.connsPerHost == nil {
 			t.connsPerHost = make(map[connectMethodKey]int)
 		}
 		t.connsPerHost[w.key] = n + 1
-		go t.dialConnFor(w)
+		t.startDialConnForLocked(w)
 		return
 	}
 
@@ -1491,11 +1531,24 @@
 		t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
 	}
 	q := t.connsPerHostWait[w.key]
-	q.cleanFront()
+	q.cleanFrontNotWaiting()
 	q.pushBack(w)
 	t.connsPerHostWait[w.key] = q
 }
 
+// startDialConnFor calls dialConn in a new goroutine.
+// t.connsPerHostMu must be held.
+func (t *Transport) startDialConnForLocked(w *wantConn) {
+	t.dialsInProgress.cleanFrontCanceled()
+	t.dialsInProgress.pushBack(w)
+	go func() {
+		t.dialConnFor(w)
+		t.connsPerHostMu.Lock()
+		defer t.connsPerHostMu.Unlock()
+		w.cancelCtx = nil
+	}()
+}
+
 // dialConnFor dials on behalf of w and delivers the result to w.
 // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
 // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
@@ -1545,7 +1598,7 @@
 		for q.len() > 0 {
 			w := q.popFront()
 			if w.waiting() {
-				go t.dialConnFor(w)
+				t.startDialConnForLocked(w)
 				done = true
 				break
 			}
@@ -1626,6 +1679,8 @@
 	RoundTripErr() error
 }
 
+var testHookProxyConnectTimeout = context.WithTimeout
+
 func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
 	pconn = &persistConn{
 		t:             t,
@@ -1742,17 +1797,11 @@
 			Header: hdr,
 		}
 
-		// If there's no done channel (no deadline or cancellation
-		// from the caller possible), at least set some (long)
-		// timeout here. This will make sure we don't block forever
-		// and leak a goroutine if the connection stops replying
-		// after the TCP connect.
-		connectCtx := ctx
-		if ctx.Done() == nil {
-			newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
-			defer cancel()
-			connectCtx = newCtx
-		}
+		// Set a (long) timeout here to make sure we don't block forever
+		// and leak a goroutine if the connection stops replying after
+		// the TCP connect.
+		connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute)
+		defer cancel()
 
 		didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
 		var (
diff --git a/src/net/http/transport_dial_test.go b/src/net/http/transport_dial_test.go
new file mode 100644
index 0000000..39e35ce
--- /dev/null
+++ b/src/net/http/transport_dial_test.go
@@ -0,0 +1,235 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+	"context"
+	"io"
+	"net"
+	"net/http"
+	"net/http/httptrace"
+	"testing"
+)
+
+func TestTransportPoolConnReusePriorConnection(t *testing.T) {
+	dt := newTransportDialTester(t, http1Mode)
+
+	// First request creates a new connection.
+	rt1 := dt.roundTrip()
+	c1 := dt.wantDial()
+	c1.finish(nil)
+	rt1.wantDone(c1)
+	rt1.finish()
+
+	// Second request reuses the first connection.
+	rt2 := dt.roundTrip()
+	rt2.wantDone(c1)
+	rt2.finish()
+}
+
+func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
+	dt := newTransportDialTester(t, http1Mode)
+
+	// First request creates a new connection.
+	rt1 := dt.roundTrip()
+	c1 := dt.wantDial()
+	c1.finish(nil)
+	rt1.wantDone(c1)
+
+	// Second request is made while the first request is still using its connection,
+	// so it goes on a new connection.
+	rt2 := dt.roundTrip()
+	c2 := dt.wantDial()
+	c2.finish(nil)
+	rt2.wantDone(c2)
+}
+
+func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
+	dt := newTransportDialTester(t, http1Mode)
+
+	// First request creates a new connection.
+	rt1 := dt.roundTrip()
+	c1 := dt.wantDial()
+	c1.finish(nil)
+	rt1.wantDone(c1)
+
+	// Second request is made while the first request is still using its connection.
+	// The first connection completes while the second Dial is in progress, so the
+	// second request uses the first connection.
+	rt2 := dt.roundTrip()
+	c2 := dt.wantDial()
+	rt1.finish()
+	rt2.wantDone(c1)
+
+	// This section is a bit overfitted to the current Transport implementation:
+	// A third request starts. We have an in-progress dial that was started by rt2,
+	// but this new request (rt3) is going to ignore it and make a dial of its own.
+	// rt3 will use the first of these dials that completes.
+	rt3 := dt.roundTrip()
+	c3 := dt.wantDial()
+	c2.finish(nil)
+	rt3.wantDone(c2)
+
+	c3.finish(nil)
+}
+
+// A transportDialTester manages a test of a connection's Dials.
+type transportDialTester struct {
+	t   *testing.T
+	cst *clientServerTest
+
+	dials chan *transportDialTesterConn // each new conn is sent to this channel
+
+	roundTripCount int
+	dialCount      int
+}
+
+// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
+type transportDialTesterRoundTrip struct {
+	t *testing.T
+
+	roundTripID int                // distinguishes RoundTrips in logs
+	cancel      context.CancelFunc // cancels the Request context
+	reqBody     io.WriteCloser     // write half of the Request.Body
+	finished    bool
+
+	done chan struct{} // closed when RoundTrip returns:w
+	res  *http.Response
+	err  error
+	conn *transportDialTesterConn
+}
+
+// A transportDialTesterConn is a client connection created by the Transport as
+// part of a dial test.
+type transportDialTesterConn struct {
+	t *testing.T
+
+	connID int        // distinguished Dials in logs
+	ready  chan error // sent on to complete the Dial
+
+	net.Conn
+}
+
+func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
+	t.Helper()
+	dt := &transportDialTester{
+		t:     t,
+		dials: make(chan *transportDialTesterConn),
+	}
+	dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Write response headers when we receive a request.
+		http.NewResponseController(w).EnableFullDuplex()
+		w.WriteHeader(200)
+		http.NewResponseController(w).Flush()
+		// Wait for the client to send the request body,
+		// to synchronize with the rest of the test.
+		io.ReadAll(r.Body)
+	}), func(tr *http.Transport) {
+		tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
+			c := &transportDialTesterConn{
+				t:     t,
+				ready: make(chan error),
+			}
+			// Notify the test that a Dial has started,
+			// and wait for the test to notify us that it should complete.
+			dt.dials <- c
+			if err := <-c.ready; err != nil {
+				return nil, err
+			}
+			nc, err := net.Dial(network, address)
+			if err != nil {
+				return nil, err
+			}
+			// Use the *transportDialTesterConn as the net.Conn,
+			// to let tests associate requests with connections.
+			c.Conn = nc
+			return c, err
+		}
+	})
+	return dt
+}
+
+// roundTrip starts a RoundTrip.
+// It returns immediately, without waiting for the RoundTrip call to complete.
+func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
+	dt.t.Helper()
+	ctx, cancel := context.WithCancel(context.Background())
+	pr, pw := io.Pipe()
+	rt := &transportDialTesterRoundTrip{
+		t:           dt.t,
+		roundTripID: dt.roundTripCount,
+		done:        make(chan struct{}),
+		reqBody:     pw,
+		cancel:      cancel,
+	}
+	dt.roundTripCount++
+	dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
+	dt.t.Cleanup(func() {
+		rt.cancel()
+		rt.finish()
+	})
+	go func() {
+		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+			GotConn: func(info httptrace.GotConnInfo) {
+				rt.conn = info.Conn.(*transportDialTesterConn)
+			},
+		})
+		req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
+		req.Header.Set("Content-Type", "text/plain")
+		rt.res, rt.err = dt.cst.tr.RoundTrip(req)
+		dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
+		close(rt.done)
+	}()
+	return rt
+}
+
+// wantDone indicates that a RoundTrip should have returned.
+func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
+	rt.t.Helper()
+	<-rt.done
+	if rt.err != nil {
+		rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
+	}
+	if rt.conn != c {
+		rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
+	}
+}
+
+// finish completes a RoundTrip by sending the request body, consuming the response body,
+// and closing the response body.
+func (rt *transportDialTesterRoundTrip) finish() {
+	rt.t.Helper()
+
+	if rt.finished {
+		return
+	}
+	rt.finished = true
+
+	<-rt.done
+
+	if rt.err != nil {
+		return
+	}
+	rt.reqBody.Close()
+	io.ReadAll(rt.res.Body)
+	rt.res.Body.Close()
+	rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
+}
+
+// wantDial waits for the Transport to start a Dial.
+func (dt *transportDialTester) wantDial() *transportDialTesterConn {
+	c := <-dt.dials
+	c.connID = dt.dialCount
+	dt.dialCount++
+	dt.t.Logf("Dial %v: started", c.connID)
+	return c
+}
+
+// finish completes a Dial.
+func (c *transportDialTesterConn) finish(err error) {
+	c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
+	c.ready <- err
+	close(c.ready)
+}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index e8baa48..fa147e1 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -1626,11 +1626,20 @@
 // Issue 28012: verify that the Transport closes its TCP connection to http proxies
 // when they're slow to reply to HTTPS CONNECT responses.
 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
-	setParallel(t)
-	defer afterTest(t)
+	cancelc := make(chan struct{})
+	SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
+		ctx, cancel := context.WithCancel(ctx)
+		go func() {
+			select {
+			case <-cancelc:
+			case <-ctx.Done():
+			}
+			cancel()
+		}()
+		return ctx, cancel
+	})
 
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
+	defer afterTest(t)
 
 	ln := newLocalListener(t)
 	defer ln.Close()
@@ -1658,7 +1667,7 @@
 		// Now hang and never write a response; instead, cancel the request and wait
 		// for the client to close.
 		// (Prior to Issue 28012 being fixed, we never closed.)
-		cancel()
+		close(cancelc)
 		var buf [1]byte
 		_, err = br.Read(buf[:])
 		if err != io.EOF {
@@ -1674,7 +1683,7 @@
 			},
 		},
 	}
-	req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
+	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -3927,9 +3936,13 @@
 
 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
 func testTransportDialContext(t *testing.T, mode testMode) {
-	var mu sync.Mutex // guards following
-	var gotReq bool
-	var receivedContext context.Context
+	ctxKey := "some-key"
+	ctxValue := "some-value"
+	var (
+		mu          sync.Mutex // guards following
+		gotReq      bool
+		gotCtxValue any
+	)
 
 	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
 		mu.Lock()
@@ -3939,7 +3952,7 @@
 	c := ts.Client()
 	c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
 		mu.Lock()
-		receivedContext = ctx
+		gotCtxValue = ctx.Value(ctxKey)
 		mu.Unlock()
 		return net.Dial(netw, addr)
 	}
@@ -3948,7 +3961,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	ctx := context.WithValue(context.Background(), "some-key", "some-value")
+	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
 	res, err := c.Do(req.WithContext(ctx))
 	if err != nil {
 		t.Fatal(err)
@@ -3958,8 +3971,8 @@
 	if !gotReq {
 		t.Error("didn't get request")
 	}
-	if receivedContext != ctx {
-		t.Error("didn't receive correct context")
+	if got, want := gotCtxValue, ctxValue; got != want {
+		t.Errorf("got context with value %v, want %v", got, want)
 	}
 }
 
@@ -3967,9 +3980,13 @@
 	run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
 }
 func testTransportDialTLSContext(t *testing.T, mode testMode) {
-	var mu sync.Mutex // guards following
-	var gotReq bool
-	var receivedContext context.Context
+	ctxKey := "some-key"
+	ctxValue := "some-value"
+	var (
+		mu          sync.Mutex // guards following
+		gotReq      bool
+		gotCtxValue any
+	)
 
 	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
 		mu.Lock()
@@ -3979,7 +3996,7 @@
 	c := ts.Client()
 	c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
 		mu.Lock()
-		receivedContext = ctx
+		gotCtxValue = ctx.Value(ctxKey)
 		mu.Unlock()
 		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
 		if err != nil {
@@ -3992,7 +4009,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	ctx := context.WithValue(context.Background(), "some-key", "some-value")
+	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
 	res, err := c.Do(req.WithContext(ctx))
 	if err != nil {
 		t.Fatal(err)
@@ -4002,8 +4019,8 @@
 	if !gotReq {
 		t.Error("didn't get request")
 	}
-	if receivedContext != ctx {
-		t.Error("didn't receive correct context")
+	if got, want := gotCtxValue, ctxValue; got != want {
+		t.Errorf("got context with value %v, want %v", got, want)
 	}
 }