sandbox: timeout runsc commands
The current mechanism for forcing a process to die ater a timeout is not
sufficient. This change fixes issues that were causing processes to run
forever on the sandbox.
- Gracefully terminate processes before we kill them inside of our
gVisor process. This helps capture valuable debug output for the user.
- Return a friendlier error when our run context times out on the
playground.
- Add a test that timeouts are handled gracefully.
- Reduce concurrent goroutines in our sandbox run handler by replacing
goroutine copy functions with a custom writer (limitedWriter) that
returns an error if too much output is returned, halting the program.
- Custom writers (limitedWriter, switchWriter) also fix timing errors
when calling Wait() too soon on a Command, before we have read all of
the data. It also fixes a different error from trying to read data after
a program has terminated.
- Remove goroutine from startContainer, and use a ticker + context
timeout for synchronization.
Updates golang/go#25224
Updates golang/go#38343
Change-Id: Ie9d65220e5c4f39272ea70b45c4b472bcd7069bb
Reviewed-on: https://go-review.googlesource.com/c/playground/+/227652
Run-TryBot: Alexander Rakoczy <alex@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Bryan C. Mills <bcmills@google.com>
diff --git a/Makefile b/Makefile
index e96fbd8..e670713 100644
--- a/Makefile
+++ b/Makefile
@@ -16,7 +16,7 @@
test_go:
# Run fast tests first: (and tests whether, say, things compile)
- GO111MODULE=on go test -v
+ GO111MODULE=on go test -v ./...
test_gvisor: docker
docker kill sandbox_front_test || true
diff --git a/sandbox.go b/sandbox.go
index 3f44a09..07a0f07 100644
--- a/sandbox.go
+++ b/sandbox.go
@@ -526,6 +526,10 @@
sreq.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewReader(exeBytes)), nil }
res, err := sandboxBackendClient().Do(sreq)
if err != nil {
+ if ctx.Err() == context.DeadlineExceeded {
+ execRes.Error = "timeout running program"
+ return execRes, nil
+ }
return execRes, fmt.Errorf("POST %q: %w", sandboxBackendURL(), err)
}
defer res.Body.Close()
diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go
index 603088b..07de286 100644
--- a/sandbox/sandbox.go
+++ b/sandbox/sandbox.go
@@ -43,12 +43,16 @@
const (
maxBinarySize = 100 << 20
+ startTimeout = 30 * time.Second
runTimeout = 5 * time.Second
maxOutputSize = 100 << 20
memoryLimitBytes = 100 << 20
)
-var errTooMuchOutput = errors.New("Output too large")
+var (
+ errTooMuchOutput = errors.New("Output too large")
+ errRunTimeout = errors.New("timeout running program")
+)
// containedStartMessage is the first thing written to stdout by the
// gvisor-contained process when it starts up. This lets the parent HTTP
@@ -68,8 +72,8 @@
type Container struct {
name string
stdin io.WriteCloser
- stdout io.ReadCloser
- stderr io.ReadCloser
+ stdout *limitedWriter
+ stderr *limitedWriter
cmd *exec.Cmd
waitOnce sync.Once
@@ -78,12 +82,12 @@
func (c *Container) Close() {
setContainerWanted(c.name, false)
- c.stdin.Close()
- c.stdout.Close()
- c.stderr.Close()
+
if c.cmd.Process != nil {
- c.cmd.Process.Kill()
- c.Wait() // just in case
+ gracefulStop(c.cmd.Process, 250*time.Millisecond)
+ if err := c.Wait(); err != nil {
+ log.Printf("error in c.Wait() for %q: %v", c.name, err)
+ }
}
}
@@ -245,6 +249,7 @@
if err := ioutil.WriteFile(binPath, bin, 0755); err != nil {
log.Fatalf("writing contained binary: %v", err)
}
+ defer os.Remove(binPath) // not that it matters much, this container will be nuked
var meta processMeta
if err := json.NewDecoder(bytes.NewReader(metaJSON)).Decode(&meta); err != nil {
@@ -262,12 +267,34 @@
if err := cmd.Start(); err != nil {
log.Fatalf("cmd.Start(): %v", err)
}
+ timer := time.AfterFunc(runTimeout-(500*time.Millisecond), func() {
+ fmt.Fprintln(os.Stderr, "timeout running program")
+ gracefulStop(cmd.Process, 250*time.Millisecond)
+ })
+ defer timer.Stop()
err = cmd.Wait()
- os.Remove(binPath) // not that it matters much, this container will be nuked
os.Exit(errExitCode(err))
return
}
+// gracefulStop attempts to send a SIGINT before a SIGKILL.
+//
+// The process will be sent a SIGINT immediately. If the context has still not been cancelled,
+// the process will be sent a SIGKILL after delay has passed since sending the SIGINT.
+//
+// TODO(golang.org/issue/38343) - Change SIGINT to SIGQUIT once decision is made.
+func gracefulStop(p *os.Process, delay time.Duration) {
+ // TODO(golang.org/issue/38343) - Change to syscall.SIGQUIT once decision is made.
+ if err := p.Signal(os.Interrupt); err != nil {
+ log.Printf("cmd.Process.Signal(%v): %v", os.Interrupt, err)
+ }
+ time.AfterFunc(delay, func() {
+ if err := p.Kill(); err != nil {
+ log.Printf("cmd.Process.Kill(): %v", err)
+ }
+ })
+}
+
func makeWorkers() {
for {
c, err := startContainer(context.Background())
@@ -321,25 +348,6 @@
func startContainer(ctx context.Context) (c *Container, err error) {
name := "play_run_" + randHex(8)
setContainerWanted(name, true)
- var stdin io.WriteCloser
- var stdout io.ReadCloser
- var stderr io.ReadCloser
- defer func() {
- if err == nil {
- return
- }
- setContainerWanted(name, false)
- if stdin != nil {
- stdin.Close()
- }
- if stdout != nil {
- stdout.Close()
- }
- if stderr != nil {
- stderr.Close()
- }
- }()
-
cmd := exec.Command("docker", "run",
"--name="+name,
"--rm",
@@ -352,46 +360,53 @@
*container,
"--mode=contained")
- stdin, err = cmd.StdinPipe()
+ stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}
- stdout, err = cmd.StdoutPipe()
- if err != nil {
- return nil, err
- }
- stderr, err = cmd.StderrPipe()
- if err != nil {
- return nil, err
- }
+ pr, pw := io.Pipe()
+ stdout := &limitedWriter{dst: &bytes.Buffer{}, n: maxOutputSize + int64(len(containedStartMessage))}
+ stderr := &limitedWriter{dst: &bytes.Buffer{}, n: maxOutputSize}
+ cmd.Stdout = &switchWriter{switchAfter: []byte(containedStartMessage), dst1: pw, dst2: stdout}
+ cmd.Stderr = stderr
if err := cmd.Start(); err != nil {
return nil, err
}
+ defer func() {
+ if err != nil {
+ log.Printf("error starting container %q: %v", name, err)
+ gracefulStop(cmd.Process, 250*time.Millisecond)
+ setContainerWanted(name, false)
+ }
+ }()
+ ctx, cancel := context.WithTimeout(ctx, startTimeout)
+ defer cancel()
- errc := make(chan error, 1)
+ startErr := make(chan error, 1)
go func() {
buf := make([]byte, len(containedStartMessage))
- if _, err := io.ReadFull(stdout, buf); err != nil {
- errc <- fmt.Errorf("error reading header from sandbox container: %v", err)
- return
+ _, err := io.ReadFull(pr, buf)
+ if err != nil {
+ startErr <- fmt.Errorf("error reading header from sandbox container: %v", err)
+ } else if string(buf) != containedStartMessage {
+ startErr <- fmt.Errorf("sandbox container sent wrong header %q; want %q", buf, containedStartMessage)
+ } else {
+ startErr <- nil
}
- if string(buf) != containedStartMessage {
- errc <- fmt.Errorf("sandbox container sent wrong header %q; want %q", buf, containedStartMessage)
- return
- }
- errc <- nil
}()
+
select {
case <-ctx.Done():
- log.Printf("timeout starting container")
- cmd.Process.Kill()
- return nil, ctx.Err()
- case err := <-errc:
+ err := fmt.Errorf("timeout starting container %q: %w", name, ctx.Err())
+ pw.Close()
+ <-startErr
+ return nil, err
+ case err = <-startErr:
if err != nil {
- log.Printf("error starting container: %v", err)
return nil, err
}
}
+ log.Printf("started container %q", name)
return &Container{
name: name,
stdin: stdin,
@@ -435,6 +450,7 @@
bin, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, maxBinarySize))
if err != nil {
+ log.Printf("failed to read request body: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -451,88 +467,136 @@
return
}
logf("got container %s", c.name)
- defer c.Close()
- defer logf("leaving handler; about to close container")
- runTimer := time.NewTimer(runTimeout)
- defer runTimer.Stop()
-
- errc := make(chan error, 2) // user-visible error
- waitc := make(chan error, 1)
-
- copyOut := func(which string, dst *[]byte, r io.Reader) {
- buf := make([]byte, 4<<10)
- for {
- n, err := r.Read(buf)
- logf("%s: Read = %v, %v", which, n, err)
- *dst = append(*dst, buf[:n]...)
- if err == io.EOF {
- return
- }
- if len(*dst) > maxOutputSize {
- errc <- errTooMuchOutput
- return
- }
- if err != nil {
- log.Printf("reading %s: %v", which, err)
- errc <- fmt.Errorf("error reading %v", which)
- return
- }
- }
- }
-
- res := &sandboxtypes.Response{}
- go func() {
- var meta processMeta
- meta.Args = r.Header["X-Argument"]
- metaJSON, _ := json.Marshal(&meta)
- metaJSON = append(metaJSON, '\n')
- if _, err := c.stdin.Write(metaJSON); err != nil {
- log.Printf("stdin write meta: %v", err)
- errc <- errors.New("failed to write meta to child")
- return
- }
- if _, err := c.stdin.Write(bin); err != nil {
- log.Printf("stdin write: %v", err)
- errc <- errors.New("failed to write binary to child")
- return
- }
- c.stdin.Close()
- logf("wrote+closed")
- go copyOut("stdout", &res.Stdout, c.stdout)
- go copyOut("stderr", &res.Stderr, c.stderr)
- waitc <- c.Wait()
+ ctx, cancel := context.WithTimeout(context.Background(), runTimeout)
+ closed := make(chan struct{})
+ defer func() {
+ logf("leaving handler; about to close container")
+ cancel()
+ <-closed
}()
- var waitErr error
- select {
- case waitErr = <-waitc:
- logf("waited: %v", waitErr)
- case err := <-errc:
- logf("got error: %v", err)
- if err == errTooMuchOutput {
- sendError(w, err.Error())
- return
+ go func() {
+ <-ctx.Done()
+ if ctx.Err() == context.DeadlineExceeded {
+ logf("timeout")
}
- if err != nil {
- http.Error(w, "failed to read stdout from docker run", http.StatusInternalServerError)
- return
- }
- case <-runTimer.C:
- logf("timeout")
- sendError(w, "timeout running program")
+ c.Close()
+ close(closed)
+ }()
+ var meta processMeta
+ meta.Args = r.Header["X-Argument"]
+ metaJSON, _ := json.Marshal(&meta)
+ metaJSON = append(metaJSON, '\n')
+ if _, err := c.stdin.Write(metaJSON); err != nil {
+ log.Printf("failed to write meta to child: %v", err)
+ http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
return
}
-
- res.ExitCode = errExitCode(waitErr)
- res.Stderr = cleanStderr(res.Stderr)
+ if _, err := c.stdin.Write(bin); err != nil {
+ log.Printf("failed to write binary to child: %v", err)
+ http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
+ return
+ }
+ c.stdin.Close()
+ logf("wrote+closed")
+ err = c.Wait()
+ select {
+ case <-ctx.Done():
+ // Timed out or canceled before or exactly as Wait returned.
+ // Either way, treat it as a timeout.
+ sendError(w, "timeout running program")
+ return
+ default:
+ logf("finished running; about to close container")
+ cancel()
+ }
+ res := &sandboxtypes.Response{}
+ if err != nil {
+ if c.stderr.n < 0 || c.stdout.n < 0 {
+ // Do not send truncated output, just send the error.
+ sendError(w, errTooMuchOutput.Error())
+ return
+ }
+ var ee *exec.ExitError
+ if !errors.As(err, &ee) {
+ http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
+ return
+ }
+ res.ExitCode = ee.ExitCode()
+ }
+ res.Stdout = c.stdout.dst.Bytes()
+ res.Stderr = cleanStderr(c.stderr.dst.Bytes())
sendResponse(w, res)
}
+// limitedWriter is an io.Writer that returns an errTooMuchOutput when the cap (n) is hit.
+type limitedWriter struct {
+ dst *bytes.Buffer
+ n int64 // max bytes remaining
+}
+
+// Write is an io.Writer function that returns errTooMuchOutput when the cap (n) is hit.
+//
+// Partial data will be written to dst if p is larger than n, but errTooMuchOutput will be returned.
+func (l *limitedWriter) Write(p []byte) (int, error) {
+ defer func() { l.n -= int64(len(p)) }()
+
+ if l.n <= 0 {
+ return 0, errTooMuchOutput
+ }
+
+ if int64(len(p)) > l.n {
+ n, err := l.dst.Write(p[:l.n])
+ if err != nil {
+ return n, err
+ }
+ return n, errTooMuchOutput
+ }
+
+ return l.dst.Write(p)
+}
+
+// switchWriter writes to dst1 until switchAfter is written, the it writes to dst2.
+type switchWriter struct {
+ dst1 io.Writer
+ dst2 io.Writer
+ switchAfter []byte
+ buf []byte
+ found bool
+}
+
+func (s *switchWriter) Write(p []byte) (int, error) {
+ if s.found {
+ return s.dst2.Write(p)
+ }
+
+ s.buf = append(s.buf, p...)
+ i := bytes.Index(s.buf, s.switchAfter)
+ if i == -1 {
+ if len(s.buf) >= len(s.switchAfter) {
+ s.buf = s.buf[len(s.buf)-len(s.switchAfter)+1:]
+ }
+ return s.dst1.Write(p)
+ }
+
+ s.found = true
+ nAfter := len(s.buf) - (i + len(s.switchAfter))
+ s.buf = nil
+
+ n, err := s.dst1.Write(p[:len(p)-nAfter])
+ if err != nil {
+ return n, err
+ }
+ n2, err := s.dst2.Write(p[len(p)-nAfter:])
+ return n + n2, err
+}
+
func errExitCode(err error) int {
if err == nil {
return 0
}
- if ee, ok := err.(*exec.ExitError); ok {
+ var ee *exec.ExitError
+ if errors.As(err, &ee) {
return ee.ExitCode()
}
return 1
diff --git a/sandbox/sandbox_test.go b/sandbox/sandbox_test.go
new file mode 100644
index 0000000..7b32ec6
--- /dev/null
+++ b/sandbox/sandbox_test.go
@@ -0,0 +1,180 @@
+package main
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+ "testing/iotest"
+)
+
+func TestLimitedWriter(t *testing.T) {
+ cases := []struct {
+ desc string
+ lw *limitedWriter
+ in []byte
+ want []byte
+ wantN int64
+ wantRemaining int64
+ err error
+ }{
+ {
+ desc: "simple",
+ lw: &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+ in: []byte("hi"),
+ want: []byte("hi"),
+ wantN: 2,
+ wantRemaining: 8,
+ },
+ {
+ desc: "writing nothing",
+ lw: &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+ in: []byte(""),
+ want: []byte(""),
+ wantN: 0,
+ wantRemaining: 10,
+ },
+ {
+ desc: "writing exactly enough",
+ lw: &limitedWriter{dst: &bytes.Buffer{}, n: 6},
+ in: []byte("enough"),
+ want: []byte("enough"),
+ wantN: 6,
+ wantRemaining: 0,
+ err: nil,
+ },
+ {
+ desc: "writing too much",
+ lw: &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+ in: []byte("this is much longer than 10"),
+ want: []byte("this is mu"),
+ wantN: 10,
+ wantRemaining: -1,
+ err: errTooMuchOutput,
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.desc, func(t *testing.T) {
+ n, err := io.Copy(c.lw, iotest.OneByteReader(bytes.NewReader(c.in)))
+ if err != c.err {
+ t.Errorf("c.lw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+ }
+ if n != c.wantN {
+ t.Errorf("c.lw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+ }
+ if c.lw.n != c.wantRemaining {
+ t.Errorf("c.lw.n = %d, wanted %d", c.lw.n, c.wantRemaining)
+ }
+ if string(c.lw.dst.Bytes()) != string(c.want) {
+ t.Errorf("c.lw.dst.Bytes() = %q, wanted %q", c.lw.dst.Bytes(), c.want)
+ }
+ })
+ }
+}
+
+func TestSwitchWriter(t *testing.T) {
+ cases := []struct {
+ desc string
+ sw *switchWriter
+ in []byte
+ want1 []byte
+ want2 []byte
+ wantN int64
+ wantFound bool
+ err error
+ }{
+ {
+ desc: "not found",
+ sw: &switchWriter{switchAfter: []byte("UNIQUE")},
+ in: []byte("hi"),
+ want1: []byte("hi"),
+ want2: []byte(""),
+ wantN: 2,
+ wantFound: false,
+ },
+ {
+ desc: "writing nothing",
+ sw: &switchWriter{switchAfter: []byte("UNIQUE")},
+ in: []byte(""),
+ want1: []byte(""),
+ want2: []byte(""),
+ wantN: 0,
+ wantFound: false,
+ },
+ {
+ desc: "writing exactly switchAfter",
+ sw: &switchWriter{switchAfter: []byte("UNIQUE")},
+ in: []byte("UNIQUE"),
+ want1: []byte("UNIQUE"),
+ want2: []byte(""),
+ wantN: 6,
+ wantFound: true,
+ },
+ {
+ desc: "writing before and after switchAfter",
+ sw: &switchWriter{switchAfter: []byte("UNIQUE")},
+ in: []byte("this is before UNIQUE and this is after"),
+ want1: []byte("this is before UNIQUE"),
+ want2: []byte(" and this is after"),
+ wantN: 39,
+ wantFound: true,
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.desc, func(t *testing.T) {
+ dst1, dst2 := &bytes.Buffer{}, &bytes.Buffer{}
+ c.sw.dst1, c.sw.dst2 = dst1, dst2
+ n, err := io.Copy(c.sw, iotest.OneByteReader(bytes.NewReader(c.in)))
+ if err != c.err {
+ t.Errorf("c.sw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+ }
+ if n != c.wantN {
+ t.Errorf("c.sw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+ }
+ if c.sw.found != c.wantFound {
+ t.Errorf("c.sw.found = %v, wanted %v", c.sw.found, c.wantFound)
+ }
+ if string(dst1.Bytes()) != string(c.want1) {
+ t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), c.want1)
+ }
+ if string(dst2.Bytes()) != string(c.want2) {
+ t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), c.want2)
+ }
+ })
+ }
+}
+
+func TestSwitchWriterMultipleWrites(t *testing.T) {
+ dst1, dst2 := &bytes.Buffer{}, &bytes.Buffer{}
+ sw := &switchWriter{
+ dst1: dst1,
+ dst2: dst2,
+ switchAfter: []byte("GOPHER"),
+ }
+ n, err := io.Copy(sw, iotest.OneByteReader(strings.NewReader("this is before GO")))
+ if err != nil || n != 17 {
+ t.Errorf("sw.Write(%q) = %d, %q, wanted %d, no error", "this is before GO", n, err, 17)
+ }
+ if sw.found {
+ t.Errorf("sw.found = %v, wanted %v", sw.found, false)
+ }
+ if string(dst1.Bytes()) != "this is before GO" {
+ t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), "this is before GO")
+ }
+ if string(dst2.Bytes()) != "" {
+ t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), "")
+ }
+ n, err = io.Copy(sw, iotest.OneByteReader(strings.NewReader("PHER and this is after")))
+ if err != nil || n != 22 {
+ t.Errorf("sw.Write(%q) = %d, %q, wanted %d, no error", "this is before GO", n, err, 22)
+ }
+ if !sw.found {
+ t.Errorf("sw.found = %v, wanted %v", sw.found, true)
+ }
+ if string(dst1.Bytes()) != "this is before GOPHER" {
+ t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), "this is before GOPHEr")
+ }
+ if string(dst2.Bytes()) != " and this is after" {
+ t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), " and this is after")
+ }
+}
diff --git a/tests.go b/tests.go
index 9e0d55f..3db7e63 100644
--- a/tests.go
+++ b/tests.go
@@ -67,7 +67,7 @@
continue
}
if resp.Errors != "" {
- stdlog.Print(resp.Errors)
+ stdlog.Printf("resp.Errors = %q, want %q", resp.Errors, t.errors)
failed = true
continue
}
@@ -548,4 +548,26 @@
func Hello() { fmt.Println("hello world") }
`,
},
+ {
+ name: "timeouts_handled_gracefully",
+ prog: `
+package main
+
+import (
+ "time"
+)
+
+func main() {
+ c := make(chan struct{})
+
+ go func() {
+ defer close(c)
+ for {
+ time.Sleep(10 * time.Millisecond)
+ }
+ }()
+
+ <-c
+}
+`, want: "timeout running program"},
}