blob: 982a2b6330c5f82e93f34d05a32c07e2d536374c [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os_test
import (
"bytes"
"internal/poll"
"io"
"math/rand"
"os"
. "os"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
)
func TestCopyFileRange(t *testing.T) {
sizes := []int{
1,
42,
1025,
syscall.Getpagesize() + 1,
32769,
}
t.Run("Basic", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testCopyFileRange(t, int64(size), -1)
})
}
})
t.Run("Limited", func(t *testing.T) {
t.Run("OneLess", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testCopyFileRange(t, int64(size), int64(size)-1)
})
}
})
t.Run("Half", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testCopyFileRange(t, int64(size), int64(size)/2)
})
}
})
t.Run("More", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testCopyFileRange(t, int64(size), int64(size)+7)
})
}
})
})
t.Run("DoesntTryInAppendMode", func(t *testing.T) {
dst, src, data, hook := newCopyFileRangeTest(t, 42)
dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
if err != nil {
t.Fatal(err)
}
defer dst2.Close()
if _, err := io.Copy(dst2, src); err != nil {
t.Fatal(err)
}
if hook.called {
t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
}
mustSeekStart(t, dst2)
mustContainData(t, dst2, data) // through traditional means
})
t.Run("CopyFileItself", func(t *testing.T) {
hook := hookCopyFileRange(t)
f, err := os.CreateTemp("", "file-readfrom-itself-test")
if err != nil {
t.Fatalf("failed to create tmp file: %v", err)
}
t.Cleanup(func() {
f.Close()
os.Remove(f.Name())
})
data := []byte("hello world!")
if _, err := f.Write(data); err != nil {
t.Fatalf("failed to create and feed the file: %v", err)
}
if err := f.Sync(); err != nil {
t.Fatalf("failed to save the file: %v", err)
}
// Rewind it.
if _, err := f.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to rewind the file: %v", err)
}
// Read data from the file itself.
if _, err := io.Copy(f, f); err != nil {
t.Fatalf("failed to read from the file: %v", err)
}
if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
}
// Rewind it.
if _, err := f.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to rewind the file: %v", err)
}
data2, err := io.ReadAll(f)
if err != nil {
t.Fatalf("failed to read from the file: %v", err)
}
// It should wind up a double of the original data.
if strings.Repeat(string(data), 2) != string(data2) {
t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
}
})
t.Run("NotRegular", func(t *testing.T) {
t.Run("BothPipes", func(t *testing.T) {
hook := hookCopyFileRange(t)
pr1, pw1, err := Pipe()
if err != nil {
t.Fatal(err)
}
defer pr1.Close()
defer pw1.Close()
pr2, pw2, err := Pipe()
if err != nil {
t.Fatal(err)
}
defer pr2.Close()
defer pw2.Close()
// The pipe is empty, and PIPE_BUF is large enough
// for this, by (POSIX) definition, so there is no
// need for an additional goroutine.
data := []byte("hello")
if _, err := pw1.Write(data); err != nil {
t.Fatal(err)
}
pw1.Close()
n, err := io.Copy(pw2, pr1)
if err != nil {
t.Fatal(err)
}
if n != int64(len(data)) {
t.Fatalf("transferred %d, want %d", n, len(data))
}
if !hook.called {
t.Fatalf("should have called poll.CopyFileRange")
}
pw2.Close()
mustContainData(t, pr2, data)
})
t.Run("DstPipe", func(t *testing.T) {
dst, src, data, hook := newCopyFileRangeTest(t, 255)
dst.Close()
pr, pw, err := Pipe()
if err != nil {
t.Fatal(err)
}
defer pr.Close()
defer pw.Close()
n, err := io.Copy(pw, src)
if err != nil {
t.Fatal(err)
}
if n != int64(len(data)) {
t.Fatalf("transferred %d, want %d", n, len(data))
}
if !hook.called {
t.Fatalf("should have called poll.CopyFileRange")
}
pw.Close()
mustContainData(t, pr, data)
})
t.Run("SrcPipe", func(t *testing.T) {
dst, src, data, hook := newCopyFileRangeTest(t, 255)
src.Close()
pr, pw, err := Pipe()
if err != nil {
t.Fatal(err)
}
defer pr.Close()
defer pw.Close()
// The pipe is empty, and PIPE_BUF is large enough
// for this, by (POSIX) definition, so there is no
// need for an additional goroutine.
if _, err := pw.Write(data); err != nil {
t.Fatal(err)
}
pw.Close()
n, err := io.Copy(dst, pr)
if err != nil {
t.Fatal(err)
}
if n != int64(len(data)) {
t.Fatalf("transferred %d, want %d", n, len(data))
}
if !hook.called {
t.Fatalf("should have called poll.CopyFileRange")
}
mustSeekStart(t, dst)
mustContainData(t, dst, data)
})
})
t.Run("Nil", func(t *testing.T) {
var nilFile *File
anyFile, err := os.CreateTemp("", "")
if err != nil {
t.Fatal(err)
}
defer Remove(anyFile.Name())
defer anyFile.Close()
if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
}
if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
}
if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
}
if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
}
if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
}
if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
}
})
}
func testCopyFileRange(t *testing.T, size int64, limit int64) {
dst, src, data, hook := newCopyFileRangeTest(t, size)
// If we have a limit, wrap the reader.
var (
realsrc io.Reader
lr *io.LimitedReader
)
if limit >= 0 {
lr = &io.LimitedReader{N: limit, R: src}
realsrc = lr
if limit < int64(len(data)) {
data = data[:limit]
}
} else {
realsrc = src
}
// Now call ReadFrom (through io.Copy), which will hopefully call
// poll.CopyFileRange.
n, err := io.Copy(dst, realsrc)
if err != nil {
t.Fatal(err)
}
// If we didn't have a limit, we should have called poll.CopyFileRange
// with the right file descriptor arguments.
if limit > 0 && !hook.called {
t.Fatal("never called poll.CopyFileRange")
}
if hook.called && hook.dstfd != int(dst.Fd()) {
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
}
if hook.called && hook.srcfd != int(src.Fd()) {
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
}
// Check that the offsets after the transfer make sense, that the size
// of the transfer was reported correctly, and that the destination
// file contains exactly the bytes we expect it to contain.
dstoff, err := dst.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
srcoff, err := src.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
if dstoff != srcoff {
t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
}
if dstoff != int64(len(data)) {
t.Errorf("dstoff = %d, want %d", dstoff, len(data))
}
if n != int64(len(data)) {
t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
}
mustSeekStart(t, dst)
mustContainData(t, dst, data)
// If we had a limit, check that it was updated.
if lr != nil {
if want := limit - n; lr.N != want {
t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
}
}
}
// newCopyFileRangeTest initializes a new test for copy_file_range.
//
// It creates source and destination files, and populates the source file
// with random data of the specified size. It also hooks package os' call
// to poll.CopyFileRange and returns the hook so it can be inspected.
func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
t.Helper()
hook = hookCopyFileRange(t)
tmp := t.TempDir()
src, err := Create(filepath.Join(tmp, "src"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { src.Close() })
dst, err = Create(filepath.Join(tmp, "dst"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { dst.Close() })
// Populate the source file with data, then rewind it, so it can be
// consumed by copy_file_range(2).
prng := rand.New(rand.NewSource(time.Now().Unix()))
data = make([]byte, size)
prng.Read(data)
if _, err := src.Write(data); err != nil {
t.Fatal(err)
}
if _, err := src.Seek(0, io.SeekStart); err != nil {
t.Fatal(err)
}
return dst, src, data, hook
}
// mustContainData ensures that the specified file contains exactly the
// specified data.
func mustContainData(t *testing.T, f *File, data []byte) {
t.Helper()
got := make([]byte, len(data))
if _, err := io.ReadFull(f, got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, data) {
t.Fatalf("didn't get the same data back from %s", f.Name())
}
if _, err := f.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("not at EOF")
}
}
func mustSeekStart(t *testing.T, f *File) {
if _, err := f.Seek(0, io.SeekStart); err != nil {
t.Fatal(err)
}
}
func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
h := new(copyFileRangeHook)
h.install()
t.Cleanup(h.uninstall)
return h
}
type copyFileRangeHook struct {
called bool
dstfd int
srcfd int
remain int64
written int64
handled bool
err error
original func(dst, src *poll.FD, remain int64) (int64, bool, error)
}
func (h *copyFileRangeHook) install() {
h.original = *PollCopyFileRangeP
*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
h.called = true
h.dstfd = dst.Sysfd
h.srcfd = src.Sysfd
h.remain = remain
h.written, h.handled, h.err = h.original(dst, src, remain)
return h.written, h.handled, h.err
}
}
func (h *copyFileRangeHook) uninstall() {
*PollCopyFileRangeP = h.original
}
// On some kernels copy_file_range fails on files in /proc.
func TestProcCopy(t *testing.T) {
const cmdlineFile = "/proc/self/cmdline"
cmdline, err := os.ReadFile(cmdlineFile)
if err != nil {
t.Skipf("can't read /proc file: %v", err)
}
in, err := os.Open(cmdlineFile)
if err != nil {
t.Fatal(err)
}
defer in.Close()
outFile := filepath.Join(t.TempDir(), "cmdline")
out, err := os.Create(outFile)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(out, in); err != nil {
t.Fatal(err)
}
if err := out.Close(); err != nil {
t.Fatal(err)
}
copy, err := os.ReadFile(outFile)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(cmdline, copy) {
t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
}
}