internal/socket: support MSG_DONTWAIT

Explicitly handle MSG_DONTWAIT in read and send calls on platforms where this is defined, to get the per-call non-blocking behavior as would be expected when calling readmsg/sendmsg in C.
When MSG_DONTWAIT is set, we always return true from the function passed to syscall.RawConn.Read/Write, to avoid entering the polling state.

Fixes golang/go#46891

Change-Id: I4809577477554db1c45b6f4825a03d98208199d7
GitHub-Last-Rev: 4022e9b52c4375536e23162a88e5aa4d5637f134
GitHub-Pull-Request: golang/net#108
Reviewed-on: https://go-review.googlesource.com/c/net/+/333469
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
diff --git a/internal/socket/complete_dontwait.go b/internal/socket/complete_dontwait.go
new file mode 100644
index 0000000..5b1d50a
--- /dev/null
+++ b/internal/socket/complete_dontwait.go
@@ -0,0 +1,26 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package socket
+
+import (
+	"syscall"
+)
+
+// ioComplete checks the flags and result of a syscall, to be used as return
+// value in a syscall.RawConn.Read or Write callback.
+func ioComplete(flags int, operr error) bool {
+	if flags&syscall.MSG_DONTWAIT != 0 {
+		// Caller explicitly said don't wait, so always return immediately.
+		return true
+	}
+	if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
+		// No data available, block for I/O and try again.
+		return false
+	}
+	return true
+}
diff --git a/internal/socket/complete_nodontwait.go b/internal/socket/complete_nodontwait.go
new file mode 100644
index 0000000..be63409
--- /dev/null
+++ b/internal/socket/complete_nodontwait.go
@@ -0,0 +1,22 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || windows || zos
+// +build aix windows zos
+
+package socket
+
+import (
+	"syscall"
+)
+
+// ioComplete checks the flags and result of a syscall, to be used as return
+// value in a syscall.RawConn.Read or Write callback.
+func ioComplete(flags int, operr error) bool {
+	if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
+		// No data available, block for I/O and try again.
+		return false
+	}
+	return true
+}
diff --git a/internal/socket/rawconn_mmsg.go b/internal/socket/rawconn_mmsg.go
index d80a15c..3fcb51b 100644
--- a/internal/socket/rawconn_mmsg.go
+++ b/internal/socket/rawconn_mmsg.go
@@ -10,7 +10,6 @@
 import (
 	"net"
 	"os"
-	"syscall"
 )
 
 func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
