net: fix netFD.Close races
Fixes #271.
Fixes #321.
R=rsc, agl, cw
CC=golang-dev
https://golang.org/cl/163052
diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go
index f134b0f..e1592eb 100644
--- a/src/pkg/net/fd.go
+++ b/src/pkg/net/fd.go
@@ -15,14 +15,18 @@
// Network file descriptor.
type netFD struct {
+ // locking/lifetime of sysfd
+ sysmu sync.Mutex;
+ sysref int;
+ closing bool;
+
// immutable until Close
- fd int;
+ sysfd int;
family int;
proto int;
- file *os.File;
+ sysfile *os.File;
cr chan *netFD;
cw chan *netFD;
- cc chan *netFD;
net string;
laddr Addr;
raddr Addr;
@@ -68,13 +72,13 @@
// channel will be empty for the next process's request. A larger buffer
// might help batch requests.
//
-// In order to prevent race conditions, pollServer has an additional cc channel
-// that receives fds to be closed. pollServer doesn't make the close system
-// call, it just sets fd.file = nil and fd.fd = -1. Because of this, pollServer
-// is always in sync with the kernel's view of a given descriptor.
+// To avoid races in closing, all fd operations are locked and
+// refcounted. when netFD.Close() is called, it calls syscall.Shutdown
+// and sets a closing flag. Only when the last reference is removed
+// will the fd be closed.
type pollServer struct {
- cr, cw, cc chan *netFD; // buffered >= 1
+ cr, cw chan *netFD; // buffered >= 1
pr, pw *os.File;
pending map[int]*netFD;
poll *pollster; // low-level OS hooks
@@ -85,7 +89,6 @@
s = new(pollServer);
s.cr = make(chan *netFD, 1);
s.cw = make(chan *netFD, 1);
- s.cc = make(chan *netFD, 1);
if s.pr, s.pw, err = os.Pipe(); err != nil {
return nil, err
}
@@ -114,16 +117,7 @@
}
func (s *pollServer) AddFD(fd *netFD, mode int) {
- // This check verifies that the underlying file descriptor hasn't been
- // closed in the mean time. Any time a netFD is closed, the closing
- // goroutine makes a round trip to the pollServer which sets file = nil
- // and fd = -1. The goroutine then closes the actual file descriptor.
- // Thus fd.fd mirrors the kernel's view of the file descriptor.
-
- // TODO(rsc,agl): There is still a race in Read and Write,
- // because they optimistically try to use the fd and don't
- // call into the PollServer unless they get EAGAIN.
- intfd := fd.fd;
+ intfd := fd.sysfd;
if intfd < 0 {
// fd closed underfoot
if mode == 'r' {
@@ -213,10 +207,10 @@
if t <= now {
s.pending[key] = nil, false;
if mode == 'r' {
- s.poll.DelFD(fd.fd, mode);
+ s.poll.DelFD(fd.sysfd, mode);
fd.rdeadline = -1;
} else {
- s.poll.DelFD(fd.fd, mode);
+ s.poll.DelFD(fd.sysfd, mode);
fd.wdeadline = -1;
}
s.WakeFD(fd, mode);
@@ -254,7 +248,6 @@
for nn, _ := s.pr.Read(&scratch); nn > 0; {
nn, _ = s.pr.Read(&scratch)
}
-
// Read from channels
for fd, ok := <-s.cr; ok; fd, ok = <-s.cr {
s.AddFD(fd, 'r')
@@ -262,11 +255,6 @@
for fd, ok := <-s.cw; ok; fd, ok = <-s.cw {
s.AddFD(fd, 'w')
}
- for fd, ok := <-s.cc; ok; fd, ok = <-s.cc {
- fd.file = nil;
- fd.fd = -1;
- fd.cc <- fd;
- }
} else {
netfd := s.LookupFD(fd, mode);
if netfd == nil {
@@ -294,12 +282,6 @@
<-fd.cw;
}
-func (s *pollServer) WaitCloseAck(fd *netFD) {
- s.cc <- fd;
- s.Wakeup();
- <-fd.cc;
-}
-
// Network FD methods.
// All the network FDs use a single pollServer.
@@ -319,7 +301,7 @@
return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)}
}
f = &netFD{
- fd: fd,
+ sysfd: fd,
family: family,
proto: proto,
net: net,
@@ -333,13 +315,37 @@
if raddr != nil {
rs = raddr.String()
}
- f.file = os.NewFile(fd, net+":"+ls+"->"+rs);
+ f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs);
f.cr = make(chan *netFD, 1);
f.cw = make(chan *netFD, 1);
- f.cc = make(chan *netFD, 1);
return f, nil;
}
+// Add a reference to this fd.
+func (fd *netFD) incref() {
+ fd.sysmu.Lock();
+ fd.sysref++;
+ fd.sysmu.Unlock();
+}
+
+// Remove a reference to this FD and close if we've been asked to do so (and
+// there are no references left.
+func (fd *netFD) decref() {
+ fd.sysmu.Lock();
+ fd.sysref--;
+ if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
+ // In case the user has set linger, switch to blocking mode so
+ // the close blocks. As long as this doesn't happen often, we
+ // can handle the extra OS processes. Otherwise we'll need to
+ // use the pollserver for Close too. Sigh.
+ syscall.SetNonblock(fd.sysfd, false);
+ fd.sysfile.Close();
+ fd.sysfile = nil;
+ fd.sysfd = -1;
+ }
+ fd.sysmu.Unlock();
+}
+
func isEAGAIN(e os.Error) bool {
if e1, ok := e.(*os.PathError); ok {
return e1.Error == os.EAGAIN
@@ -348,36 +354,32 @@
}
func (fd *netFD) Close() os.Error {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
- // In case the user has set linger,
- // switch to blocking mode so the close blocks.
- // As long as this doesn't happen often,
- // we can handle the extra OS processes.
- // Otherwise we'll need to use the pollserver
- // for Close too. Sigh.
- syscall.SetNonblock(fd.file.Fd(), false);
-
- f := fd.file;
- pollserver.WaitCloseAck(fd);
- return f.Close();
+ fd.incref();
+ syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR);
+ fd.closing = true;
+ fd.decref();
+ return nil;
}
func (fd *netFD) Read(p []byte) (n int, err os.Error) {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.rio.Lock();
defer fd.rio.Unlock();
+ fd.incref();
+ defer fd.decref();
if fd.rdeadline_delta > 0 {
fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
} else {
fd.rdeadline = 0
}
for {
- n, err = fd.file.Read(p);
+ n, err = fd.sysfile.Read(p);
if isEAGAIN(err) && fd.rdeadline >= 0 {
pollserver.WaitRead(fd);
continue;
@@ -388,11 +390,13 @@
}
func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return 0, nil, os.EINVAL
}
fd.rio.Lock();
defer fd.rio.Unlock();
+ fd.incref();
+ defer fd.decref();
if fd.rdeadline_delta > 0 {
fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
} else {
@@ -400,14 +404,14 @@
}
for {
var errno int;
- n, sa, errno = syscall.Recvfrom(fd.fd, p, 0);
+ n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0);
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
pollserver.WaitRead(fd);
continue;
}
if errno != 0 {
n = 0;
- err = &os.PathError{"recvfrom", fd.file.Name(), os.Errno(errno)};
+ err = &os.PathError{"recvfrom", fd.sysfile.Name(), os.Errno(errno)};
}
break;
}
@@ -415,11 +419,13 @@
}
func (fd *netFD) Write(p []byte) (n int, err os.Error) {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.wio.Lock();
defer fd.wio.Unlock();
+ fd.incref();
+ defer fd.decref();
if fd.wdeadline_delta > 0 {
fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
} else {
@@ -428,7 +434,7 @@
err = nil;
nn := 0;
for nn < len(p) {
- n, err = fd.file.Write(p[nn:]);
+ n, err = fd.sysfile.Write(p[nn:]);
if n > 0 {
nn += n
}
@@ -447,11 +453,13 @@
}
func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.wio.Lock();
defer fd.wio.Unlock();
+ fd.incref();
+ defer fd.decref();
if fd.wdeadline_delta > 0 {
fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
} else {
@@ -459,13 +467,13 @@
}
err = nil;
for {
- errno := syscall.Sendto(fd.fd, p, 0, sa);
+ errno := syscall.Sendto(fd.sysfd, p, 0, sa);
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
pollserver.WaitWrite(fd);
continue;
}
if errno != 0 {
- err = &os.PathError{"sendto", fd.file.Name(), os.Errno(errno)}
+ err = &os.PathError{"sendto", fd.sysfile.Name(), os.Errno(errno)}
}
break;
}
@@ -476,18 +484,21 @@
}
func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
- if fd == nil || fd.file == nil {
+ if fd == nil || fd.sysfile == nil {
return nil, os.EINVAL
}
+ fd.incref();
+ defer fd.decref();
+
// See ../syscall/exec.go for description of ForkLock.
// It is okay to hold the lock across syscall.Accept
- // because we have put fd.fd into non-blocking mode.
+ // because we have put fd.sysfd into non-blocking mode.
syscall.ForkLock.RLock();
var s, e int;
var sa syscall.Sockaddr;
for {
- s, sa, e = syscall.Accept(fd.fd);
+ s, sa, e = syscall.Accept(fd.sysfd);
if e != syscall.EAGAIN {
break
}