net/http: fix race between dialing and canceling
In the brief window between getConn and persistConn.roundTrip,
a cancel could end up going missing.
Fix by making it possible to inspect if a cancel function was cleared
and checking if we were canceled before entering roundTrip.
Fixes #10511
Change-Id: If6513e63fbc2edb703e36d6356ccc95a1dc33144
Reviewed-on: https://go-review.googlesource.com/9181
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 69757bd..b656aa9 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -82,6 +82,10 @@
testHookPersistConnClosedGotRes = f
}
+func SetEnterRoundTripHook(f func()) {
+ testHookEnterRoundTrip = f
+}
+
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
f := func() <-chan time.Time {
return ch
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index b754472..e31ae93 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -475,6 +475,25 @@
}
}
+// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
+// for the request, we don't set the function and return false.
+// Since CancelRequest will clear the canceler, we can use the return value to detect if
+// the request was canceled since the last setReqCancel call.
+func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ _, ok := t.reqCanceler[r]
+ if !ok {
+ return false
+ }
+ if fn != nil {
+ t.reqCanceler[r] = fn
+ } else {
+ delete(t.reqCanceler, r)
+ }
+ return true
+}
+
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
if t.Dial != nil {
return t.Dial(network, addr)
@@ -491,6 +510,10 @@
// is ready to write requests to.
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
if pc := t.getIdleConn(cm); pc != nil {
+ // set request canceler to some non-nil function so we
+ // can detect whether it was cleared between now and when
+ // we enter roundTrip
+ t.setReqCanceler(req, func() {})
return pc, nil
}
@@ -1063,10 +1086,20 @@
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
-var testHookPersistConnClosedGotRes func() // nil except for tests
+// nil except for tests
+var (
+ testHookPersistConnClosedGotRes func()
+ testHookEnterRoundTrip func()
+)
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- pc.t.setReqCanceler(req.Request, pc.cancelRequest)
+ if hook := testHookEnterRoundTrip; hook != nil {
+ hook()
+ }
+ if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+ pc.t.putIdleConn(pc)
+ return nil, errRequestCanceled
+ }
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 2d52f17..d20ba13 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -2400,6 +2400,32 @@
res.Body.Close()
}
+func TestTransportDialCancelRace(t *testing.T) {
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ SetEnterRoundTripHook(func() {
+ tr.CancelRequest(req)
+ })
+ defer SetEnterRoundTripHook(nil)
+ res, err := tr.RoundTrip(req)
+ if err != ExportErrRequestCanceled {
+ t.Errorf("expected canceled request error; got %v", err)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+}
+
func wantBody(res *http.Response, err error, want string) error {
if err != nil {
return err