@@ -28,10 +27,7 @@
 	var n int
 	fn := func(s uintptr) bool {
 		n, operr = recvmmsg(s, hs, flags)
-		if operr == syscall.EAGAIN {
-			return false
-		}
-		return true
+		return ioComplete(flags, operr)
 	}
 	if err := c.c.Read(fn); err != nil {
 		return n, err
@@ -60,10 +56,7 @@
 	var n int
 	fn := func(s uintptr) bool {
 		n, operr = sendmmsg(s, hs, flags)
-		if operr == syscall.EAGAIN {
-			return false
-		}
-		return true
+		return ioComplete(flags, operr)
 	}
 	if err := c.c.Write(fn); err != nil {
 		return n, err
diff --git a/internal/socket/rawconn_msg.go b/internal/socket/rawconn_msg.go
index 2e2d61b..ba53f56 100644
--- a/internal/socket/rawconn_msg.go
+++ b/internal/socket/rawconn_msg.go
@@ -9,7 +9,6 @@
 
 import (
 	"os"
-	"syscall"
 )
 
 func (c *Conn) recvMsg(m *Message, flags int) error {
@@ -25,10 +24,7 @@
 	var n int
 	fn := func(s uintptr) bool {
 		n, operr = recvmsg(s, &h, flags)
-		if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
-			return false
-		}
-		return true
+		return ioComplete(flags, operr)
 	}
 	if err := c.c.Read(fn); err != nil {
 		return err
@@ -64,10 +60,7 @@
 	var n int
 	fn := func(s uintptr) bool {
 		n, operr = sendmsg(s, &h, flags)
-		if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
-			return false
-		}
-		return true
+		return ioComplete(flags, operr)
 	}
 	if err := c.c.Write(fn); err != nil {
 		return err
diff --git a/internal/socket/socket_dontwait_test.go b/internal/socket/socket_dontwait_test.go
new file mode 100644
index 0000000..8eab990
--- /dev/null
+++ b/internal/socket/socket_dontwait_test.go
@@ -0,0 +1,125 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package socket_test
+
+import (
+	"bytes"
+	"errors"
+	"net"
+	"runtime"
+	"syscall"
+	"testing"
+
+	"golang.org/x/net/internal/socket"
+	"golang.org/x/net/nettest"
+)
+
+func TestUDPDontwait(t *testing.T) {
+	c, err := nettest.NewLocalPacketListener("udp")
+	if err != nil {
+		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+	}
+	defer c.Close()
+	cc, err := socket.NewConn(c.(*net.UDPConn))
+	if err != nil {
+		t.Fatal(err)
+	}
+	isErrWouldblock := func(err error) bool {
+		var errno syscall.Errno
+		return errors.As(err, &errno) && (errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK)
+	}
+
+	t.Run("Message-dontwait", func(t *testing.T) {
+		// Read before something was sent; expect EWOULDBLOCK
+		b := make([]byte, 32)
+		rm := socket.Message{
+			Buffers: [][]byte{b},
+		}
+		if err := cc.RecvMsg(&rm, syscall.MSG_DONTWAIT); !isErrWouldblock(err) {
+			t.Fatal(err)
+		}
+		// To trigger EWOULDBLOCK by SendMsg, we have to send faster than what the
+		// system/network is able to process. Whether or not we can trigger this
+		// depends on the system, specifically on write buffer sizes and the speed
+		// of the network interface.
+		// We cannot expect to quickly and reliably trigger this, especially not
+		// because this test sends data over a (fast) loopback. Consequently, we
+		// only check that sending with MSG_DONTWAIT works at all and don't attempt
+		// testing that we would eventually get EWOULDBLOCK here.
+		data := []byte("HELLO-R-U-THERE")
+		wm := socket.Message{
+			Buffers: [][]byte{data},
+			Addr:    c.LocalAddr(),
+		}
+		// Send one message, repeat until we don't get EWOULDBLOCK. This will likely succeed at the first attempt.
+		for {
+			err := cc.SendMsg(&wm, syscall.MSG_DONTWAIT)
+			if err == nil {
+				break
+			} else if !isErrWouldblock(err) {
+				t.Fatal(err)
+			}
+		}
+		// Read the message now available; again, this will likely succeed at the first attempt.
+		for {
+			err := cc.RecvMsg(&rm, syscall.MSG_DONTWAIT)
+			if err == nil {
+				break
+			} else if !isErrWouldblock(err) {
+				t.Fatal(err)
+			}
+		}
+		if !bytes.Equal(b[:rm.N], data) {
+			t.Fatalf("got %#v; want %#v", b[:rm.N], data)
+		}
+	})
+	switch runtime.GOOS {
+	case "android", "linux":
+		t.Run("Messages", func(t *testing.T) {
+			data := []byte("HELLO-R-U-THERE")
+			wmbs := bytes.SplitAfter(data, []byte("-"))
+			wms := []socket.Message{
+				{Buffers: wmbs[:1], Addr: c.LocalAddr()},
+				{Buffers: wmbs[1:], Addr: c.LocalAddr()},
+			}
+			b := make([]byte, 32)
+			rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
+			rms := []socket.Message{
+				{Buffers: rmbs[0]},
+				{Buffers: rmbs[1]},
+			}
+			_, err := cc.RecvMsgs(rms, syscall.MSG_DONTWAIT)
+			if !isErrWouldblock(err) {
+				t.Fatal(err)
+			}
+			for ntot := 0; ntot < len(wms); {
+				n, err := cc.SendMsgs(wms[ntot:], syscall.MSG_DONTWAIT)
+				if err == nil {
+					ntot += n
+				} else if !isErrWouldblock(err) {
+					t.Fatal(err)
+				}
+			}
+			for ntot := 0; ntot < len(rms); {
+				n, err := cc.RecvMsgs(rms[ntot:], syscall.MSG_DONTWAIT)
+				if err == nil {
+					ntot += n
+				} else if !isErrWouldblock(err) {
+					t.Fatal(err)
+				}
+			}
+			nn := 0
+			for i := 0; i < len(rms); i++ {
+				nn += rms[i].N
+			}
+			if !bytes.Equal(b[:nn], data) {
+				t.Fatalf("got %#v; want %#v", b[:nn], data)
+			}
+		})
+	}
+}