net: add Buffers type, do writev on unix
No fast path currently for solaris, windows, nacl, plan9.
Fixes #13451
Change-Id: I24b3233a2e3a57fc6445e276a5c0d7b097884007
Reviewed-on: https://go-review.googlesource.com/29951
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/fd_unix.go b/src/net/fd_unix.go
index 11dde76..1296bc5 100644
--- a/src/net/fd_unix.go
+++ b/src/net/fd_unix.go
@@ -29,6 +29,9 @@
laddr Addr
raddr Addr
+ // writev cache.
+ iovecs *[]syscall.Iovec
+
// wait server
pd pollDesc
}
diff --git a/src/net/net.go b/src/net/net.go
index d6812d1..8ab952a 100644
--- a/src/net/net.go
+++ b/src/net/net.go
@@ -604,3 +604,66 @@
func releaseThread() {
<-threadLimit
}
+
+// buffersWriter is the interface implemented by Conns that support a
+// "writev"-like batch write optimization.
+// writeBuffers should fully consume and write all chunks from the
+// provided Buffers, else it should report a non-nil error.
+type buffersWriter interface {
+ writeBuffers(*Buffers) (int64, error)
+}
+
+var testHookDidWritev = func(wrote int) {}
+
+// Buffers contains zero or more runs of bytes to write.
+//
+// On certain machines, for certain types of connections, this is
+// optimized into an OS-specific batch write operation (such as
+// "writev").
+type Buffers [][]byte
+
+var (
+ _ io.WriterTo = (*Buffers)(nil)
+ _ io.Reader = (*Buffers)(nil)
+)
+
+func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
+ if wv, ok := w.(buffersWriter); ok {
+ return wv.writeBuffers(v)
+ }
+ for _, b := range *v {
+ nb, err := w.Write(b)
+ n += int64(nb)
+ if err != nil {
+ v.consume(n)
+ return n, err
+ }
+ }
+ v.consume(n)
+ return n, nil
+}
+
+func (v *Buffers) Read(p []byte) (n int, err error) {
+ for len(p) > 0 && len(*v) > 0 {
+ n0 := copy(p, (*v)[0])
+ v.consume(int64(n0))
+ p = p[n0:]
+ n += n0
+ }
+ if len(*v) == 0 {
+ err = io.EOF
+ }
+ return
+}
+
+func (v *Buffers) consume(n int64) {
+ for len(*v) > 0 {
+ ln0 := int64(len((*v)[0]))
+ if ln0 > n {
+ (*v)[0] = (*v)[0][n:]
+ return
+ }
+ n -= ln0
+ *v = (*v)[1:]
+ }
+}
diff --git a/src/net/net_test.go b/src/net/net_test.go
index b2f825d..1968ff3 100644
--- a/src/net/net_test.go
+++ b/src/net/net_test.go
@@ -414,3 +414,38 @@
}
}
}
+
+// withTCPConnPair sets up a TCP connection between two peers, then
+// runs peer1 and peer2 concurrently. withTCPConnPair returns when
+// both have completed.
+func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ errc := make(chan error, 2)
+ go func() {
+ c1, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c1.Close()
+ errc <- peer1(c1.(*TCPConn))
+ }()
+ go func() {
+ c2, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c2.Close()
+ errc <- peer2(c2.(*TCPConn))
+ }()
+ for i := 0; i < 2; i++ {
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ }
+}
diff --git a/src/net/writev_test.go b/src/net/writev_test.go
new file mode 100644
index 0000000..cc53adc
--- /dev/null
+++ b/src/net/writev_test.go
@@ -0,0 +1,171 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestBuffers_read(t *testing.T) {
+ const story = "once upon a time in Gopherland ... "
+ buffers := Buffers{
+ []byte("once "),
+ []byte("upon "),
+ []byte("a "),
+ []byte("time "),
+ []byte("in "),
+ []byte("Gopherland ... "),
+ }
+ got, err := ioutil.ReadAll(&buffers)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != story {
+ t.Errorf("read %q; want %q", got, story)
+ }
+ if len(buffers) != 0 {
+ t.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+}
+
+func TestBuffers_consume(t *testing.T) {
+ tests := []struct {
+ in Buffers
+ consume int64
+ want Buffers
+ }{
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 0,
+ want: Buffers{[]byte("foo"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 2,
+ want: Buffers{[]byte("o"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 3,
+ want: Buffers{[]byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 4,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("bar")},
+ consume: 1,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("foo")},
+ consume: 0,
+ want: Buffers{[]byte("foo")},
+ },
+ {
+ in: Buffers{nil, nil, nil},
+ consume: 0,
+ want: Buffers{},
+ },
+ }
+ for i, tt := range tests {
+ in := tt.in
+ in.consume(tt.consume)
+ if !reflect.DeepEqual(in, tt.want) {
+ t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
+ }
+ }
+}
+
+func TestBuffers_WriteTo(t *testing.T) {
+ for _, name := range []string{"WriteTo", "Copy"} {
+ for _, size := range []int{0, 10, 1023, 1024, 1025} {
+ t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
+ testBuffer_writeTo(t, size, name == "Copy")
+ })
+ }
+ }
+}
+
+func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
+ oldHook := testHookDidWritev
+ defer func() { testHookDidWritev = oldHook }()
+ var writeLog struct {
+ sync.Mutex
+ log []int
+ }
+ testHookDidWritev = func(size int) {
+ writeLog.Lock()
+ writeLog.log = append(writeLog.log, size)
+ writeLog.Unlock()
+ }
+ var want bytes.Buffer
+ for i := 0; i < chunks; i++ {
+ want.WriteByte(byte(i))
+ }
+
+ withTCPConnPair(t, func(c *TCPConn) error {
+ buffers := make(Buffers, chunks)
+ for i := range buffers {
+ buffers[i] = want.Bytes()[i : i+1]
+ }
+ var n int64
+ var err error
+ if useCopy {
+ n, err = io.Copy(c, &buffers)
+ } else {
+ n, err = buffers.WriteTo(c)
+ }
+ if err != nil {
+ return err
+ }
+ if len(buffers) != 0 {
+ return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+ if n != int64(want.Len()) {
+ return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
+ }
+ return nil
+ }, func(c *TCPConn) error {
+ all, err := ioutil.ReadAll(c)
+ if !bytes.Equal(all, want.Bytes()) || err != nil {
+ return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
+ }
+
+ writeLog.Lock() // no need to unlock
+ var gotSum int
+ for _, v := range writeLog.log {
+ gotSum += v
+ }
+
+ var wantSum int
+ var wantMinCalls int
+ switch runtime.GOOS {
+ case "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
+ wantSum = want.Len()
+ v := chunks
+ for v > 0 {
+ wantMinCalls++
+ v -= 1024
+ }
+ }
+ if len(writeLog.log) < wantMinCalls {
+ t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
+ }
+ if gotSum != wantSum {
+ t.Errorf("writev call sum = %v; want %v", gotSum, wantSum)
+ }
+ return nil
+ })
+}
diff --git a/src/net/writev_unix.go b/src/net/writev_unix.go
new file mode 100644
index 0000000..ac4f7cf
--- /dev/null
+++ b/src/net/writev_unix.go
@@ -0,0 +1,91 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd
+
+package net
+
+import (
+ "io"
+ "os"
+ "syscall"
+ "unsafe"
+)
+
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.writeBuffers(v)
+ if err != nil {
+ return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, nil
+}
+
+func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
+ if err := fd.writeLock(); err != nil {
+ return 0, err
+ }
+ defer fd.writeUnlock()
+ if err := fd.pd.prepareWrite(); err != nil {
+ return 0, err
+ }
+
+ var iovecs []syscall.Iovec
+ if fd.iovecs != nil {
+ iovecs = *fd.iovecs
+ }
+ // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is
+ // 1024 and this seems conservative enough for now. Darwin's
+ // UIO_MAXIOV also seems to be 1024.
+ maxVec := 1024
+
+ for len(*v) > 0 {
+ iovecs = iovecs[:0]
+ for _, chunk := range *v {
+ if len(chunk) == 0 {
+ continue
+ }
+ iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]})
+ iovecs[len(iovecs)-1].SetLen(len(chunk))
+ if len(iovecs) == maxVec {
+ break
+ }
+ }
+ if len(iovecs) == 0 {
+ break
+ }
+ fd.iovecs = &iovecs // cache
+
+ wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV,
+ uintptr(fd.sysfd),
+ uintptr(unsafe.Pointer(&iovecs[0])),
+ uintptr(len(iovecs)))
+ if wrote < 0 {
+ wrote = 0
+ }
+ testHookDidWritev(int(wrote))
+ n += int64(wrote)
+ v.consume(int64(wrote))
+ if e0 == syscall.EAGAIN {
+ if err = fd.pd.waitWrite(); err == nil {
+ continue
+ }
+ } else if e0 != 0 {
+ err = syscall.Errno(e0)
+ }
+ if err != nil {
+ break
+ }
+ if n == 0 {
+ err = io.ErrUnexpectedEOF
+ break
+ }
+ }
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("writev", err)
+ }
+ return n, err
+}
diff --git a/src/syscall/syscall_nacl.go b/src/syscall/syscall_nacl.go
index 88e4e3a..3247505 100644
--- a/src/syscall/syscall_nacl.go
+++ b/src/syscall/syscall_nacl.go
@@ -296,3 +296,5 @@
func ParseRoutingMessage(b []byte) ([]RoutingMessage, error) { return nil, ENOSYS }
func ParseRoutingSockaddr(msg RoutingMessage) ([]Sockaddr, error) { return nil, ENOSYS }
func SysctlUint32(name string) (value uint32, err error) { return 0, ENOSYS }
+
+type Iovec struct{} // dummy