ogle/socket: new package.
LGTM=r
R=r
https://golang.org/cl/61540047
diff --git a/probe/net.go b/probe/net.go
index 55a088a..f5a7275 100644
--- a/probe/net.go
+++ b/probe/net.go
@@ -10,14 +10,16 @@
"errors"
"io" // Used only for the definitions of the various interfaces and errors.
"net"
+
+ "code.google.com/p/ogle/socket"
)
var (
- port = ":54321" // TODO: how to choose port number?
- tracing = false
+ tracing = false
+ listening = make(chan struct{})
)
-// init starts a network listener and leaves it in the background waiting for connections.
+// init starts a listener and leaves it in the background waiting for connections.
func init() {
go demon()
}
@@ -26,7 +28,8 @@
// The server runs in the same goroutine as the demon, so a new connection cannot be
// established until the previous one is completed.
func demon() {
- listener, err := net.Listen("tcp", port)
+ listener, err := socket.Listen()
+ close(listening)
if err != nil {
trace("listen:", err)
return
diff --git a/probe/net_test.go b/probe/net_test.go
index bdb449b..a107419 100644
--- a/probe/net_test.go
+++ b/probe/net_test.go
@@ -8,8 +8,11 @@
"bufio"
"io"
"net"
+ "os"
"testing"
"unsafe"
+
+ "code.google.com/p/ogle/socket"
)
// traceThisFunction turns on tracing and returns a function to turn it off.
@@ -36,7 +39,8 @@
// newConn makes a connection.
func newConn(t *testing.T) *Conn {
// defer traceThisFunction()()
- conn, err := net.Dial("tcp", "localhost"+port)
+ <-listening
+ conn, err := socket.Dial(os.Getuid(), os.Getpid())
if err != nil {
t.Fatal(err)
}
@@ -220,3 +224,10 @@
}
}
}
+
+// TestCollectGarbage doesn't actually test anything, but it does collect any
+// garbage sockets that are no longer used. It is a courtesy for computers that
+// run this test suite often.
+func TestCollectGarbage(t *testing.T) {
+ socket.CollectGarbage()
+}
diff --git a/socket/socket.go b/socket/socket.go
new file mode 100644
index 0000000..cc14dc4
--- /dev/null
+++ b/socket/socket.go
@@ -0,0 +1,107 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package socket provides a way for multiple processes from the same user to
+// communicate over a Unix domain socket.
+package socket
+
+// TODO: euid instead of uid?
+// TODO: Windows support.
+
+import (
+ "net"
+ "os"
+ "syscall"
+)
+
+// atoi is like strconv.Atoi but we aim to minimize this package's dependencies.
+func atoi(s string) (i int, ok bool) {
+ for _, c := range s {
+ if c < '0' || '9' < c {
+ return 0, false
+ }
+ i = 10*i + int(c-'0')
+ }
+ return i, true
+}
+
+// itoa is like strconv.Itoa but we aim to minimize this package's dependencies.
+func itoa(i int) string {
+ var buf [30]byte
+ n := len(buf)
+ neg := false
+ if i < 0 {
+ i = -i
+ neg = true
+ }
+ ui := uint(i)
+ for ui > 0 || n == len(buf) {
+ n--
+ buf[n] = byte('0' + ui%10)
+ ui /= 10
+ }
+ if neg {
+ n--
+ buf[n] = '-'
+ }
+ return string(buf[n:])
+}
+
+func names(uid, pid int) (dirName, socketName string) {
+ dirName = "/tmp/ogle-socket-uid" + itoa(uid)
+ socketName = dirName + "/pid" + itoa(pid)
+ return
+}
+
+// Listen creates a PID-specific socket under a UID-specific sub-directory of
+// /tmp. That sub-directory is created with 0700 permission bits (before
+// umasking), so that only processes with the same UID can dial that socket.
+func Listen() (net.Listener, error) {
+ dirName, socketName := names(os.Getuid(), os.Getpid())
+ if err := os.MkdirAll(dirName, 0700); err != nil {
+ return nil, err
+ }
+ if err := os.Remove(socketName); err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+ return net.Listen("unix", socketName)
+}
+
+// Dial dials the Unix domain socket created by the process with the given UID
+// and PID.
+func Dial(uid, pid int) (net.Conn, error) {
+ _, socketName := names(uid, pid)
+ return net.Dial("unix", socketName)
+}
+
+// CollectGarbage deletes any no-longer-used sockets in the UID-specific sub-
+// directory of /tmp.
+func CollectGarbage() {
+ dirName, _ := names(os.Getuid(), os.Getpid())
+ dir, err := os.Open(dirName)
+ if err != nil {
+ return
+ }
+ defer dir.Close()
+ fileNames, err := dir.Readdirnames(-1)
+ if err != nil {
+ return
+ }
+ for _, fileName := range fileNames {
+ if len(fileName) < 3 || fileName[:3] != "pid" {
+ continue
+ }
+ pid, ok := atoi(fileName[3:])
+ if !ok {
+ continue
+ }
+ // See if there is a process with the given PID. The os.FindProcess function
+ // looks relevant, but on Unix that always succeeds even if there is no such
+ // process. Instead, we send signal 0 and look for ESRCH.
+ if syscall.Kill(pid, 0) != syscall.ESRCH {
+ continue
+ }
+ os.Remove(dirName + "/" + fileName)
+ }
+}
diff --git a/socket/socket_test.go b/socket/socket_test.go
new file mode 100644
index 0000000..b07a732
--- /dev/null
+++ b/socket/socket_test.go
@@ -0,0 +1,79 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socket
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestSocket(t *testing.T) {
+ const msg = "Zoich!"
+
+ l, err := Listen()
+ if err != nil {
+ t.Fatalf("listen: %v", err)
+ }
+ defer l.Close()
+
+ wc := make(chan string, 1)
+ go func(c chan string) {
+ w, err := Dial(os.Getuid(), os.Getpid())
+ if err != nil {
+ c <- fmt.Sprintf("dial: %v", err)
+ return
+ }
+ defer w.Close()
+ _, err = w.Write([]byte(msg))
+ if err != nil {
+ c <- fmt.Sprintf("write: %v", err)
+ return
+ }
+ c <- ""
+ }(wc)
+
+ rc := make(chan string, 1)
+ go func(c chan string) {
+ r, err := l.Accept()
+ if err != nil {
+ c <- fmt.Sprintf("accept: %v", err)
+ return
+ }
+ defer r.Close()
+ s, err := ioutil.ReadAll(r)
+ if err != nil {
+ c <- fmt.Sprintf("readAll: %v", err)
+ return
+ }
+ c <- string(s)
+ }(rc)
+
+ for wc != nil || rc != nil {
+ select {
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("timed out")
+ case errStr := <-wc:
+ if errStr != "" {
+ t.Fatal(errStr)
+ }
+ wc = nil
+ case got := <-rc:
+ if got != msg {
+ t.Fatalf("got %q, want %q", got, msg)
+ }
+ rc = nil
+ }
+ }
+}
+
+// TestCollectGarbage doesn't actually test anything, but it does collect any
+// garbage sockets that are no longer used. It is a courtesy for computers that
+// run this test suite often.
+func TestCollectGarbage(t *testing.T) {
+ CollectGarbage()
+}