net/http: don't cancel Dials when requests are canceled
Currently, when a Transport creates a new connection for a request,
it uses the request's Context to make the Dial. If a request
times out or is canceled before a Dial completes, the Dial is
canceled.
Change this so that the lifetime of a Dial call is not bound
by the request that originated it.
This change avoids a scenario where a Transport can start and
then cancel many Dial calls in rapid succession:
- Request starts a Dial.
- A previous request completes, making its connection available.
- The new request uses the now-idle connection, and completes.
- The request Context is canceled, and the Dial is aborted.
Fixes #59017
Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c
Reviewed-on: https://go-review.googlesource.com/c/go/+/576555
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
index 569b58c..33e6946 100644
--- a/src/net/http/client_test.go
+++ b/src/net/http/client_test.go
@@ -1938,21 +1938,25 @@
}
}
+type testRoundTripper func(*Request) (*Response, error)
+
+func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
+ return t(req)
+}
+
func TestClientPropagatesTimeoutToContext(t *testing.T) {
- errDial := errors.New("not actually dialing")
c := &Client{
Timeout: 5 * time.Second,
- Transport: &Transport{
- DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
- deadline, ok := ctx.Deadline()
- if !ok {
- t.Error("no deadline")
- } else {
- t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
- }
- return nil, errDial
- },
- },
+ Transport: testRoundTripper(func(req *Request) (*Response, error) {
+ ctx := req.Context()
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ t.Error("no deadline")
+ } else {
+ t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
+ }
+ return nil, errors.New("not actually making a request")
+ }),
}
c.Get("https://example.tld/")
}
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 8a6f4f1..56ebda1 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -86,6 +86,14 @@
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
+func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
+ orig := testHookProxyConnectTimeout
+ t.Cleanup(func() {
+ testHookProxyConnectTimeout = orig
+ })
+ testHookProxyConnectTimeout = f
+}
+
func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
return &timeoutHandler{
handler: handler,
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index d97298e..e6a97a0 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -108,6 +108,7 @@
connsPerHostMu sync.Mutex
connsPerHost map[connectMethodKey]int
connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
+ dialsInProgress wantConnQueue
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
@@ -807,6 +808,13 @@
pconn.close(errCloseIdleConns)
}
}
+ t.connsPerHostMu.Lock()
+ t.dialsInProgress.all(func(w *wantConn) {
+ if w.cancelCtx != nil && !w.waiting() {
+ w.cancelCtx()
+ }
+ })
+ t.connsPerHostMu.Unlock()
if t2 := t.h2transport; t2 != nil {
t2.CloseIdleConnections()
}
@@ -1116,7 +1124,7 @@
t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.idleConnWait[w.key]
- q.cleanFront()
+ q.cleanFrontNotWaiting()
q.pushBack(w)
t.idleConnWait[w.key] = q
return false
@@ -1230,10 +1238,11 @@
beforeDial func()
afterDial func()
- mu sync.Mutex // protects ctx, done and sending of the result
- ctx context.Context // context for dial, cleared after delivered or canceled
- done bool // true after delivered or canceled
- result chan connOrError // channel to deliver connection or error
+ mu sync.Mutex // protects ctx, done and sending of the result
+ ctx context.Context // context for dial, cleared after delivered or canceled
+ cancelCtx context.CancelFunc
+ done bool // true after delivered or canceled
+ result chan connOrError // channel to deliver connection or error
}
type connOrError struct {
@@ -1352,9 +1361,9 @@
return nil
}
-// cleanFront pops any wantConns that are no longer waiting from the head of the
+// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
-func (q *wantConnQueue) cleanFront() (cleaned bool) {
+func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
@@ -1365,6 +1374,28 @@
}
}
+// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue.
+func (q *wantConnQueue) cleanFrontCanceled() {
+ for {
+ w := q.peekFront()
+ if w == nil || w.cancelCtx != nil {
+ return
+ }
+ q.popFront()
+ }
+}
+
+// all iterates over all wantConns in the queue.
+// The caller must not modify the queue while iterating.
+func (q *wantConnQueue) all(f func(*wantConn)) {
+ for _, w := range q.head[q.headPos:] {
+ f(w)
+ }
+ for _, w := range q.tail {
+ f(w)
+ }
+}
+
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if t.DialTLSContext != nil {
conn, err = t.DialTLSContext(ctx, network, addr)
@@ -1389,10 +1420,18 @@
trace.GetConn(cm.addr())
}
+ // Detach from the request context's cancellation signal.
+ // The dial should proceed even if the request is canceled,
+ // because a future request may be able to make use of the connection.
+ //
+ // We retain the request context's values.
+ dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx))
+
w := &wantConn{
cm: cm,
key: cm.key(),
- ctx: ctx,
+ ctx: dialCtx,
+ cancelCtx: dialCancel,
result: make(chan connOrError, 1),
beforeDial: testHookPrePendingDial,
afterDial: testHookPostPendingDial,
@@ -1470,20 +1509,21 @@
// Once w receives permission to dial, it will do so in a separate goroutine.
func (t *Transport) queueForDial(w *wantConn) {
w.beforeDial()
- if t.MaxConnsPerHost <= 0 {
- go t.dialConnFor(w)
- return
- }
t.connsPerHostMu.Lock()
defer t.connsPerHostMu.Unlock()
+ if t.MaxConnsPerHost <= 0 {
+ t.startDialConnForLocked(w)
+ return
+ }
+
if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
if t.connsPerHost == nil {
t.connsPerHost = make(map[connectMethodKey]int)
}
t.connsPerHost[w.key] = n + 1
- go t.dialConnFor(w)
+ t.startDialConnForLocked(w)
return
}
@@ -1491,11 +1531,24 @@
t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.connsPerHostWait[w.key]
- q.cleanFront()
+ q.cleanFrontNotWaiting()
q.pushBack(w)
t.connsPerHostWait[w.key] = q
}
+// startDialConnFor calls dialConn in a new goroutine.
+// t.connsPerHostMu must be held.
+func (t *Transport) startDialConnForLocked(w *wantConn) {
+ t.dialsInProgress.cleanFrontCanceled()
+ t.dialsInProgress.pushBack(w)
+ go func() {
+ t.dialConnFor(w)
+ t.connsPerHostMu.Lock()
+ defer t.connsPerHostMu.Unlock()
+ w.cancelCtx = nil
+ }()
+}
+
// dialConnFor dials on behalf of w and delivers the result to w.
// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
@@ -1545,7 +1598,7 @@
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
- go t.dialConnFor(w)
+ t.startDialConnForLocked(w)
done = true
break
}
@@ -1626,6 +1679,8 @@
RoundTripErr() error
}
+var testHookProxyConnectTimeout = context.WithTimeout
+
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
pconn = &persistConn{
t: t,
@@ -1742,17 +1797,11 @@
Header: hdr,
}
- // If there's no done channel (no deadline or cancellation
- // from the caller possible), at least set some (long)
- // timeout here. This will make sure we don't block forever
- // and leak a goroutine if the connection stops replying
- // after the TCP connect.
- connectCtx := ctx
- if ctx.Done() == nil {
- newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
- defer cancel()
- connectCtx = newCtx
- }
+ // Set a (long) timeout here to make sure we don't block forever
+ // and leak a goroutine if the connection stops replying after
+ // the TCP connect.
+ connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute)
+ defer cancel()
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
var (
diff --git a/src/net/http/transport_dial_test.go b/src/net/http/transport_dial_test.go
new file mode 100644
index 0000000..39e35ce
--- /dev/null
+++ b/src/net/http/transport_dial_test.go
@@ -0,0 +1,235 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "context"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptrace"
+ "testing"
+)
+
+func TestTransportPoolConnReusePriorConnection(t *testing.T) {
+ dt := newTransportDialTester(t, http1Mode)
+
+ // First request creates a new connection.
+ rt1 := dt.roundTrip()
+ c1 := dt.wantDial()
+ c1.finish(nil)
+ rt1.wantDone(c1)
+ rt1.finish()
+
+ // Second request reuses the first connection.
+ rt2 := dt.roundTrip()
+ rt2.wantDone(c1)
+ rt2.finish()
+}
+
+func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
+ dt := newTransportDialTester(t, http1Mode)
+
+ // First request creates a new connection.
+ rt1 := dt.roundTrip()
+ c1 := dt.wantDial()
+ c1.finish(nil)
+ rt1.wantDone(c1)
+
+ // Second request is made while the first request is still using its connection,
+ // so it goes on a new connection.
+ rt2 := dt.roundTrip()
+ c2 := dt.wantDial()
+ c2.finish(nil)
+ rt2.wantDone(c2)
+}
+
+func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
+ dt := newTransportDialTester(t, http1Mode)
+
+ // First request creates a new connection.
+ rt1 := dt.roundTrip()
+ c1 := dt.wantDial()
+ c1.finish(nil)
+ rt1.wantDone(c1)
+
+ // Second request is made while the first request is still using its connection.
+ // The first connection completes while the second Dial is in progress, so the
+ // second request uses the first connection.
+ rt2 := dt.roundTrip()
+ c2 := dt.wantDial()
+ rt1.finish()
+ rt2.wantDone(c1)
+
+ // This section is a bit overfitted to the current Transport implementation:
+ // A third request starts. We have an in-progress dial that was started by rt2,
+ // but this new request (rt3) is going to ignore it and make a dial of its own.
+ // rt3 will use the first of these dials that completes.
+ rt3 := dt.roundTrip()
+ c3 := dt.wantDial()
+ c2.finish(nil)
+ rt3.wantDone(c2)
+
+ c3.finish(nil)
+}
+
+// A transportDialTester manages a test of a connection's Dials.
+type transportDialTester struct {
+ t *testing.T
+ cst *clientServerTest
+
+ dials chan *transportDialTesterConn // each new conn is sent to this channel
+
+ roundTripCount int
+ dialCount int
+}
+
+// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
+type transportDialTesterRoundTrip struct {
+ t *testing.T
+
+ roundTripID int // distinguishes RoundTrips in logs
+ cancel context.CancelFunc // cancels the Request context
+ reqBody io.WriteCloser // write half of the Request.Body
+ finished bool
+
+ done chan struct{} // closed when RoundTrip returns:w
+ res *http.Response
+ err error
+ conn *transportDialTesterConn
+}
+
+// A transportDialTesterConn is a client connection created by the Transport as
+// part of a dial test.
+type transportDialTesterConn struct {
+ t *testing.T
+
+ connID int // distinguished Dials in logs
+ ready chan error // sent on to complete the Dial
+
+ net.Conn
+}
+
+func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
+ t.Helper()
+ dt := &transportDialTester{
+ t: t,
+ dials: make(chan *transportDialTesterConn),
+ }
+ dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Write response headers when we receive a request.
+ http.NewResponseController(w).EnableFullDuplex()
+ w.WriteHeader(200)
+ http.NewResponseController(w).Flush()
+ // Wait for the client to send the request body,
+ // to synchronize with the rest of the test.
+ io.ReadAll(r.Body)
+ }), func(tr *http.Transport) {
+ tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
+ c := &transportDialTesterConn{
+ t: t,
+ ready: make(chan error),
+ }
+ // Notify the test that a Dial has started,
+ // and wait for the test to notify us that it should complete.
+ dt.dials <- c
+ if err := <-c.ready; err != nil {
+ return nil, err
+ }
+ nc, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ // Use the *transportDialTesterConn as the net.Conn,
+ // to let tests associate requests with connections.
+ c.Conn = nc
+ return c, err
+ }
+ })
+ return dt
+}
+
+// roundTrip starts a RoundTrip.
+// It returns immediately, without waiting for the RoundTrip call to complete.
+func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
+ dt.t.Helper()
+ ctx, cancel := context.WithCancel(context.Background())
+ pr, pw := io.Pipe()
+ rt := &transportDialTesterRoundTrip{
+ t: dt.t,
+ roundTripID: dt.roundTripCount,
+ done: make(chan struct{}),
+ reqBody: pw,
+ cancel: cancel,
+ }
+ dt.roundTripCount++
+ dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
+ dt.t.Cleanup(func() {
+ rt.cancel()
+ rt.finish()
+ })
+ go func() {
+ ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ GotConn: func(info httptrace.GotConnInfo) {
+ rt.conn = info.Conn.(*transportDialTesterConn)
+ },
+ })
+ req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
+ req.Header.Set("Content-Type", "text/plain")
+ rt.res, rt.err = dt.cst.tr.RoundTrip(req)
+ dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
+ close(rt.done)
+ }()
+ return rt
+}
+
+// wantDone indicates that a RoundTrip should have returned.
+func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
+ rt.t.Helper()
+ <-rt.done
+ if rt.err != nil {
+ rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
+ }
+ if rt.conn != c {
+ rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
+ }
+}
+
+// finish completes a RoundTrip by sending the request body, consuming the response body,
+// and closing the response body.
+func (rt *transportDialTesterRoundTrip) finish() {
+ rt.t.Helper()
+
+ if rt.finished {
+ return
+ }
+ rt.finished = true
+
+ <-rt.done
+
+ if rt.err != nil {
+ return
+ }
+ rt.reqBody.Close()
+ io.ReadAll(rt.res.Body)
+ rt.res.Body.Close()
+ rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
+}
+
+// wantDial waits for the Transport to start a Dial.
+func (dt *transportDialTester) wantDial() *transportDialTesterConn {
+ c := <-dt.dials
+ c.connID = dt.dialCount
+ dt.dialCount++
+ dt.t.Logf("Dial %v: started", c.connID)
+ return c
+}
+
+// finish completes a Dial.
+func (c *transportDialTesterConn) finish(err error) {
+ c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
+ c.ready <- err
+ close(c.ready)
+}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index e8baa48..fa147e1 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -1626,11 +1626,20 @@
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
// when they're slow to reply to HTTPS CONNECT responses.
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ cancelc := make(chan struct{})
+ SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ select {
+ case <-cancelc:
+ case <-ctx.Done():
+ }
+ cancel()
+ }()
+ return ctx, cancel
+ })
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ defer afterTest(t)
ln := newLocalListener(t)
defer ln.Close()
@@ -1658,7 +1667,7 @@
// Now hang and never write a response; instead, cancel the request and wait
// for the client to close.
// (Prior to Issue 28012 being fixed, we never closed.)
- cancel()
+ close(cancelc)
var buf [1]byte
_, err = br.Read(buf[:])
if err != io.EOF {
@@ -1674,7 +1683,7 @@
},
},
}
- req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
+ req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
if err != nil {
t.Fatal(err)
}
@@ -3927,9 +3936,13 @@
func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
func testTransportDialContext(t *testing.T, mode testMode) {
- var mu sync.Mutex // guards following
- var gotReq bool
- var receivedContext context.Context
+ ctxKey := "some-key"
+ ctxValue := "some-value"
+ var (
+ mu sync.Mutex // guards following
+ gotReq bool
+ gotCtxValue any
+ )
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
@@ -3939,7 +3952,7 @@
c := ts.Client()
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
- receivedContext = ctx
+ gotCtxValue = ctx.Value(ctxKey)
mu.Unlock()
return net.Dial(netw, addr)
}
@@ -3948,7 +3961,7 @@
if err != nil {
t.Fatal(err)
}
- ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
@@ -3958,8 +3971,8 @@
if !gotReq {
t.Error("didn't get request")
}
- if receivedContext != ctx {
- t.Error("didn't receive correct context")
+ if got, want := gotCtxValue, ctxValue; got != want {
+ t.Errorf("got context with value %v, want %v", got, want)
}
}
@@ -3967,9 +3980,13 @@
run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
}
func testTransportDialTLSContext(t *testing.T, mode testMode) {
- var mu sync.Mutex // guards following
- var gotReq bool
- var receivedContext context.Context
+ ctxKey := "some-key"
+ ctxValue := "some-value"
+ var (
+ mu sync.Mutex // guards following
+ gotReq bool
+ gotCtxValue any
+ )
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
@@ -3979,7 +3996,7 @@
c := ts.Client()
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
- receivedContext = ctx
+ gotCtxValue = ctx.Value(ctxKey)
mu.Unlock()
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
if err != nil {
@@ -3992,7 +4009,7 @@
if err != nil {
t.Fatal(err)
}
- ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
@@ -4002,8 +4019,8 @@
if !gotReq {
t.Error("didn't get request")
}
- if receivedContext != ctx {
- t.Error("didn't receive correct context")
+ if got, want := gotCtxValue, ctxValue; got != want {
+ t.Errorf("got context with value %v, want %v", got, want)
}
}