net: fix netFD.Close races

Fixes #271.
Fixes #321.

R=rsc, agl, cw
CC=golang-dev
https://golang.org/cl/163052
diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go
index f134b0f..e1592eb 100644
--- a/src/pkg/net/fd.go
+++ b/src/pkg/net/fd.go
@@ -15,14 +15,18 @@
 
 // Network file descriptor.
 type netFD struct {
+	// locking/lifetime of sysfd
+	sysmu	sync.Mutex;
+	sysref	int;
+	closing	bool;
+
 	// immutable until Close
-	fd	int;
+	sysfd	int;
 	family	int;
 	proto	int;
-	file	*os.File;
+	sysfile	*os.File;
 	cr	chan *netFD;
 	cw	chan *netFD;
-	cc	chan *netFD;
 	net	string;
 	laddr	Addr;
 	raddr	Addr;
@@ -68,13 +72,13 @@
 // channel will be empty for the next process's request.  A larger buffer
 // might help batch requests.
 //
-// In order to prevent race conditions, pollServer has an additional cc channel
-// that receives fds to be closed. pollServer doesn't make the close system
-// call, it just sets fd.file = nil and fd.fd = -1. Because of this, pollServer
-// is always in sync with the kernel's view of a given descriptor.
+// To avoid races in closing, all fd operations are locked and
+// refcounted. when netFD.Close() is called, it calls syscall.Shutdown
+// and sets a closing flag. Only when the last reference is removed
+// will the fd be closed.
 
 type pollServer struct {
-	cr, cw, cc	chan *netFD;	// buffered >= 1
+	cr, cw		chan *netFD;	// buffered >= 1
 	pr, pw		*os.File;
 	pending		map[int]*netFD;
 	poll		*pollster;	// low-level OS hooks
@@ -85,7 +89,6 @@
 	s = new(pollServer);
 	s.cr = make(chan *netFD, 1);
 	s.cw = make(chan *netFD, 1);
-	s.cc = make(chan *netFD, 1);
 	if s.pr, s.pw, err = os.Pipe(); err != nil {
 		return nil, err
 	}
@@ -114,16 +117,7 @@
 }
 
 func (s *pollServer) AddFD(fd *netFD, mode int) {
-	// This check verifies that the underlying file descriptor hasn't been
-	// closed in the mean time. Any time a netFD is closed, the closing
-	// goroutine makes a round trip to the pollServer which sets file = nil
-	// and fd = -1. The goroutine then closes the actual file descriptor.
-	// Thus fd.fd mirrors the kernel's view of the file descriptor.
-
-	// TODO(rsc,agl): There is still a race in Read and Write,
-	// because they optimistically try to use the fd and don't
-	// call into the PollServer unless they get EAGAIN.
-	intfd := fd.fd;
+	intfd := fd.sysfd;
 	if intfd < 0 {
 		// fd closed underfoot
 		if mode == 'r' {
@@ -213,10 +207,10 @@
 			if t <= now {
 				s.pending[key] = nil, false;
 				if mode == 'r' {
-					s.poll.DelFD(fd.fd, mode);
+					s.poll.DelFD(fd.sysfd, mode);
 					fd.rdeadline = -1;
 				} else {
-					s.poll.DelFD(fd.fd, mode);
+					s.poll.DelFD(fd.sysfd, mode);
 					fd.wdeadline = -1;
 				}
 				s.WakeFD(fd, mode);
@@ -254,7 +248,6 @@
 			for nn, _ := s.pr.Read(&scratch); nn > 0; {
 				nn, _ = s.pr.Read(&scratch)
 			}
-
 			// Read from channels
 			for fd, ok := <-s.cr; ok; fd, ok = <-s.cr {
 				s.AddFD(fd, 'r')
@@ -262,11 +255,6 @@
 			for fd, ok := <-s.cw; ok; fd, ok = <-s.cw {
 				s.AddFD(fd, 'w')
 			}
-			for fd, ok := <-s.cc; ok; fd, ok = <-s.cc {
-				fd.file = nil;
-				fd.fd = -1;
-				fd.cc <- fd;
-			}
 		} else {
 			netfd := s.LookupFD(fd, mode);
 			if netfd == nil {
@@ -294,12 +282,6 @@
 	<-fd.cw;
 }
 
-func (s *pollServer) WaitCloseAck(fd *netFD) {
-	s.cc <- fd;
-	s.Wakeup();
-	<-fd.cc;
-}
-
 // Network FD methods.
 // All the network FDs use a single pollServer.
 
@@ -319,7 +301,7 @@
 		return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)}
 	}
 	f = &netFD{
-		fd: fd,
+		sysfd: fd,
 		family: family,
 		proto: proto,
 		net: net,
@@ -333,13 +315,37 @@
 	if raddr != nil {
 		rs = raddr.String()
 	}
-	f.file = os.NewFile(fd, net+":"+ls+"->"+rs);
+	f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs);
 	f.cr = make(chan *netFD, 1);
 	f.cw = make(chan *netFD, 1);
-	f.cc = make(chan *netFD, 1);
 	return f, nil;
 }
 
+// Add a reference to this fd.
+func (fd *netFD) incref() {
+	fd.sysmu.Lock();
+	fd.sysref++;
+	fd.sysmu.Unlock();
+}
+
+// Remove a reference to this FD and close if we've been asked to do so (and
+// there are no references left.
+func (fd *netFD) decref() {
+	fd.sysmu.Lock();
+	fd.sysref--;
+	if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
+		// In case the user has set linger, switch to blocking mode so
+		// the close blocks.  As long as this doesn't happen often, we
+		// can handle the extra OS processes.  Otherwise we'll need to
+		// use the pollserver for Close too.  Sigh.
+		syscall.SetNonblock(fd.sysfd, false);
+		fd.sysfile.Close();
+		fd.sysfile = nil;
+		fd.sysfd = -1;
+	}
+	fd.sysmu.Unlock();
+}
+
 func isEAGAIN(e os.Error) bool {
 	if e1, ok := e.(*os.PathError); ok {
 		return e1.Error == os.EAGAIN
@@ -348,36 +354,32 @@
 }
 
 func (fd *netFD) Close() os.Error {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return os.EINVAL
 	}
 
-	// In case the user has set linger,
-	// switch to blocking mode so the close blocks.
-	// As long as this doesn't happen often,
-	// we can handle the extra OS processes.
-	// Otherwise we'll need to use the pollserver
-	// for Close too.  Sigh.
-	syscall.SetNonblock(fd.file.Fd(), false);
-
-	f := fd.file;
-	pollserver.WaitCloseAck(fd);
-	return f.Close();
+	fd.incref();
+	syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR);
+	fd.closing = true;
+	fd.decref();
+	return nil;
 }
 
 func (fd *netFD) Read(p []byte) (n int, err os.Error) {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return 0, os.EINVAL
 	}
 	fd.rio.Lock();
 	defer fd.rio.Unlock();
