net, os, internal/poll: test for use of sendfile

The net package's sendfile tests exercise various paths where
we expect sendfile to be used, but don't verify that sendfile
was in fact used.

Add a hook to internal/poll.SendFile to let us verify that
sendfile was called when expected. Update os package tests
(which use their own hook mechanism) to use this hook as well.

For #66988

Change-Id: I7afb130dcfe0063d60c6ea0f8560cf8665ad5a81
Reviewed-on: https://go-review.googlesource.com/c/go/+/581778
Reviewed-by: Ian Lance Taylor <iant@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/src/internal/poll/sendfile.go b/src/internal/poll/sendfile.go
new file mode 100644
index 0000000..41b0481
--- /dev/null
+++ b/src/internal/poll/sendfile.go
@@ -0,0 +1,7 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package poll
+
+var TestHookDidSendFile = func(dstFD *FD, src int, written int64, err error, handled bool) {}
diff --git a/src/internal/poll/sendfile_bsd.go b/src/internal/poll/sendfile_bsd.go
index 8fcdb1c..669df94 100644
--- a/src/internal/poll/sendfile_bsd.go
+++ b/src/internal/poll/sendfile_bsd.go
@@ -14,6 +14,9 @@
 
 // SendFile wraps the sendfile system call.
 func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, handled bool) {
+	defer func() {
+		TestHookDidSendFile(dstFD, src, written, err, handled)
+	}()
 	if err := dstFD.writeLock(); err != nil {
 		return 0, err, false
 	}
diff --git a/src/internal/poll/sendfile_linux.go b/src/internal/poll/sendfile_linux.go
index c2a0653..d1c4d5c 100644
--- a/src/internal/poll/sendfile_linux.go
+++ b/src/internal/poll/sendfile_linux.go
@@ -12,6 +12,9 @@
 
 // SendFile wraps the sendfile system call.
 func SendFile(dstFD *FD, src int, remain int64) (written int64, err error, handled bool) {
+	defer func() {
+		TestHookDidSendFile(dstFD, src, written, err, handled)
+	}()
 	if err := dstFD.writeLock(); err != nil {
 		return 0, err, false
 	}
diff --git a/src/internal/poll/sendfile_solaris.go b/src/internal/poll/sendfile_solaris.go
index 1ba0c8d..ec67583 100644
--- a/src/internal/poll/sendfile_solaris.go
+++ b/src/internal/poll/sendfile_solaris.go
@@ -17,6 +17,9 @@
 
 // SendFile wraps the sendfile system call.
 func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, handled bool) {
+	defer func() {
+		TestHookDidSendFile(dstFD, src, written, err, handled)
+	}()
 	if err := dstFD.writeLock(); err != nil {
 		return 0, err, false
 	}
diff --git a/src/internal/poll/sendfile_windows.go b/src/internal/poll/sendfile_windows.go
index 8c3353b..2ae8a8d 100644
--- a/src/internal/poll/sendfile_windows.go
+++ b/src/internal/poll/sendfile_windows.go
@@ -11,6 +11,9 @@
 
 // SendFile wraps the TransmitFile call.
 func SendFile(fd *FD, src syscall.Handle, n int64) (written int64, err error) {
+	defer func() {
+		TestHookDidSendFile(fd, 0, written, err, written > 0)
+	}()
 	if fd.kind == kindPipe {
 		// TransmitFile does not work with pipes
 		return 0, syscall.ESPIPE
diff --git a/src/net/sendfile_linux.go b/src/net/sendfile_linux.go
index 9a7d005..f8a7bec 100644
--- a/src/net/sendfile_linux.go
+++ b/src/net/sendfile_linux.go
@@ -10,6 +10,8 @@
 	"os"
 )
 
+const supportsSendfile = true
+
 // sendFile copies the contents of r to c using the sendfile
 // system call to minimize copies.
 //
diff --git a/src/net/sendfile_stub.go b/src/net/sendfile_stub.go
index a4fdd99..7f31cc6 100644
--- a/src/net/sendfile_stub.go
+++ b/src/net/sendfile_stub.go
@@ -2,12 +2,14 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || js || netbsd || openbsd || ios || wasip1
+//go:build !(linux || (darwin && !ios) || dragonfly || freebsd || solaris || windows)
 
 package net
 
 import "io"
 
+const supportsSendfile = false
+
 func sendFile(c *netFD, r io.Reader) (n int64, err error, handled bool) {
 	return 0, nil, false
 }
diff --git a/src/net/sendfile_test.go b/src/net/sendfile_test.go
index 8fadb47..4f34115 100644
--- a/src/net/sendfile_test.go
+++ b/src/net/sendfile_test.go
@@ -11,6 +11,7 @@
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"internal/poll"
 	"io"
 	"os"
 	"runtime"
@@ -26,6 +27,48 @@
 	newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
 )
 
+// expectSendfile runs f, and verifies that internal/poll.SendFile successfully handles
+// a write to wantConn during f's execution.
+//
+// On platforms where supportsSendfile is false, expectSendfile runs f but does not
+// expect a call to SendFile.
+func expectSendfile(t *testing.T, wantConn Conn, f func()) {
+	t.Helper()
+	if !supportsSendfile {
+		f()
+		return
+	}
+	orig := poll.TestHookDidSendFile
+	defer func() {
+		poll.TestHookDidSendFile = orig
+	}()
+	var (
+		called     bool
+		gotHandled bool
+		gotFD      *poll.FD
+	)
+	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
+		if called {
+			t.Error("internal/poll.SendFile called multiple times, want one call")
+		}
+		called = true
+		gotHandled = handled
+		gotFD = dstFD
+	}
+	f()
+	if !called {
+		t.Error("internal/poll.SendFile was not called, want it to be")
+		return
+	}
+	if !gotHandled {
+		t.Error("internal/poll.SendFile did not handle the write, want it to")
+		return
+	}
+	if &wantConn.(*TCPConn).fd.pfd != gotFD {
+		t.Error("internal.poll.SendFile called with unexpected FD")
+	}
+}
+
 func TestSendfile(t *testing.T) {
 	ln := newLocalListener(t, "tcp")
 	defer ln.Close()
@@ -53,7 +96,17 @@
 
 			// Return file data using io.Copy, which should use
 			// sendFile if available.
-			sbytes, err := io.Copy(conn, f)
+			var sbytes int64
+			switch runtime.GOOS {
+			case "windows":
+				// Windows is not using sendfile for some reason:
+				// https://go.dev/issue/67042
+				sbytes, err = io.Copy(conn, f)
+			default:
+				expectSendfile(t, conn, func() {
+					sbytes, err = io.Copy(conn, f)
+				})
+			}
 			if err != nil {
 				errc <- err
 				return
@@ -121,7 +174,9 @@
 			for i := 0; i < 3; i++ {
 				// Return file data using io.CopyN, which should use
 				// sendFile if available.
-				_, err = io.CopyN(conn, f, 3)
+				expectSendfile(t, conn, func() {
+					_, err = io.CopyN(conn, f, 3)
+				})
 				if err != nil {
 					errc <- err
 					return
@@ -180,7 +235,9 @@
 				return
 			}
 
-			_, err = io.CopyN(conn, f, sendSize)
+			expectSendfile(t, conn, func() {
+				_, err = io.CopyN(conn, f, sendSize)
+			})
 			if err != nil {
 				errc <- err
 				return
@@ -240,6 +297,10 @@
 			return
 		}
 		defer conn.Close()
+		// The comment above states that this should call into sendfile,
+		// but empirically it doesn't seem to do so at this time.
+		// If it does, or does on some platforms, this CopyN should be wrapped
+		// in expectSendfile.
 		_, err = io.CopyN(conn, r, 1)
 		if err != nil {
 			t.Error(err)
@@ -333,6 +394,10 @@
 		}
 		defer f.Close()
 
+		// We expect this to use sendfile, but as of the time this comment was written
+		// poll.SendFile on an FD past its timeout can return an error indicating that
+		// it didn't handle the operation, resulting in a non-sendfile retry.
+		// So don't use expectSendfile here.
 		_, err = io.Copy(conn, f)
 		if errors.Is(err, os.ErrDeadlineExceeded) {
 			return nil
diff --git a/src/net/sendfile_unix_alt.go b/src/net/sendfile_unix_alt.go
index 5a10540..9e46c4e 100644
--- a/src/net/sendfile_unix_alt.go
+++ b/src/net/sendfile_unix_alt.go
@@ -13,6 +13,8 @@
 	"syscall"
 )
 
+const supportsSendfile = true
+
 // sendFile copies the contents of r to c using the sendfile
 // system call to minimize copies.
 //
@@ -35,6 +37,8 @@
 			return 0, nil, true
 		}
 	}
+	// r might be an *os.File or an os.fileWithoutWriteTo.
+	// Type assert to an interface rather than *os.File directly to handle the latter case.
 	f, ok := r.(interface {
 		fs.File
 		io.Seeker
diff --git a/src/net/sendfile_windows.go b/src/net/sendfile_windows.go
index 59b1b0d..0377a48 100644
--- a/src/net/sendfile_windows.go
+++ b/src/net/sendfile_windows.go
@@ -11,6 +11,8 @@
 	"syscall"
 )
 
+const supportsSendfile = true
+
 // sendFile copies the contents of r to c using the TransmitFile
 // system call to minimize copies.
 //
diff --git a/src/os/export_linux_test.go b/src/os/export_linux_test.go
index 942b48a..78c2d14 100644
--- a/src/os/export_linux_test.go
+++ b/src/os/export_linux_test.go
@@ -7,6 +7,5 @@
 var (
 	PollCopyFileRangeP  = &pollCopyFileRange
 	PollSpliceFile      = &pollSplice
-	PollSendFile        = &pollSendFile
 	GetPollFDAndNetwork = getPollFDAndNetwork
 )
diff --git a/src/os/writeto_linux_test.go b/src/os/writeto_linux_test.go
index 5ffab88..e390063 100644
--- a/src/os/writeto_linux_test.go
+++ b/src/os/writeto_linux_test.go
@@ -109,8 +109,18 @@
 
 func hookSendFile(t *testing.T) *sendFileHook {
 	h := new(sendFileHook)
-	h.install()
-	t.Cleanup(h.uninstall)
+	orig := poll.TestHookDidSendFile
+	t.Cleanup(func() {
+		poll.TestHookDidSendFile = orig
+	})
+	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
+		h.called = true
+		h.dstfd = dstFD.Sysfd
+		h.srcfd = src
+		h.written = written
+		h.err = err
+		h.handled = handled
+	}
 	return h
 }
 
@@ -118,29 +128,10 @@
 	called bool
 	dstfd  int
 	srcfd  int
-	remain int64
 
 	written int64
 	handled bool
 	err     error
-
-	original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
-}
-
-func (h *sendFileHook) install() {
-	h.original = *PollSendFile
-	*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
-		h.called = true
-		h.dstfd = dst.Sysfd
-		h.srcfd = src
-		h.remain = remain
-		h.written, h.err, h.handled = h.original(dst, src, remain)
-		return h.written, h.err, h.handled
-	}
-}
-
-func (h *sendFileHook) uninstall() {
-	*PollSendFile = h.original
 }
 
 func createTempFile(t *testing.T, size int64) (*File, []byte) {
diff --git a/src/os/zero_copy_linux.go b/src/os/zero_copy_linux.go
index 70a05ff..0afc19e 100644
--- a/src/os/zero_copy_linux.go
+++ b/src/os/zero_copy_linux.go
@@ -13,7 +13,6 @@
 var (
 	pollCopyFileRange = poll.CopyFileRange
 	pollSplice        = poll.Splice
-	pollSendFile      = poll.SendFile
 )
 
 // wrapSyscallError takes an error and a syscall name. If the error is
@@ -38,7 +37,7 @@
 	}
 
 	rerr := sc.Read(func(fd uintptr) (done bool) {
-		written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
+		written, err, handled = poll.SendFile(pfd, int(fd), 1<<63-1)
 		return true
 	})