net/http: fix race between dialing and canceling

In the brief window between getConn and persistConn.roundTrip,
a cancel could end up going missing.

Fix by making it possible to inspect if a cancel function was cleared
and checking if we were canceled before entering roundTrip.

Fixes #10511

Change-Id: If6513e63fbc2edb703e36d6356ccc95a1dc33144
Reviewed-on: https://go-review.googlesource.com/9181
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 69757bd..b656aa9 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -82,6 +82,10 @@
 	testHookPersistConnClosedGotRes = f
 }
 
+func SetEnterRoundTripHook(f func()) {
+	testHookEnterRoundTrip = f
+}
+
 func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
 	f := func() <-chan time.Time {
 		return ch
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index b754472..e31ae93 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -475,6 +475,25 @@
 	}
 }
 
+// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
+// for the request, we don't set the function and return false.
+// Since CancelRequest will clear the canceler, we can use the return value to detect if
+// the request was canceled since the last setReqCancel call.
+func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+	t.reqMu.Lock()
+	defer t.reqMu.Unlock()
+	_, ok := t.reqCanceler[r]
+	if !ok {
+		return false
+	}
+	if fn != nil {
+		t.reqCanceler[r] = fn
+	} else {
+		delete(t.reqCanceler, r)
+	}
+	return true
+}
+
 func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
 	if t.Dial != nil {
 		return t.Dial(network, addr)
@@ -491,6 +510,10 @@
 // is ready to write requests to.
 func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
 	if pc := t.getIdleConn(cm); pc != nil {
+		// set request canceler to some non-nil function so we
+		// can detect whether it was cleared between now and when
+		// we enter roundTrip
+		t.setReqCanceler(req, func() {})
 		return pc, nil
 	}
 
@@ -1063,10 +1086,20 @@
 var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
 var errRequestCanceled = errors.New("net/http: request canceled")
 
-var testHookPersistConnClosedGotRes func() // nil except for tests
+// nil except for tests
+var (
+	testHookPersistConnClosedGotRes func()
+	testHookEnterRoundTrip          func()
+)
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
-	pc.t.setReqCanceler(req.Request, pc.cancelRequest)
+	if hook := testHookEnterRoundTrip; hook != nil {
+		hook()
+	}
+	if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+		pc.t.putIdleConn(pc)
+		return nil, errRequestCanceled
+	}
 	pc.lk.Lock()
 	pc.numExpectedResponses++
 	headerFn := pc.mutateHeaderFunc
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 2d52f17..d20ba13 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -2400,6 +2400,32 @@
 	res.Body.Close()
 }
 
+func TestTransportDialCancelRace(t *testing.T) {
+	defer afterTest(t)
+
+	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+	defer ts.Close()
+
+	tr := &Transport{}
+	defer tr.CloseIdleConnections()
+
+	req, err := NewRequest("GET", ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	SetEnterRoundTripHook(func() {
+		tr.CancelRequest(req)
+	})
+	defer SetEnterRoundTripHook(nil)
+	res, err := tr.RoundTrip(req)
+	if err != ExportErrRequestCanceled {
+		t.Errorf("expected canceled request error; got %v", err)
+		if err == nil {
+			res.Body.Close()
+		}
+	}
+}
+
 func wantBody(res *http.Response, err error, want string) error {
 	if err != nil {
 		return err