+	fd.incref();
+	defer fd.decref();
 	if fd.rdeadline_delta > 0 {
 		fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
 	} else {
 		fd.rdeadline = 0
 	}
 	for {
-		n, err = fd.file.Read(p);
+		n, err = fd.sysfile.Read(p);
 		if isEAGAIN(err) && fd.rdeadline >= 0 {
 			pollserver.WaitRead(fd);
 			continue;
@@ -388,11 +390,13 @@
 }
 
 func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return 0, nil, os.EINVAL
 	}
 	fd.rio.Lock();
 	defer fd.rio.Unlock();
+	fd.incref();
+	defer fd.decref();
 	if fd.rdeadline_delta > 0 {
 		fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
 	} else {
@@ -400,14 +404,14 @@
 	}
 	for {
 		var errno int;
-		n, sa, errno = syscall.Recvfrom(fd.fd, p, 0);
+		n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0);
 		if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
 			pollserver.WaitRead(fd);
 			continue;
 		}
 		if errno != 0 {
 			n = 0;
-			err = &os.PathError{"recvfrom", fd.file.Name(), os.Errno(errno)};
+			err = &os.PathError{"recvfrom", fd.sysfile.Name(), os.Errno(errno)};
 		}
 		break;
 	}
@@ -415,11 +419,13 @@
 }
 
 func (fd *netFD) Write(p []byte) (n int, err os.Error) {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return 0, os.EINVAL
 	}
 	fd.wio.Lock();
 	defer fd.wio.Unlock();
+	fd.incref();
+	defer fd.decref();
 	if fd.wdeadline_delta > 0 {
 		fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
 	} else {
@@ -428,7 +434,7 @@
 	err = nil;
 	nn := 0;
 	for nn < len(p) {
-		n, err = fd.file.Write(p[nn:]);
+		n, err = fd.sysfile.Write(p[nn:]);
 		if n > 0 {
 			nn += n
 		}
@@ -447,11 +453,13 @@
 }
 
 func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return 0, os.EINVAL
 	}
 	fd.wio.Lock();
 	defer fd.wio.Unlock();
+	fd.incref();
+	defer fd.decref();
 	if fd.wdeadline_delta > 0 {
 		fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
 	} else {
@@ -459,13 +467,13 @@
 	}
 	err = nil;
 	for {
-		errno := syscall.Sendto(fd.fd, p, 0, sa);
+		errno := syscall.Sendto(fd.sysfd, p, 0, sa);
 		if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
 			pollserver.WaitWrite(fd);
 			continue;
 		}
 		if errno != 0 {
-			err = &os.PathError{"sendto", fd.file.Name(), os.Errno(errno)}
+			err = &os.PathError{"sendto", fd.sysfile.Name(), os.Errno(errno)}
 		}
 		break;
 	}
@@ -476,18 +484,21 @@
 }
 
 func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
-	if fd == nil || fd.file == nil {
+	if fd == nil || fd.sysfile == nil {
 		return nil, os.EINVAL
 	}
 
+	fd.incref();
+	defer fd.decref();
+
 	// See ../syscall/exec.go for description of ForkLock.
 	// It is okay to hold the lock across syscall.Accept
-	// because we have put fd.fd into non-blocking mode.
+	// because we have put fd.sysfd into non-blocking mode.
 	syscall.ForkLock.RLock();
 	var s, e int;
 	var sa syscall.Sockaddr;
 	for {
-		s, sa, e = syscall.Accept(fd.fd);
+		s, sa, e = syscall.Accept(fd.sysfd);
 		if e != syscall.EAGAIN {
 			break
 		}