| // Copyright 2018 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // +build linux |
| |
| package net |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "testing" |
| ) |
| |
| func TestSplice(t *testing.T) { |
| t.Run("simple", testSpliceSimple) |
| t.Run("multipleWrite", testSpliceMultipleWrite) |
| t.Run("big", testSpliceBig) |
| t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader) |
| t.Run("readerAtEOF", testSpliceReaderAtEOF) |
| } |
| |
| func testSpliceSimple(t *testing.T) { |
| srv, err := newSpliceTestServer() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer srv.Close() |
| copyDone := srv.Copy() |
| msg := []byte("splice test") |
| if _, err := srv.Write(msg); err != nil { |
| t.Fatal(err) |
| } |
| got := make([]byte, len(msg)) |
| if _, err := io.ReadFull(srv, got); err != nil { |
| t.Fatal(err) |
| } |
| if !bytes.Equal(got, msg) { |
| t.Errorf("got %q, wrote %q", got, msg) |
| } |
| srv.CloseWrite() |
| srv.CloseRead() |
| if err := <-copyDone; err != nil { |
| t.Errorf("splice: %v", err) |
| } |
| } |
| |
| func testSpliceMultipleWrite(t *testing.T) { |
| srv, err := newSpliceTestServer() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer srv.Close() |
| copyDone := srv.Copy() |
| msg1 := []byte("splice test part 1 ") |
| msg2 := []byte(" splice test part 2") |
| if _, err := srv.Write(msg1); err != nil { |
| t.Fatalf("Write: %v", err) |
| } |
| if _, err := srv.Write(msg2); err != nil { |
| t.Fatal(err) |
| } |
| got := make([]byte, len(msg1)+len(msg2)) |
| if _, err := io.ReadFull(srv, got); err != nil { |
| t.Fatal(err) |
| } |
| want := append(msg1, msg2...) |
| if !bytes.Equal(got, want) { |
| t.Errorf("got %q, wrote %q", got, want) |
| } |
| srv.CloseWrite() |
| srv.CloseRead() |
| if err := <-copyDone; err != nil { |
| t.Errorf("splice: %v", err) |
| } |
| } |
| |
| func testSpliceBig(t *testing.T) { |
| size := 1<<31 - 1 |
| if testing.Short() { |
| size = 1 << 25 |
| } |
| srv, err := newSpliceTestServer() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer srv.Close() |
| big := make([]byte, size) |
| copyDone := srv.Copy() |
| type readResult struct { |
| b []byte |
| err error |
| } |
| readDone := make(chan readResult) |
| go func() { |
| got := make([]byte, len(big)) |
| _, err := io.ReadFull(srv, got) |
| readDone <- readResult{got, err} |
| }() |
| if _, err := srv.Write(big); err != nil { |
| t.Fatal(err) |
| } |
| res := <-readDone |
| if res.err != nil { |
| t.Fatal(res.err) |
| } |
| got := res.b |
| if !bytes.Equal(got, big) { |
| t.Errorf("input and output differ") |
| } |
| srv.CloseWrite() |
| srv.CloseRead() |
| if err := <-copyDone; err != nil { |
| t.Errorf("splice: %v", err) |
| } |
| } |
| |
| func testSpliceHonorsLimitedReader(t *testing.T) { |
| t.Run("stopsAfterN", testSpliceStopsAfterN) |
| t.Run("updatesN", testSpliceUpdatesN) |
| } |
| |
| func testSpliceStopsAfterN(t *testing.T) { |
| clientUp, serverUp, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientUp.Close() |
| defer serverUp.Close() |
| clientDown, serverDown, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientDown.Close() |
| defer serverDown.Close() |
| count := 128 |
| copyDone := make(chan error) |
| lr := &io.LimitedReader{ |
| N: int64(count), |
| R: serverUp, |
| } |
| go func() { |
| _, err := io.Copy(serverDown, lr) |
| serverDown.Close() |
| copyDone <- err |
| }() |
| msg := make([]byte, 2*count) |
| if _, err := clientUp.Write(msg); err != nil { |
| t.Fatal(err) |
| } |
| clientUp.Close() |
| var buf bytes.Buffer |
| if _, err := io.Copy(&buf, clientDown); err != nil { |
| t.Fatal(err) |
| } |
| if buf.Len() != count { |
| t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count) |
| } |
| clientDown.Close() |
| if err := <-copyDone; err != nil { |
| t.Errorf("splice: %v", err) |
| } |
| } |
| |
| func testSpliceUpdatesN(t *testing.T) { |
| clientUp, serverUp, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientUp.Close() |
| defer serverUp.Close() |
| clientDown, serverDown, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientDown.Close() |
| defer serverDown.Close() |
| count := 128 |
| copyDone := make(chan error) |
| lr := &io.LimitedReader{ |
| N: int64(100 + count), |
| R: serverUp, |
| } |
| go func() { |
| _, err := io.Copy(serverDown, lr) |
| copyDone <- err |
| }() |
| msg := make([]byte, count) |
| if _, err := clientUp.Write(msg); err != nil { |
| t.Fatal(err) |
| } |
| clientUp.Close() |
| got := make([]byte, count) |
| if _, err := io.ReadFull(clientDown, got); err != nil { |
| t.Fatal(err) |
| } |
| clientDown.Close() |
| if err := <-copyDone; err != nil { |
| t.Errorf("splice: %v", err) |
| } |
| wantN := int64(100) |
| if lr.N != wantN { |
| t.Errorf("lr.N = %d, want %d", lr.N, wantN) |
| } |
| } |
| |
| func testSpliceReaderAtEOF(t *testing.T) { |
| clientUp, serverUp, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientUp.Close() |
| defer serverUp.Close() |
| clientDown, serverDown, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer clientDown.Close() |
| defer serverDown.Close() |
| |
| serverUp.Close() |
| _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) |
| if !handled { |
| t.Errorf("closed connection: got err = %v, handled = %t, want handled = true", err, handled) |
| } |
| lr := &io.LimitedReader{ |
| N: 0, |
| R: serverUp, |
| } |
| _, err, handled = splice(serverDown.(*TCPConn).fd, lr) |
| if !handled { |
| t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled) |
| } |
| } |
| |
| func BenchmarkTCPReadFrom(b *testing.B) { |
| testHookUninstaller.Do(uninstallTestHooks) |
| |
| var chunkSizes []int |
| for i := uint(10); i <= 20; i++ { |
| chunkSizes = append(chunkSizes, 1<<i) |
| } |
| // To benchmark the genericReadFrom code path, set this to false. |
| useSplice := true |
| for _, chunkSize := range chunkSizes { |
| b.Run(fmt.Sprint(chunkSize), func(b *testing.B) { |
| benchmarkSplice(b, chunkSize, useSplice) |
| }) |
| } |
| } |
| |
| func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) { |
| srv, err := newSpliceTestServer() |
| if err != nil { |
| b.Fatal(err) |
| } |
| defer srv.Close() |
| var copyDone <-chan error |
| if useSplice { |
| copyDone = srv.Copy() |
| } else { |
| copyDone = srv.CopyNoSplice() |
| } |
| chunk := make([]byte, chunkSize) |
| discardDone := make(chan struct{}) |
| go func() { |
| for { |
| buf := make([]byte, chunkSize) |
| _, err := srv.Read(buf) |
| if err != nil { |
| break |
| } |
| } |
| discardDone <- struct{}{} |
| }() |
| b.SetBytes(int64(chunkSize)) |
| b.ResetTimer() |
| for i := 0; i < b.N; i++ { |
| srv.Write(chunk) |
| } |
| srv.CloseWrite() |
| <-copyDone |
| srv.CloseRead() |
| <-discardDone |
| } |
| |
| type spliceTestServer struct { |
| clientUp io.WriteCloser |
| clientDown io.ReadCloser |
| serverUp io.ReadCloser |
| serverDown io.WriteCloser |
| } |
| |
| func newSpliceTestServer() (*spliceTestServer, error) { |
| // For now, both networks are hard-coded to TCP. |
| // If splice is enabled for non-tcp upstream connections, |
| // newSpliceTestServer will need to take a network parameter. |
| clientUp, serverUp, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| return nil, err |
| } |
| clientDown, serverDown, err := spliceTestSocketPair("tcp") |
| if err != nil { |
| clientUp.Close() |
| serverUp.Close() |
| return nil, err |
| } |
| return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil |
| } |
| |
| // Read reads from the downstream connection. |
| func (srv *spliceTestServer) Read(b []byte) (int, error) { |
| return srv.clientDown.Read(b) |
| } |
| |
| // Write writes to the upstream connection. |
| func (srv *spliceTestServer) Write(b []byte) (int, error) { |
| return srv.clientUp.Write(b) |
| } |
| |
| // Close closes the server. |
| func (srv *spliceTestServer) Close() error { |
| err := srv.closeUp() |
| err1 := srv.closeDown() |
| if err == nil { |
| return err1 |
| } |
| return err |
| } |
| |
| // CloseWrite closes the client side of the upstream connection. |
| func (srv *spliceTestServer) CloseWrite() error { |
| return srv.clientUp.Close() |
| } |
| |
| // CloseRead closes the client side of the downstream connection. |
| func (srv *spliceTestServer) CloseRead() error { |
| return srv.clientDown.Close() |
| } |
| |
| // Copy copies from the server side of the upstream connection |
| // to the server side of the downstream connection, in a separate |
| // goroutine. Copy is done when the first send on the returned |
| // channel succeeds. |
| func (srv *spliceTestServer) Copy() <-chan error { |
| ch := make(chan error) |
| go func() { |
| _, err := io.Copy(srv.serverDown, srv.serverUp) |
| ch <- err |
| close(ch) |
| }() |
| return ch |
| } |
| |
| // CopyNoSplice is like Copy, but ensures that the splice code path |
| // is not reached. |
| func (srv *spliceTestServer) CopyNoSplice() <-chan error { |
| type onlyReader struct { |
| io.Reader |
| } |
| ch := make(chan error) |
| go func() { |
| _, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp}) |
| ch <- err |
| close(ch) |
| }() |
| return ch |
| } |
| |
| func (srv *spliceTestServer) closeUp() error { |
| var err, err1 error |
| if srv.serverUp != nil { |
| err = srv.serverUp.Close() |
| } |
| if srv.clientUp != nil { |
| err1 = srv.clientUp.Close() |
| } |
| if err == nil { |
| return err1 |
| } |
| return err |
| } |
| |
| func (srv *spliceTestServer) closeDown() error { |
| var err, err1 error |
| if srv.serverDown != nil { |
| err = srv.serverDown.Close() |
| } |
| if srv.clientDown != nil { |
| err1 = srv.clientDown.Close() |
| } |
| if err == nil { |
| return err1 |
| } |
| return err |
| } |
| |
| func spliceTestSocketPair(net string) (client, server Conn, err error) { |
| ln, err := newLocalListener(net) |
| if err != nil { |
| return nil, nil, err |
| } |
| defer ln.Close() |
| var cerr, serr error |
| acceptDone := make(chan struct{}) |
| go func() { |
| server, serr = ln.Accept() |
| acceptDone <- struct{}{} |
| }() |
| client, cerr = Dial(ln.Addr().Network(), ln.Addr().String()) |
| <-acceptDone |
| if cerr != nil { |
| if server != nil { |
| server.Close() |
| } |
| return nil, nil, cerr |
| } |
| if serr != nil { |
| if client != nil { |
| client.Close() |
| } |
| return nil, nil, serr |
| } |
| return client, server, nil |
| } |