net: prevent cancelation goroutine from adjusting fd timeout after connect

This was previously fixed in https://golang.org/cl/21497 but not enough.

Fixes #16523

Change-Id: I678543a656304c82d654e25e12fb094cd6cc87e8
Reviewed-on: https://go-review.googlesource.com/25330
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/dial_unix_test.go b/src/net/dial_unix_test.go
new file mode 100644
index 0000000..4705254
--- /dev/null
+++ b/src/net/dial_unix_test.go
@@ -0,0 +1,108 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+	"context"
+	"syscall"
+	"testing"
+	"time"
+)
+
+// Issue 16523
+func TestDialContextCancelRace(t *testing.T) {
+	oldConnectFunc := connectFunc
+	oldGetsockoptIntFunc := getsockoptIntFunc
+	oldTestHookCanceledDial := testHookCanceledDial
+	defer func() {
+		connectFunc = oldConnectFunc
+		getsockoptIntFunc = oldGetsockoptIntFunc
+		testHookCanceledDial = oldTestHookCanceledDial
+	}()
+
+	ln, err := newLocalListener("tcp")
+	if err != nil {
+		t.Fatal(err)
+	}
+	listenerDone := make(chan struct{})
+	go func() {
+		defer close(listenerDone)
+		c, err := ln.Accept()
+		if err == nil {
+			c.Close()
+		}
+	}()
+	defer func() { <-listenerDone }()
+	defer ln.Close()
+
+	sawCancel := make(chan bool, 1)
+	testHookCanceledDial = func() {
+		sawCancel <- true
+	}
+
+	ctx, cancelCtx := context.WithCancel(context.Background())
+
+	connectFunc = func(fd int, addr syscall.Sockaddr) error {
+		err := oldConnectFunc(fd, addr)
+		t.Logf("connect(%d, addr) = %v", fd, err)
+		if err == nil {
+			// On some operating systems, localhost
+			// connects _sometimes_ succeed immediately.
+			// Prevent that, so we exercise the code path
+			// we're interested in testing. This seems
+			// harmless. It makes FreeBSD 10.10 work when
+			// run with many iterations. It failed about
+			// half the time previously.
+			return syscall.EINPROGRESS
+		}
+		return err
+	}
+
+	getsockoptIntFunc = func(fd, level, opt int) (val int, err error) {
+		val, err = oldGetsockoptIntFunc(fd, level, opt)
+		t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err)
+		if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 {
+			t.Logf("canceling context")
+
+			// Cancel the context at just the moment which
+			// caused the race in issue 16523.
+			cancelCtx()
+
+			// And wait for the "interrupter" goroutine to
+			// cancel the dial by messing with its write
+			// timeout before returning.
+			select {
+			case <-sawCancel:
+				t.Logf("saw cancel")
+			case <-time.After(5 * time.Second):
+				t.Errorf("didn't see cancel after 5 seconds")
+			}
+		}
+		return
+	}
+
+	var d Dialer
+	c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+	if err == nil {
+		c.Close()
+		t.Fatal("unexpected successful dial; want context canceled error")
+	}
+
+	select {
+	case <-ctx.Done():
+	case <-time.After(5 * time.Second):
+		t.Fatal("expected context to be canceled")
+	}
+
+	oe, ok := err.(*OpError)
+	if !ok || oe.Op != "dial" {
+		t.Fatalf("Dial error = %#v; want dial *OpError", err)
+	}
+	if oe.Err != ctx.Err() {
+		t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err())
+	}
+}
diff --git a/src/net/fd_unix.go b/src/net/fd_unix.go
index 0f80bc7..11dde76 100644
--- a/src/net/fd_unix.go
+++ b/src/net/fd_unix.go
@@ -64,7 +64,7 @@
 	return fd.net + ":" + ls + "->" + rs
 }
 
-func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) {
 	// Do not need to call fd.writeLock here,
 	// because fd is not yet accessible to user,
 	// so no concurrent operations are possible.
@@ -101,21 +101,44 @@
 		defer fd.setWriteDeadline(noDeadline)
 	}
 
-	// Wait for the goroutine converting context.Done into a write timeout
-	// to exist, otherwise our caller might cancel the context and
-	// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
-	done := make(chan bool) // must be unbuffered
-	defer func() { done <- true }()
-	go func() {
-		select {
-		case <-ctx.Done():
-			// Force the runtime's poller to immediately give
-			// up waiting for writability.
-			fd.setWriteDeadline(aLongTimeAgo)
-			<-done
-		case <-done:
-		}
-	}()
+	// Start the "interrupter" goroutine, if this context might be canceled.
+	// (The background context cannot)
+	//
+	// The interrupter goroutine waits for the context to be done and
+	// interrupts the dial (by altering the fd's write deadline, which
+	// wakes up waitWrite).
+	if ctx != context.Background() {
+		// Wait for the interrupter goroutine to exit before returning
+		// from connect.
+		done := make(chan struct{})
+		interruptRes := make(chan error)
+		defer func() {
+			close(done)
+			if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
+				// The interrupter goroutine called setWriteDeadline,
+				// but the connect code below had returned from
+				// waitWrite already and did a successful connect (ret
+				// == nil). Because we've now poisoned the connection
+				// by making it unwritable, don't return a successful
+				// dial. This was issue 16523.
+				ret = ctxErr
+				fd.Close() // prevent a leak
+			}
+		}()
+		go func() {
+			select {
+			case <-ctx.Done():
+				// Force the runtime's poller to immediately give up
+				// waiting for writability, unblocking waitWrite
+				// below.
+				fd.setWriteDeadline(aLongTimeAgo)
+				testHookCanceledDial()
+				interruptRes <- ctx.Err()
+			case <-done:
+				interruptRes <- nil
+			}
+		}()
+	}
 
 	for {
 		// Performing multiple connect system calls on a
diff --git a/src/net/hook_unix.go b/src/net/hook_unix.go
index 361ca59..cf52567 100644
--- a/src/net/hook_unix.go
+++ b/src/net/hook_unix.go
@@ -9,7 +9,8 @@
 import "syscall"
 
 var (
-	testHookDialChannel = func() {} // see golang.org/issue/5349
+	testHookDialChannel  = func() {} // for golang.org/issue/5349
+	testHookCanceledDial = func() {} // for golang.org/issue/16523
 
 	// Placeholders for socket system calls.
 	socketFunc        func(int, int, int) (int, error)         = syscall.Socket