ogle/probe: handle read requests from network
Basic plumbing to accept a connection and handle requests to read data.
Test it by calling one's own process, which is slightly slippery because of
potential race conditions. Tests of writing will have to be cleverer.
LGTM=nigeltao
R=nigeltao
https://golang.org/cl/60550047
diff --git a/probe/addr_test.go b/probe/addr_test.go
index 65b1cc9..89c52ae 100644
--- a/probe/addr_test.go
+++ b/probe/addr_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package main
+package probe
import (
"reflect"
diff --git a/probe/net.go b/probe/net.go
new file mode 100644
index 0000000..55a088a
--- /dev/null
+++ b/probe/net.go
@@ -0,0 +1,283 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// TODO: Document the protocol once it settles.
+
+package probe
+
+import (
+ "errors"
+ "io" // Used only for the definitions of the various interfaces and errors.
+ "net"
+)
+
+var (
+ port = ":54321" // TODO: how to choose port number?
+ tracing = false
+)
+
+// init starts a network listener and leaves it in the background waiting for connections.
+func init() {
+ go demon()
+}
+
+// demon answers consecutive connection requests and starts a server to manage each one.
+// The server runs in the same goroutine as the demon, so a new connection cannot be
+// established until the previous one is completed.
+func demon() {
+ listener, err := net.Listen("tcp", port)
+ if err != nil {
+ trace("listen:", err)
+ return
+ }
+ trace("listening")
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ trace("accept", err)
+ continue
+ }
+ trace("accepted a connection")
+ serve(conn)
+ conn.Close()
+ }
+}
+
+// stringer is the same as fmt.Stringer. We redefine it here to avoid pulling in fmt.
+type stringer interface {
+ String() string
+}
+
+func printHex(b byte) {
+ const hex = "0123456789ABCDEF"
+ b1, b0 := b>>4&0xF, b&0xF
+ print(hex[b1:b1+1], hex[b0:b0+1])
+}
+
+// trace is a simple version of println that is enabled by the tracing boolean.
+func trace(args ...interface{}) {
+ if !tracing {
+ return
+ }
+ print("ogle demon: ")
+ for i, arg := range args {
+ if i > 0 {
+ print(" ")
+ }
+ // A little help. Built-in print isn't very capable.
+ switch arg := arg.(type) {
+ case stringer:
+ print(arg.String())
+ case error:
+ print(arg.Error())
+ case []byte:
+ print("[")
+ for i := range arg {
+ if i > 0 {
+ print(" ")
+ }
+ printHex(arg[i])
+ }
+ print("]")
+ case int:
+ print(arg)
+ case string:
+ print(arg)
+ case uintptr:
+ print("0x")
+ for i := ptrSize - 1; i >= 0; i-- {
+ printHex(byte(arg >> uint(8*i)))
+ }
+ default:
+ print(arg)
+ }
+ }
+ print("\n")
+}
+
+func serve(conn net.Conn) {
+ const (
+ bufSize = 1 << 16
+ )
+ var buf [bufSize]byte
+ network := &pipe{
+ rw: conn,
+ }
+ for {
+ // One message per loop.
+ n, err := network.Read(buf[:1])
+ if n != 1 || err != nil {
+ return
+ }
+ switch buf[0] {
+ case 'r':
+ // Read: ['r', address, size] => [0, size, size bytes]
+ u, err := network.readUintptr()
+ if err != nil {
+ return
+ }
+ n, err := network.readInt()
+ if err != nil {
+ return
+ }
+ if !validRead(u, n) {
+ trace("read", err)
+ network.error("invalid read address")
+ continue
+ }
+ network.sendReadResponse(u, n)
+ default:
+ // TODO: shut down connection?
+ trace("unknown message type:", buf[0])
+ }
+ }
+}
+
+// pipe is a buffered network connection (actually just a reader/writer) that
+// implements Read and ReadByte as well as readFull.
+// It also has support routines to make it easier to read and write
+// network messages.
+type pipe struct {
+ rw io.ReadWriter
+ pos int
+ end int
+ oneByte [1]byte
+ buf [4096]byte
+}
+
+// readFull fills the argument slice with data from the wire. If it cannot fill the
+// slice, it returns an error.
+// TODO: unused for now; write will need it.
+func (p *pipe) readFull(buf []byte) error {
+ for len(buf) > 0 {
+ n, err := p.rw.Read(buf)
+ if n == len(buf) {
+ return nil
+ }
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return err
+ }
+ if n == 0 {
+ return io.EOF
+ }
+ buf = buf[n:]
+ }
+ return nil
+}
+
+// Read satisfies io.Reader.
+func (p *pipe) Read(buf []byte) (int, error) {
+ n := len(buf)
+ if p.end == p.pos {
+ p.pos = 0
+ // Read from network
+ var err error
+ p.end, err = p.rw.Read(p.buf[:])
+ if err != nil {
+ trace("read:", err)
+ return p.end, err
+ }
+ if p.end == 0 {
+ trace("read: eof")
+ return p.end, io.EOF
+ }
+ }
+ if n > p.end-p.pos {
+ n = p.end - p.pos
+ }
+ copy(buf, p.buf[p.pos:p.pos+n])
+ p.pos += n
+ return n, nil
+}
+
+// ReadByte satisfies io.ByteReader.
+func (p *pipe) ReadByte() (byte, error) {
+ _, err := p.Read(p.oneByte[:])
+ return p.oneByte[0], err
+}
+
+// readUintptr reads a varint-encoded uinptr value from the connection.
+func (p *pipe) readUintptr() (uintptr, error) {
+ u, err := readUvarint(p)
+ if err != nil {
+ trace("read uintptr:", err)
+ return 0, err
+ }
+ if u > uint64(^uintptr(0)) {
+ trace("read uintptr: overflow")
+ return 0, err
+ }
+ return uintptr(u), nil
+}
+
+var intOverflow = errors.New("ogle probe: varint overflows int")
+
+// readInt reads an varint-encoded int value from the connection.
+// The transported value is always a uint64; this routine
+// verifies that it fits in an int.
+func (p *pipe) readInt() (int, error) {
+ u, err := readUvarint(p)
+ if err != nil {
+ trace("read int:", err)
+ return 0, err
+ }
+ // Does it fit in an int?
+ if u > maxInt {
+ trace("int overflow")
+ return 0, intOverflow
+ }
+ return int(u), nil
+}
+
+// error writes an error message to the connection.
+// The format is [size, size bytes].
+func (p *pipe) error(msg string) {
+ // A zero-length message is problematic. It should never arise, but be safe.
+ if len(msg) == 0 {
+ msg = "undefined error"
+ }
+ // Truncate if necessary. Extremely unlikely.
+ if len(msg) > len(p.buf)-maxVarintLen64 {
+ msg = msg[:len(p.buf)-maxVarintLen64]
+ }
+ n := putUvarint(p.buf[:], uint64(len(msg)))
+ n += copy(p.buf[n:], msg)
+ _, err := p.rw.Write(p.buf[:n])
+ if err != nil {
+ trace("write:", err)
+ // TODO: shut down connection?
+ }
+}
+
+// sendReadResponse sends a read response to the connection.
+// The format is [0, size, size bytes].
+func (p *pipe) sendReadResponse(addr uintptr, size int) {
+ trace("sendRead:", addr, size)
+ m := 0
+ m += putUvarint(p.buf[m:], 0) // No error.
+ m += putUvarint(p.buf[m:], uint64(size)) // Number of bytes to follow.
+ for m > 0 || size > 0 {
+ n := len(p.buf) - m
+ if n > size {
+ n = size
+ }
+ if !read(addr, p.buf[m:m+n]) {
+ trace("copy error")
+ // TODO: shut down connection?
+ // TODO: for now, continue delivering data. We said we would.
+ }
+ _, err := p.rw.Write(p.buf[:m+n])
+ if err != nil {
+ trace("write:", err)
+ // TODO: shut down connection?
+ }
+ addr += uintptr(n)
+ size -= n
+ // Next time we can use the whole buffer.
+ m = 0
+ }
+}
diff --git a/probe/net_test.go b/probe/net_test.go
new file mode 100644
index 0000000..bdb449b
--- /dev/null
+++ b/probe/net_test.go
@@ -0,0 +1,222 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package probe
+
+import (
+ "bufio"
+ "io"
+ "net"
+ "testing"
+ "unsafe"
+)
+
+// traceThisFunction turns on tracing and returns a function to turn it off.
+// It is intended for use as "defer traceThisFunction()()".
+func traceThisFunction() func() {
+ // TODO: This should be done atomically to guarantee the probe can see the update.
+ tracing = true
+ return func() {
+ tracing = false
+ }
+}
+
+type Conn struct {
+ conn net.Conn
+ input *bufio.Reader
+ output *bufio.Writer
+}
+
+func (c *Conn) close() {
+ c.output.Flush()
+ c.conn.Close()
+}
+
+// newConn makes a connection.
+func newConn(t *testing.T) *Conn {
+ // defer traceThisFunction()()
+ conn, err := net.Dial("tcp", "localhost"+port)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return &Conn{
+ conn: conn,
+ input: bufio.NewReader(conn),
+ output: bufio.NewWriter(conn),
+ }
+}
+
+// bytesToUint64 returns the uint64 stored in the 8 bytes of buf.
+func bytesToUint64(buf []byte) uint64 {
+ // We're using same machine here, so byte order is the same.
+ // We can just fetch it, but on some architectures it
+ // must be aligned so copy first.
+ var tmp [8]byte
+ copy(tmp[:], buf)
+ return *(*uint64)(unsafe.Pointer(&tmp[0]))
+}
+
+// Test that we get an error back for a request to read an illegal address.
+func TestReadBadAddress(t *testing.T) {
+ //defer traceThisFunction()()
+
+ conn := newConn(t)
+ defer conn.close()
+
+ // Read the elements in pseudo-random order.
+ var tmp [100]byte
+ // Request a read of a bad address.
+ conn.output.WriteByte('r')
+ // Address.
+ n := putUvarint(tmp[:], uint64(base()-8))
+ conn.output.Write(tmp[:n])
+ // Length. Any length will do.
+ n = putUvarint(tmp[:], 8)
+ conn.output.Write(tmp[:n])
+ // Send it.
+ err := conn.output.Flush()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Read result.
+ // We expect an initial non-zero value, the number of bytes of the error message.
+ u, err := readUvarint(conn.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u == 0 {
+ t.Fatalf("expected error return; got none")
+ }
+ // We expect a particular error.
+ const expect = "invalid read address"
+ if u != uint64(len(expect)) {
+ t.Fatalf("got %d bytes; expected %d", u, len(expect))
+ }
+ _, err = io.ReadFull(conn.input, tmp[:u])
+ if err != nil {
+ t.Fatal(err)
+ }
+ msg := string(tmp[:u])
+ if msg != expect {
+ t.Fatalf("got %q; expected %q", msg, expect)
+ }
+}
+
+// Test that we can read some data from the address space on the other side of the connection.
+func TestReadUint64(t *testing.T) {
+ //defer traceThisFunction()()
+
+ conn := newConn(t)
+ defer conn.close()
+
+ // Some data to send over the wire.
+ data := make([]uint64, 10)
+ for i := range data {
+ data[i] = 0x1234567887654321 + 12345*uint64(i)
+ }
+ // TODO: To be righteous we should put a memory barrier here.
+
+ // Read the elements in pseudo-random order.
+ var tmp [100]byte
+ which := 0
+ for i := 0; i < 100; i++ {
+ which = (which + 7) % len(data)
+ // Request a read of data[which].
+ conn.output.WriteByte('r')
+ // Address.
+ n := putUvarint(tmp[:], uint64(addr(&data[which])))
+ conn.output.Write(tmp[:n])
+ // Length
+ n = putUvarint(tmp[:], 8)
+ conn.output.Write(tmp[:n])
+ // Send it.
+ err := conn.output.Flush()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Read result.
+ // We expect 10 bytes: the initial zero, followed by 8 (the count), followed by 8 bytes of data.
+ u, err := readUvarint(conn.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u != 0 {
+ t.Fatalf("expected leading zero byte; got %#x\n", u)
+ }
+ // N bytes of data.
+ u, err = readUvarint(conn.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u != 8 {
+ t.Fatalf("got %d bytes of data; expected 8", u)
+ }
+ _, err = io.ReadFull(conn.input, tmp[:u])
+ if err != nil {
+ t.Fatal(err)
+ }
+ u = bytesToUint64(tmp[:u])
+ if u != data[which] {
+ t.Fatalf("got %#x; expected %#x", u, data[which])
+ }
+ }
+}
+
+// Test that we can read an array bigger than the pipe's buffer size.
+func TestBigRead(t *testing.T) {
+ // defer traceThisFunction()()
+
+ conn := newConn(t)
+ defer conn.close()
+
+ // A big array.
+ data := make([]byte, 3*len(pipe{}.buf))
+ noise := 17
+ for i := range data {
+ data[i] = byte(noise)
+ noise += 23
+ }
+ // TODO: To be righteous we should put a memory barrier here.
+
+ // Read the elements in pseudo-random order.
+ tmp := make([]byte, len(data))
+ conn.output.WriteByte('r')
+ // Address.
+ n := putUvarint(tmp[:], uint64(addr(&data[0])))
+ conn.output.Write(tmp[:n])
+ // Length
+ n = putUvarint(tmp[:], uint64(len(data)))
+ conn.output.Write(tmp[:n])
+ // Send it.
+ err := conn.output.Flush()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Read result.
+ // We expect the full data back.
+ u, err := readUvarint(conn.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u != 0 {
+ t.Fatalf("expected leading zero byte; got %#x\n", u)
+ }
+ // N bytes of data.
+ u, err = readUvarint(conn.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u != uint64(len(data)) {
+ t.Fatalf("got %d bytes of data; expected 8", u)
+ }
+ _, err = io.ReadFull(conn.input, tmp)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, c := range data {
+ if tmp[i] != c {
+ t.Fatalf("at offset %d expected %#x; got %#x", i, c, tmp[i])
+ }
+ }
+}
diff --git a/probe/probe.go b/probe/probe.go
index 002f671..d932695 100644
--- a/probe/probe.go
+++ b/probe/probe.go
@@ -4,8 +4,13 @@
// Package probe is imported by programs to provide (possibly remote)
// access to a separate debugger program.
-package main
+package probe
+import (
+ "unsafe"
+)
+
+// Defined in assembler.
func base() uintptr
func etext() uintptr
func edata() uintptr
@@ -50,3 +55,29 @@
}
return false
}
+
+// read copies into the argument buffer the contents of memory starting at address p.
+// Its boolean return tells whether it succeeded. If it fails, no bytes were copied.
+func read(p uintptr, buf []byte) bool {
+ if !validRead(p, len(buf)) {
+ return false
+ }
+ for i := range buf {
+ buf[i] = *(*byte)(unsafe.Pointer(p))
+ p++
+ }
+ return true
+}
+
+// write copies the argument buffer to memory starting at address p.
+// Its boolean return tells whether it succeeded. If it fails, no bytes were copied.
+func write(p uintptr, buf []byte) bool {
+ if !validWrite(p, len(buf)) {
+ return false
+ }
+ for i := range buf {
+ *(*byte)(unsafe.Pointer(p)) = buf[i]
+ p++
+ }
+ return true
+}
diff --git a/probe/size_amd64.go b/probe/size_amd64.go
new file mode 100644
index 0000000..08783a5
--- /dev/null
+++ b/probe/size_amd64.go
@@ -0,0 +1,12 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package probe
+
+// Sizes of values for the AMD64 architecture.
+
+const (
+ ptrSize = 8
+ maxInt = 1<<63 - 1
+)
diff --git a/probe/varint.go b/probe/varint.go
new file mode 100644
index 0000000..2d4f1b3
--- /dev/null
+++ b/probe/varint.go
@@ -0,0 +1,82 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package probe
+
+// This file contains an implementation of "varint" encoding and decoding.
+// Code is adapted from encoding/binary/varint.go, copied here to avoid dependencies,
+// simplified somewhat, and made local to the package.
+// It handles unsigned integers only.
+
+import (
+ "errors"
+ "io"
+)
+
+// maxVarintLenN is the maximum length of a varint-encoded N-bit integer.
+const (
+ maxVarintLen16 = 3
+ maxVarintLen32 = 5
+ maxVarintLen64 = 10
+)
+
+// putUvarint encodes a uint64 into buf and returns the number of bytes written.
+// If the buffer is too small, putUvarint will panic.
+func putUvarint(buf []byte, x uint64) int {
+ i := 0
+ for x >= 0x80 {
+ buf[i] = byte(x) | 0x80
+ x >>= 7
+ i++
+ }
+ buf[i] = byte(x)
+ return i + 1
+}
+
+// getUvarint decodes a uint64 from buf and returns that value and the
+// number of bytes read (> 0). If an error occurred, the value is 0
+// and the number of bytes n is <= 0 meaning:
+//
+// n == 0: buf too small
+// n < 0: value larger than 64 bits (overflow)
+// and -n is the number of bytes read
+//
+// TODO: Unused. Delete if it doesn't get used.
+func getUvarint(buf []byte) (uint64, int) {
+ var x uint64
+ var s uint
+ for i, b := range buf {
+ if b < 0x80 {
+ if i > 9 || i == 9 && b > 1 {
+ return 0, -(i + 1) // overflow
+ }
+ return x | uint64(b)<<s, i + 1
+ }
+ x |= uint64(b&0x7f) << s
+ s += 7
+ }
+ return 0, 0
+}
+
+var overflow = errors.New("ogle probe: varint overflows a 64-bit integer")
+
+// readUvarint reads an encoded unsigned integer from r and returns it as a uint64.
+func readUvarint(r io.ByteReader) (uint64, error) {
+ var x uint64
+ var s uint
+ for i := 0; ; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ return x, err
+ }
+ if b < 0x80 {
+ if i > 9 || i == 9 && b > 1 {
+ return x, overflow
+ }
+ return x | uint64(b)<<s, nil
+ }
+ x |= uint64(b&0x7f) << s
+ s += 7
+ }
+}