| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package os_test |
| |
| import ( |
| "bytes" |
| "errors" |
| "fmt" |
| "io" |
| "math/rand/v2" |
| "net" |
| "os" |
| "runtime" |
| "sync" |
| "testing" |
| |
| "golang.org/x/net/nettest" |
| ) |
| |
| // Exercise sendfile/splice fast paths with a moderately large file. |
| // |
| // https://go.dev/issue/70000 |
| |
| func TestLargeCopyViaNetwork(t *testing.T) { |
| const size = 10 * 1024 * 1024 |
| dir := t.TempDir() |
| |
| src, err := os.Create(dir + "/src") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer src.Close() |
| if _, err := io.CopyN(src, newRandReader(), size); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := src.Seek(0, 0); err != nil { |
| t.Fatal(err) |
| } |
| |
| dst, err := os.Create(dir + "/dst") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer dst.Close() |
| |
| client, server := createSocketPair(t, "tcp") |
| var wg sync.WaitGroup |
| wg.Add(2) |
| go func() { |
| defer wg.Done() |
| if n, err := io.Copy(dst, server); n != size || err != nil { |
| t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size) |
| } |
| }() |
| go func() { |
| defer wg.Done() |
| defer client.Close() |
| if n, err := io.Copy(client, src); n != size || err != nil { |
| t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size) |
| } |
| }() |
| wg.Wait() |
| |
| if _, err := dst.Seek(0, 0); err != nil { |
| t.Fatal(err) |
| } |
| if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| func TestCopyFileToFile(t *testing.T) { |
| const size = 1 * 1024 * 1024 |
| dir := t.TempDir() |
| |
| src, err := os.Create(dir + "/src") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer src.Close() |
| if _, err := io.CopyN(src, newRandReader(), size); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := src.Seek(0, 0); err != nil { |
| t.Fatal(err) |
| } |
| |
| mustSeek := func(f *os.File, offset int64, whence int) int64 { |
| ret, err := f.Seek(offset, whence) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return ret |
| } |
| |
| for _, srcStart := range []int64{0, 100, size} { |
| remaining := size - srcStart |
| for _, dstStart := range []int64{0, 200} { |
| for _, limit := range []int64{remaining, remaining - 100, size * 2, 0} { |
| if limit < 0 { |
| continue |
| } |
| name := fmt.Sprintf("srcStart=%v/dstStart=%v/limit=%v", srcStart, dstStart, limit) |
| t.Run(name, func(t *testing.T) { |
| dst, err := os.CreateTemp(dir, "dst") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer dst.Close() |
| defer os.Remove(dst.Name()) |
| |
| mustSeek(src, srcStart, io.SeekStart) |
| if _, err := io.CopyN(dst, zeroReader{}, dstStart); err != nil { |
| t.Fatal(err) |
| } |
| |
| var copied int64 |
| if limit == 0 { |
| copied, err = io.Copy(dst, src) |
| } else { |
| copied, err = io.CopyN(dst, src, limit) |
| } |
| if limit > remaining { |
| if err != io.EOF { |
| t.Errorf("Copy: %v; want io.EOF", err) |
| } |
| } else { |
| if err != nil { |
| t.Errorf("Copy: %v; want nil", err) |
| } |
| } |
| |
| wantCopied := remaining |
| if limit != 0 { |
| wantCopied = min(limit, wantCopied) |
| } |
| if copied != wantCopied { |
| t.Errorf("copied %v bytes, want %v", copied, wantCopied) |
| } |
| |
| srcPos := mustSeek(src, 0, io.SeekCurrent) |
| wantSrcPos := srcStart + wantCopied |
| if srcPos != wantSrcPos { |
| t.Errorf("source position = %v, want %v", srcPos, wantSrcPos) |
| } |
| |
| dstPos := mustSeek(dst, 0, io.SeekCurrent) |
| wantDstPos := dstStart + wantCopied |
| if dstPos != wantDstPos { |
| t.Errorf("destination position = %v, want %v", dstPos, wantDstPos) |
| } |
| |
| mustSeek(dst, 0, io.SeekStart) |
| rr := newRandReader() |
| io.CopyN(io.Discard, rr, srcStart) |
| wantReader := io.MultiReader( |
| io.LimitReader(zeroReader{}, dstStart), |
| io.LimitReader(rr, wantCopied), |
| ) |
| if err := compareReaders(dst, wantReader); err != nil { |
| t.Fatal(err) |
| } |
| }) |
| |
| } |
| } |
| } |
| } |
| |
| func compareReaders(a, b io.Reader) error { |
| bufa := make([]byte, 4096) |
| bufb := make([]byte, 4096) |
| off := 0 |
| for { |
| na, erra := io.ReadFull(a, bufa) |
| if erra != nil && erra != io.EOF && erra != io.ErrUnexpectedEOF { |
| return erra |
| } |
| nb, errb := io.ReadFull(b, bufb) |
| if errb != nil && errb != io.EOF && errb != io.ErrUnexpectedEOF { |
| return errb |
| } |
| if !bytes.Equal(bufa[:na], bufb[:nb]) { |
| return errors.New("contents mismatch") |
| } |
| if erra != nil && errb != nil { |
| break |
| } |
| off += len(bufa) |
| } |
| return nil |
| } |
| |
| type zeroReader struct{} |
| |
| func (r zeroReader) Read(p []byte) (int, error) { |
| clear(p) |
| return len(p), nil |
| } |
| |
| type randReader struct { |
| rand *rand.Rand |
| } |
| |
| func newRandReader() *randReader { |
| return &randReader{rand.New(rand.NewPCG(0, 0))} |
| } |
| |
| func (r *randReader) Read(p []byte) (int, error) { |
| for i := range p { |
| p[i] = byte(r.rand.Uint32() & 0xff) |
| } |
| return len(p), nil |
| } |
| |
| func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { |
| t.Helper() |
| if !nettest.TestableNetwork(proto) { |
| t.Skipf("%s does not support %q", runtime.GOOS, proto) |
| } |
| |
| ln, err := nettest.NewLocalListener(proto) |
| if err != nil { |
| t.Fatalf("NewLocalListener error: %v", err) |
| } |
| t.Cleanup(func() { |
| if ln != nil { |
| ln.Close() |
| } |
| if client != nil { |
| client.Close() |
| } |
| if server != nil { |
| server.Close() |
| } |
| }) |
| ch := make(chan struct{}) |
| go func() { |
| var err error |
| server, err = ln.Accept() |
| if err != nil { |
| t.Errorf("Accept new connection error: %v", err) |
| } |
| ch <- struct{}{} |
| }() |
| client, err = net.Dial(proto, ln.Addr().String()) |
| <-ch |
| if err != nil { |
| t.Fatalf("Dial new connection error: %v", err) |
| } |
| return client, server |
| } |