blob: 2f1e69ddb64106566cbf05010503fc7f22070c61 [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux
package net
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"sync"
"testing"
)
func TestSplice(t *testing.T) {
t.Run("simple", testSpliceSimple)
t.Run("multipleWrite", testSpliceMultipleWrite)
t.Run("big", testSpliceBig)
t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
t.Run("readerAtEOF", testSpliceReaderAtEOF)
t.Run("issue25985", testSpliceIssue25985)
}
func testSpliceSimple(t *testing.T) {
srv, err := newSpliceTestServer()
if err != nil {
t.Fatal(err)
}
defer srv.Close()
copyDone := srv.Copy()
msg := []byte("splice test")
if _, err := srv.Write(msg); err != nil {
t.Fatal(err)
}
got := make([]byte, len(msg))
if _, err := io.ReadFull(srv, got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, msg) {
t.Errorf("got %q, wrote %q", got, msg)
}
srv.CloseWrite()
srv.CloseRead()
if err := <-copyDone; err != nil {
t.Errorf("splice: %v", err)
}
}
func testSpliceMultipleWrite(t *testing.T) {
srv, err := newSpliceTestServer()
if err != nil {
t.Fatal(err)
}
defer srv.Close()
copyDone := srv.Copy()
msg1 := []byte("splice test part 1 ")
msg2 := []byte(" splice test part 2")
if _, err := srv.Write(msg1); err != nil {
t.Fatalf("Write: %v", err)
}
if _, err := srv.Write(msg2); err != nil {
t.Fatal(err)
}
got := make([]byte, len(msg1)+len(msg2))
if _, err := io.ReadFull(srv, got); err != nil {
t.Fatal(err)
}
want := append(msg1, msg2...)
if !bytes.Equal(got, want) {
t.Errorf("got %q, wrote %q", got, want)
}
srv.CloseWrite()
srv.CloseRead()
if err := <-copyDone; err != nil {
t.Errorf("splice: %v", err)
}
}
func testSpliceBig(t *testing.T) {
size := 1<<31 - 1
if testing.Short() {
size = 1 << 25
}
srv, err := newSpliceTestServer()
if err != nil {
t.Fatal(err)
}
defer srv.Close()
big := make([]byte, size)
copyDone := srv.Copy()
type readResult struct {
b []byte
err error
}
readDone := make(chan readResult)
go func() {
got := make([]byte, len(big))
_, err := io.ReadFull(srv, got)
readDone <- readResult{got, err}
}()
if _, err := srv.Write(big); err != nil {
t.Fatal(err)
}
res := <-readDone
if res.err != nil {
t.Fatal(res.err)
}
got := res.b
if !bytes.Equal(got, big) {
t.Errorf("input and output differ")
}
srv.CloseWrite()
srv.CloseRead()
if err := <-copyDone; err != nil {
t.Errorf("splice: %v", err)
}
}
func testSpliceHonorsLimitedReader(t *testing.T) {
t.Run("stopsAfterN", testSpliceStopsAfterN)
t.Run("updatesN", testSpliceUpdatesN)
}
func testSpliceStopsAfterN(t *testing.T) {
clientUp, serverUp, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientUp.Close()
defer serverUp.Close()
clientDown, serverDown, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientDown.Close()
defer serverDown.Close()
count := 128
copyDone := make(chan error)
lr := &io.LimitedReader{
N: int64(count),
R: serverUp,
}
go func() {
_, err := io.Copy(serverDown, lr)
serverDown.Close()
copyDone <- err
}()
msg := make([]byte, 2*count)
if _, err := clientUp.Write(msg); err != nil {
t.Fatal(err)
}
clientUp.Close()
var buf bytes.Buffer
if _, err := io.Copy(&buf, clientDown); err != nil {
t.Fatal(err)
}
if buf.Len() != count {
t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
}
clientDown.Close()
if err := <-copyDone; err != nil {
t.Errorf("splice: %v", err)
}
}
func testSpliceUpdatesN(t *testing.T) {
clientUp, serverUp, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientUp.Close()
defer serverUp.Close()
clientDown, serverDown, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientDown.Close()
defer serverDown.Close()
count := 128
copyDone := make(chan error)
lr := &io.LimitedReader{
N: int64(100 + count),
R: serverUp,
}
go func() {
_, err := io.Copy(serverDown, lr)
copyDone <- err
}()
msg := make([]byte, count)
if _, err := clientUp.Write(msg); err != nil {
t.Fatal(err)
}
clientUp.Close()
got := make([]byte, count)
if _, err := io.ReadFull(clientDown, got); err != nil {
t.Fatal(err)
}
clientDown.Close()
if err := <-copyDone; err != nil {
t.Errorf("splice: %v", err)
}
wantN := int64(100)
if lr.N != wantN {
t.Errorf("lr.N = %d, want %d", lr.N, wantN)
}
}
func testSpliceReaderAtEOF(t *testing.T) {
clientUp, serverUp, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientUp.Close()
defer serverUp.Close()
clientDown, serverDown, err := spliceTestSocketPair("tcp")
if err != nil {
t.Fatal(err)
}
defer clientDown.Close()
defer serverDown.Close()
serverUp.Close()
_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
if !handled {
t.Errorf("closed connection: got err = %v, handled = %t, want handled = true", err, handled)
}
lr := &io.LimitedReader{
N: 0,
R: serverUp,
}
_, err, handled = splice(serverDown.(*TCPConn).fd, lr)
if !handled {
t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
}
}
func testSpliceIssue25985(t *testing.T) {
front, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
}
defer front.Close()
back, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
}
defer back.Close()
var wg sync.WaitGroup
wg.Add(2)
proxy := func() {
src, err := front.Accept()
if err != nil {
return
}
dst, err := Dial("tcp", back.Addr().String())
if err != nil {
return
}
defer dst.Close()
defer src.Close()
go func() {
io.Copy(src, dst)
wg.Done()
}()
go func() {
io.Copy(dst, src)
wg.Done()
}()
}
go proxy()
toFront, err := Dial("tcp", front.Addr().String())
if err != nil {
t.Fatal(err)
}
io.WriteString(toFront, "foo")
toFront.Close()
fromProxy, err := back.Accept()
if err != nil {
t.Fatal(err)
}
defer fromProxy.Close()
_, err = ioutil.ReadAll(fromProxy)
if err != nil {
t.Fatal(err)
}
wg.Wait()
}
func BenchmarkTCPReadFrom(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
var chunkSizes []int
for i := uint(10); i <= 20; i++ {
chunkSizes = append(chunkSizes, 1<<i)
}
// To benchmark the genericReadFrom code path, set this to false.
useSplice := true
for _, chunkSize := range chunkSizes {
b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
benchmarkSplice(b, chunkSize, useSplice)
})
}
}
func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
srv, err := newSpliceTestServer()
if err != nil {
b.Fatal(err)
}
defer srv.Close()
var copyDone <-chan error
if useSplice {
copyDone = srv.Copy()
} else {
copyDone = srv.CopyNoSplice()
}
chunk := make([]byte, chunkSize)
discardDone := make(chan struct{})
go func() {
for {
buf := make([]byte, chunkSize)
_, err := srv.Read(buf)
if err != nil {
break
}
}
discardDone <- struct{}{}
}()
b.SetBytes(int64(chunkSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
srv.Write(chunk)
}
srv.CloseWrite()
<-copyDone
srv.CloseRead()
<-discardDone
}
type spliceTestServer struct {
clientUp io.WriteCloser
clientDown io.ReadCloser
serverUp io.ReadCloser
serverDown io.WriteCloser
}
func newSpliceTestServer() (*spliceTestServer, error) {
// For now, both networks are hard-coded to TCP.
// If splice is enabled for non-tcp upstream connections,
// newSpliceTestServer will need to take a network parameter.
clientUp, serverUp, err := spliceTestSocketPair("tcp")
if err != nil {
return nil, err
}
clientDown, serverDown, err := spliceTestSocketPair("tcp")
if err != nil {
clientUp.Close()
serverUp.Close()
return nil, err
}
return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
}
// Read reads from the downstream connection.
func (srv *spliceTestServer) Read(b []byte) (int, error) {
return srv.clientDown.Read(b)
}
// Write writes to the upstream connection.
func (srv *spliceTestServer) Write(b []byte) (int, error) {
return srv.clientUp.Write(b)
}
// Close closes the server.
func (srv *spliceTestServer) Close() error {
err := srv.closeUp()
err1 := srv.closeDown()
if err == nil {
return err1
}
return err
}
// CloseWrite closes the client side of the upstream connection.
func (srv *spliceTestServer) CloseWrite() error {
return srv.clientUp.Close()
}
// CloseRead closes the client side of the downstream connection.
func (srv *spliceTestServer) CloseRead() error {
return srv.clientDown.Close()
}
// Copy copies from the server side of the upstream connection
// to the server side of the downstream connection, in a separate
// goroutine. Copy is done when the first send on the returned
// channel succeeds.
func (srv *spliceTestServer) Copy() <-chan error {
ch := make(chan error)
go func() {
_, err := io.Copy(srv.serverDown, srv.serverUp)
ch <- err
close(ch)
}()
return ch
}
// CopyNoSplice is like Copy, but ensures that the splice code path
// is not reached.
func (srv *spliceTestServer) CopyNoSplice() <-chan error {
type onlyReader struct {
io.Reader
}
ch := make(chan error)
go func() {
_, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
ch <- err
close(ch)
}()
return ch
}
func (srv *spliceTestServer) closeUp() error {
var err, err1 error
if srv.serverUp != nil {
err = srv.serverUp.Close()
}
if srv.clientUp != nil {
err1 = srv.clientUp.Close()
}
if err == nil {
return err1
}
return err
}
func (srv *spliceTestServer) closeDown() error {
var err, err1 error
if srv.serverDown != nil {
err = srv.serverDown.Close()
}
if srv.clientDown != nil {
err1 = srv.clientDown.Close()
}
if err == nil {
return err1
}
return err
}
func spliceTestSocketPair(net string) (client, server Conn, err error) {
ln, err := newLocalListener(net)
if err != nil {
return nil, nil, err
}
defer ln.Close()
var cerr, serr error
acceptDone := make(chan struct{})
go func() {
server, serr = ln.Accept()
acceptDone <- struct{}{}
}()
client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
<-acceptDone
if cerr != nil {
if server != nil {
server.Close()
}
return nil, nil, cerr
}
if serr != nil {
if client != nil {
client.Close()
}
return nil, nil, serr
}
return client, server, nil
}