internal/socket: tell race detector about syscall reads and writes

The syscalls that send and receive messages write to buffers provided
by the user. The race detector can't see those reads and writes by
default (they are done by the kernel), so we need to tell the race
detector explicitly about them.

Fixes golang/go#35329

Change-Id: Ibf4ef1b937535c4834aa9eeb744722d91f669a27
Reviewed-on: https://go-review.googlesource.com/c/net/+/205461
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
diff --git a/internal/socket/norace.go b/internal/socket/norace.go
new file mode 100644
index 0000000..9519ffb
--- /dev/null
+++ b/internal/socket/norace.go
@@ -0,0 +1,12 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !race
+
+package socket
+
+func (m *Message) raceRead() {
+}
+func (m *Message) raceWrite() {
+}
diff --git a/internal/socket/race.go b/internal/socket/race.go
new file mode 100644
index 0000000..df60c62
--- /dev/null
+++ b/internal/socket/race.go
@@ -0,0 +1,37 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package socket
+
+import (
+	"runtime"
+	"unsafe"
+)
+
+// This package reads and writes the Message buffers using a
+// direct system call, which the race detector can't see.
+// These functions tell the race detector what is going on during the syscall.
+
+func (m *Message) raceRead() {
+	for _, b := range m.Buffers {
+		if len(b) > 0 {
+			runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
+		}
+	}
+	if b := m.OOB; len(b) > 0 {
+		runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
+	}
+}
+func (m *Message) raceWrite() {
+	for _, b := range m.Buffers {
+		if len(b) > 0 {
+			runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
+		}
+	}
+	if b := m.OOB; len(b) > 0 {
+		runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
+	}
+}
diff --git a/internal/socket/rawconn_mmsg.go b/internal/socket/rawconn_mmsg.go
index 1f4cb3b..d01fc4c 100644
--- a/internal/socket/rawconn_mmsg.go
+++ b/internal/socket/rawconn_mmsg.go
@@ -13,6 +13,9 @@
 )
 
 func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
+	for i := range ms {
+		ms[i].raceWrite()
+	}
 	hs := make(mmsghdrs, len(ms))
 	var parseFn func([]byte, string) (net.Addr, error)
 	if c.network != "tcp" {
@@ -43,6 +46,9 @@
 }
 
 func (c *Conn) sendMsgs(ms []Message, flags int) (int, error) {
+	for i := range ms {
+		ms[i].raceRead()
+	}
 	hs := make(mmsghdrs, len(ms))
 	var marshalFn func(net.Addr) []byte
 	if c.network != "tcp" {
diff --git a/internal/socket/rawconn_msg.go b/internal/socket/rawconn_msg.go
index a972011..d5ae3f8 100644
--- a/internal/socket/rawconn_msg.go
+++ b/internal/socket/rawconn_msg.go
@@ -12,6 +12,7 @@
 )
 
 func (c *Conn) recvMsg(m *Message, flags int) error {
+	m.raceWrite()
 	var h msghdr
 	vs := make([]iovec, len(m.Buffers))
 	var sa []byte
@@ -48,6 +49,7 @@
 }
 
 func (c *Conn) sendMsg(m *Message, flags int) error {
+	m.raceRead()
 	var h msghdr
 	vs := make([]iovec, len(m.Buffers))
 	var sa []byte
diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go
index 0b6ebf5..2300cec 100644
--- a/internal/socket/socket_test.go
+++ b/internal/socket/socket_test.go
@@ -9,8 +9,13 @@
 import (
 	"bytes"
 	"fmt"
+	"io/ioutil"
 	"net"
+	"os"
+	"os/exec"
+	"path/filepath"
 	"runtime"
+	"strings"
 	"syscall"
 	"testing"
 
@@ -296,3 +301,67 @@
 		}
 	}
 }
+
+func TestRace(t *testing.T) {
+	tests := []string{
+		`
+package main
+import "net"
+import "golang.org/x/net/ipv4"
+var g byte
+func main() {
+	c, _ := net.ListenPacket("udp", "127.0.0.1:0")
+	cc := ipv4.NewPacketConn(c)
+	sync := make(chan bool)
+	src := make([]byte, 1)
+	dst := make([]byte, 1)
+	go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
+	go func() { cc.ReadFrom(dst); sync <- true }()
+	g = dst[0]
+	<- sync
+}
+`,
+		`
+package main
+import "net"
+import "golang.org/x/net/ipv4"
+func main() {
+	c, _ := net.ListenPacket("udp", "127.0.0.1:0")
+	cc := ipv4.NewPacketConn(c)
+	sync := make(chan bool)
+	src := make([]byte, 1)
+	dst := make([]byte, 1)
+	go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
+	src[0] = 0
+	go func() { cc.ReadFrom(dst) }()
+	<- sync
+}
+`,
+	}
+	platforms := map[string]bool{
+		"linux/amd64":   true,
+		"linux/ppc64le": true,
+		"linux/arm64":   true,
+	}
+	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
+		t.Skip("skipping test on non-race-enabled host.")
+	}
+	dir, err := ioutil.TempDir("", "testrace")
+	if err != nil {
+		t.Fatalf("failed to create temp directory: %v", err)
+	}
+	defer os.RemoveAll(dir)
+	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
+	for i, test := range tests {
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
+			if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
+				t.Fatalf("failed to write file: %v", err)
+			}
+			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
+			if !strings.Contains(string(got), "WARNING: DATA RACE") {
+				t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
+			}
+		})
+	}
+}