net: avoid Shutdown during Close
Once we've evicted all the blocked I/O, the ref count
should go to zero quickly, so it should be safe to
postpone the close(2) until then.
Fixes #1898.
Fixes #2116.
Fixes #2122.
R=golang-dev, mikioh.mikioh, bradfitz, fullung, iant
CC=golang-dev
https://golang.org/cl/5649076
diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go
index 607a6c1..596cf33 100644
--- a/src/pkg/net/fd.go
+++ b/src/pkg/net/fd.go
@@ -7,6 +7,7 @@
package net
import (
+ "errors"
"io"
"os"
"sync"
@@ -19,6 +20,9 @@
// locking/lifetime of sysfd
sysmu sync.Mutex
sysref int
+
+ // must lock both sysmu and pollserver to write
+ // can lock either to read
closing bool
// immutable until Close
@@ -27,8 +31,8 @@
sotype int
isConnected bool
sysfile *os.File
- cr chan bool
- cw chan bool
+ cr chan error
+ cw chan error
net string
laddr Addr
raddr Addr
@@ -86,19 +90,14 @@
deadline int64 // next deadline (nsec since 1970)
}
-func (s *pollServer) AddFD(fd *netFD, mode int) {
- intfd := fd.sysfd
- if intfd < 0 {
- // fd closed underfoot
- if mode == 'r' {
- fd.cr <- true
- } else {
- fd.cw <- true
- }
- return
- }
-
+func (s *pollServer) AddFD(fd *netFD, mode int) error {
s.Lock()
+ intfd := fd.sysfd
+ if intfd < 0 || fd.closing {
+ // fd closed underfoot
+ s.Unlock()
+ return errClosing
+ }
var t int64
key := intfd << 1
@@ -124,12 +123,28 @@
if wake {
doWakeup = true
}
-
s.Unlock()
if doWakeup {
s.Wakeup()
}
+ return nil
+}
+
+// Evict evicts fd from the pending list, unblocking
+// any I/O running on fd. The caller must have locked
+// pollserver.
+func (s *pollServer) Evict(fd *netFD) {
+ if s.pending[fd.sysfd<<1] == fd {
+ s.WakeFD(fd, 'r', errClosing)
+ s.poll.DelFD(fd.sysfd, 'r')
+ delete(s.pending, fd.sysfd<<1)
+ }
+ if s.pending[fd.sysfd<<1|1] == fd {
+ s.WakeFD(fd, 'w', errClosing)
+ s.poll.DelFD(fd.sysfd, 'w')
+ delete(s.pending, fd.sysfd<<1|1)
+ }
}
var wakeupbuf [1]byte
@@ -149,16 +164,16 @@
return netfd
}
-func (s *pollServer) WakeFD(fd *netFD, mode int) {
+func (s *pollServer) WakeFD(fd *netFD, mode int, err error) {
if mode == 'r' {
for fd.ncr > 0 {
fd.ncr--
- fd.cr <- true
+ fd.cr <- err
}
} else {
for fd.ncw > 0 {
fd.ncw--
- fd.cw <- true
+ fd.cw <- err
}
}
}
@@ -196,7 +211,7 @@
s.poll.DelFD(fd.sysfd, mode)
fd.wdeadline = -1
}
- s.WakeFD(fd, mode)
+ s.WakeFD(fd, mode, nil)
} else if next_deadline == 0 || t < next_deadline {
next_deadline = t
}
@@ -240,19 +255,25 @@
print("pollServer: unexpected wakeup for fd=", fd, " mode=", string(mode), "\n")
continue
}
- s.WakeFD(netfd, mode)
+ s.WakeFD(netfd, mode, nil)
}
}
}
-func (s *pollServer) WaitRead(fd *netFD) {
- s.AddFD(fd, 'r')
- <-fd.cr
+func (s *pollServer) WaitRead(fd *netFD) error {
+ err := s.AddFD(fd, 'r')
+ if err == nil {
+ err = <-fd.cr
+ }
+ return err
}
-func (s *pollServer) WaitWrite(fd *netFD) {
- s.AddFD(fd, 'w')
- <-fd.cw
+func (s *pollServer) WaitWrite(fd *netFD) error {
+ err := s.AddFD(fd, 'w')
+ if err == nil {
+ err = <-fd.cw
+ }
+ return err
}
// Network FD methods.
@@ -280,8 +301,8 @@
sotype: sotype,
net: net,
}
- netfd.cr = make(chan bool, 1)
- netfd.cw = make(chan bool, 1)
+ netfd.cr = make(chan error, 1)
+ netfd.cw = make(chan error, 1)
return netfd, nil
}
@@ -301,7 +322,9 @@
func (fd *netFD) connect(ra syscall.Sockaddr) error {
err := syscall.Connect(fd.sysfd, ra)
if err == syscall.EINPROGRESS {
- pollserver.WaitWrite(fd)
+ if err = pollserver.WaitWrite(fd); err != nil {
+ return err
+ }
var e int
e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil {
@@ -314,24 +337,37 @@
return err
}
+var errClosing = errors.New("use of closed network connection")
+
// Add a reference to this fd.
-func (fd *netFD) incref() {
+// If closing==true, pollserver must be locked; mark the fd as closing.
+// Returns an error if the fd cannot be used.
+func (fd *netFD) incref(closing bool) error {
+ if fd == nil {
+ return errClosing
+ }
fd.sysmu.Lock()
+ if fd.closing {
+ fd.sysmu.Unlock()
+ return errClosing
+ }
fd.sysref++
+ if closing {
+ fd.closing = true
+ }
fd.sysmu.Unlock()
+ return nil
}
// Remove a reference to this FD and close if we've been asked to do so (and
// there are no references left.
func (fd *netFD) decref() {
+ if fd == nil {
+ return
+ }
fd.sysmu.Lock()
fd.sysref--
- if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
- // In case the user has set linger, switch to blocking mode so
- // the close blocks. As long as this doesn't happen often, we
- // can handle the extra OS processes. Otherwise we'll need to
- // use the pollserver for Close too. Sigh.
- syscall.SetNonblock(fd.sysfd, false)
+ if fd.closing && fd.sysref == 0 && fd.sysfile != nil {
fd.sysfile.Close()
fd.sysfile = nil
fd.sysfd = -1
@@ -340,21 +376,26 @@
}
func (fd *netFD) Close() error {
- if fd == nil || fd.sysfile == nil {
- return os.EINVAL
+ pollserver.Lock() // needed for both fd.incref(true) and pollserver.Evict
+ defer pollserver.Unlock()
+ if err := fd.incref(true); err != nil {
+ return err
}
-
- fd.incref()
- syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR)
- fd.closing = true
+ // Unblock any I/O. Once it all unblocks and returns,
+ // so that it cannot be referring to fd.sysfd anymore,
+ // the final decref will close fd.sysfd. This should happen
+ // fairly quickly, since all the I/O is non-blocking, and any
+ // attempts to block in the pollserver will return errClosing.
+ pollserver.Evict(fd)
fd.decref()
return nil
}
func (fd *netFD) shutdown(how int) error {
- if fd == nil || fd.sysfile == nil {
- return os.EINVAL
+ if err := fd.incref(false); err != nil {
+ return err
}
+ defer fd.decref()
err := syscall.Shutdown(fd.sysfd, how)
if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, err}
@@ -371,24 +412,21 @@
}
func (fd *netFD) Read(p []byte) (n int, err error) {
- if fd == nil {
- return 0, os.EINVAL
- }
fd.rio.Lock()
defer fd.rio.Unlock()
- fd.incref()
- defer fd.decref()
- if fd.sysfile == nil {
- return 0, os.EINVAL
+ if err := fd.incref(false); err != nil {
+ return 0, err
}
+ defer fd.decref()
for {
- n, err = syscall.Read(int(fd.sysfile.Fd()), p)
+ n, err = syscall.Read(int(fd.sysfd), p)
if err == syscall.EAGAIN {
- if fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
- }
err = errTimeout
+ if fd.rdeadline >= 0 {
+ if err = pollserver.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
if err != nil {
n = 0
@@ -404,49 +442,49 @@
}
func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
- if fd == nil || fd.sysfile == nil {
- return 0, nil, os.EINVAL
- }
fd.rio.Lock()
defer fd.rio.Unlock()
- fd.incref()
+ if err := fd.incref(false); err != nil {
+ return 0, nil, err
+ }
defer fd.decref()
for {
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
if err == syscall.EAGAIN {
- if fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
- }
err = errTimeout
+ if fd.rdeadline >= 0 {
+ if err = pollserver.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
if err != nil {
n = 0
}
break
}
- if err != nil {
+ if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.laddr, err}
}
return
}
func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
- if fd == nil || fd.sysfile == nil {
- return 0, 0, 0, nil, os.EINVAL
- }
fd.rio.Lock()
defer fd.rio.Unlock()
- fd.incref()
+ if err := fd.incref(false); err != nil {
+ return 0, 0, 0, nil, err
+ }
defer fd.decref()
for {
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
if err == syscall.EAGAIN {
- if fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
- }
err = errTimeout
+ if fd.rdeadline >= 0 {
+ if err = pollserver.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
if err == nil && n == 0 {
err = io.EOF
@@ -461,12 +499,11 @@
}
func (fd *netFD) Write(p []byte) (int, error) {
- if fd == nil {
- return 0, os.EINVAL
- }
fd.wio.Lock()
defer fd.wio.Unlock()
- fd.incref()
+ if err := fd.incref(false); err != nil {
+ return 0, err
+ }
defer fd.decref()
if fd.sysfile == nil {
return 0, os.EINVAL
@@ -476,7 +513,7 @@
nn := 0
for {
var n int
- n, err = syscall.Write(int(fd.sysfile.Fd()), p[nn:])
+ n, err = syscall.Write(int(fd.sysfd), p[nn:])
if n > 0 {
nn += n
}
@@ -484,11 +521,12 @@
break
}
if err == syscall.EAGAIN {
- if fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
- }
err = errTimeout
+ if fd.wdeadline >= 0 {
+ if err = pollserver.WaitWrite(fd); err == nil {
+ continue
+ }
+ }
}
if err != nil {
n = 0
@@ -506,21 +544,21 @@
}
func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
- if fd == nil || fd.sysfile == nil {
- return 0, os.EINVAL
- }
fd.wio.Lock()
defer fd.wio.Unlock()
- fd.incref()
+ if err := fd.incref(false); err != nil {
+ return 0, err
+ }
defer fd.decref()
for {
err = syscall.Sendto(fd.sysfd, p, 0, sa)
if err == syscall.EAGAIN {
- if fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
- }
err = errTimeout
+ if fd.wdeadline >= 0 {
+ if err = pollserver.WaitWrite(fd); err == nil {
+ continue
+ }
+ }
}
break
}
@@ -533,21 +571,21 @@
}
func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
- if fd == nil || fd.sysfile == nil {
- return 0, 0, os.EINVAL
- }
fd.wio.Lock()
defer fd.wio.Unlock()
- fd.incref()
+ if err := fd.incref(false); err != nil {
+ return 0, 0, err
+ }
defer fd.decref()
for {
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
if err == syscall.EAGAIN {
- if fd.wdeadline >= 0 {
- pollserver.WaitWrite(fd)
- continue
- }
err = errTimeout
+ if fd.wdeadline >= 0 {
+ if err = pollserver.WaitWrite(fd); err == nil {
+ continue
+ }
+ }
}
break
}
@@ -561,11 +599,9 @@
}
func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) {
- if fd == nil || fd.sysfile == nil {
- return nil, os.EINVAL
+ if err := fd.incref(false); err != nil {
+ return nil, err
}
-
- fd.incref()
defer fd.decref()
// See ../syscall/exec.go for description of ForkLock.
@@ -574,19 +610,17 @@
var s int
var rsa syscall.Sockaddr
for {
- if fd.closing {
- return nil, os.EINVAL
- }
syscall.ForkLock.RLock()
s, rsa, err = syscall.Accept(fd.sysfd)
if err != nil {
syscall.ForkLock.RUnlock()
if err == syscall.EAGAIN {
- if fd.rdeadline >= 0 {
- pollserver.WaitRead(fd)
- continue
- }
err = errTimeout
+ if fd.rdeadline >= 0 {
+ if err = pollserver.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
return nil, &OpError{"accept", fd.net, fd.laddr, err}
}