sandbox: timeout runsc commands

The current mechanism for forcing a process to die ater a timeout is not
sufficient. This change fixes issues that were causing processes to run
forever on the sandbox.

- Gracefully terminate processes before we kill them inside of our
gVisor process. This helps capture valuable debug output for the user.

- Return a friendlier error when our run context times out on the
playground.

- Add a test that timeouts are handled gracefully.

- Reduce concurrent goroutines in our sandbox run handler by replacing
goroutine copy functions with a custom writer (limitedWriter) that
returns an error if too much output is returned, halting the program.

- Custom writers (limitedWriter, switchWriter) also fix timing errors
when calling Wait() too soon on a Command, before we have read all of
the data. It also fixes a different error from trying to read data after
a program has terminated.

- Remove goroutine from startContainer, and use a ticker + context
timeout for synchronization.

Updates golang/go#25224
Updates golang/go#38343

Change-Id: Ie9d65220e5c4f39272ea70b45c4b472bcd7069bb
Reviewed-on: https://go-review.googlesource.com/c/playground/+/227652
Run-TryBot: Alexander Rakoczy <alex@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Bryan C. Mills <bcmills@google.com>
diff --git a/Makefile b/Makefile
index e96fbd8..e670713 100644
--- a/Makefile
+++ b/Makefile
@@ -16,7 +16,7 @@
 
 test_go:
 	# Run fast tests first: (and tests whether, say, things compile)
-	GO111MODULE=on go test -v
+	GO111MODULE=on go test -v ./...
 
 test_gvisor: docker
 	docker kill sandbox_front_test || true
diff --git a/sandbox.go b/sandbox.go
index 3f44a09..07a0f07 100644
--- a/sandbox.go
+++ b/sandbox.go
@@ -526,6 +526,10 @@
 	sreq.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewReader(exeBytes)), nil }
 	res, err := sandboxBackendClient().Do(sreq)
 	if err != nil {
+		if ctx.Err() == context.DeadlineExceeded {
+			execRes.Error = "timeout running program"
+			return execRes, nil
+		}
 		return execRes, fmt.Errorf("POST %q: %w", sandboxBackendURL(), err)
 	}
 	defer res.Body.Close()
diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go
index 603088b..07de286 100644
--- a/sandbox/sandbox.go
+++ b/sandbox/sandbox.go
@@ -43,12 +43,16 @@
 
 const (
 	maxBinarySize    = 100 << 20
+	startTimeout     = 30 * time.Second
 	runTimeout       = 5 * time.Second
 	maxOutputSize    = 100 << 20
 	memoryLimitBytes = 100 << 20
 )
 
-var errTooMuchOutput = errors.New("Output too large")
+var (
+	errTooMuchOutput = errors.New("Output too large")
+	errRunTimeout    = errors.New("timeout running program")
+)
 
 // containedStartMessage is the first thing written to stdout by the
 // gvisor-contained process when it starts up. This lets the parent HTTP
@@ -68,8 +72,8 @@
 type Container struct {
 	name   string
 	stdin  io.WriteCloser
-	stdout io.ReadCloser
-	stderr io.ReadCloser
+	stdout *limitedWriter
+	stderr *limitedWriter
 	cmd    *exec.Cmd
 
 	waitOnce sync.Once
@@ -78,12 +82,12 @@
 
 func (c *Container) Close() {
 	setContainerWanted(c.name, false)
-	c.stdin.Close()
-	c.stdout.Close()
-	c.stderr.Close()
+
 	if c.cmd.Process != nil {
-		c.cmd.Process.Kill()
-		c.Wait() // just in case
+		gracefulStop(c.cmd.Process, 250*time.Millisecond)
+		if err := c.Wait(); err != nil {
+			log.Printf("error in c.Wait() for %q: %v", c.name, err)
+		}
 	}
 }
 
@@ -245,6 +249,7 @@
 	if err := ioutil.WriteFile(binPath, bin, 0755); err != nil {
 		log.Fatalf("writing contained binary: %v", err)
 	}
+	defer os.Remove(binPath) // not that it matters much, this container will be nuked
 
 	var meta processMeta
 	if err := json.NewDecoder(bytes.NewReader(metaJSON)).Decode(&meta); err != nil {
@@ -262,12 +267,34 @@
 	if err := cmd.Start(); err != nil {
 		log.Fatalf("cmd.Start(): %v", err)
 	}
+	timer := time.AfterFunc(runTimeout-(500*time.Millisecond), func() {
+		fmt.Fprintln(os.Stderr, "timeout running program")
+		gracefulStop(cmd.Process, 250*time.Millisecond)
+	})
+	defer timer.Stop()
 	err = cmd.Wait()
-	os.Remove(binPath) // not that it matters much, this container will be nuked
 	os.Exit(errExitCode(err))
 	return
 }
 
+// gracefulStop attempts to send a SIGINT before a SIGKILL.
+//
+// The process will be sent a SIGINT immediately. If the context has still not been cancelled,
+// the process will be sent a SIGKILL after delay has passed since sending the SIGINT.
+//
+// TODO(golang.org/issue/38343) - Change SIGINT to SIGQUIT once decision is made.
+func gracefulStop(p *os.Process, delay time.Duration) {
+	// TODO(golang.org/issue/38343) - Change to syscall.SIGQUIT once decision is made.
+	if err := p.Signal(os.Interrupt); err != nil {
+		log.Printf("cmd.Process.Signal(%v): %v", os.Interrupt, err)
+	}
+	time.AfterFunc(delay, func() {
+		if err := p.Kill(); err != nil {
+			log.Printf("cmd.Process.Kill(): %v", err)
+		}
+	})
+}
+
 func makeWorkers() {
 	for {
 		c, err := startContainer(context.Background())
@@ -321,25 +348,6 @@
 func startContainer(ctx context.Context) (c *Container, err error) {
 	name := "play_run_" + randHex(8)
 	setContainerWanted(name, true)
-	var stdin io.WriteCloser
-	var stdout io.ReadCloser
-	var stderr io.ReadCloser
-	defer func() {
-		if err == nil {
-			return
-		}
-		setContainerWanted(name, false)
-		if stdin != nil {
-			stdin.Close()
-		}
-		if stdout != nil {
-			stdout.Close()
-		}
-		if stderr != nil {
-			stderr.Close()
-		}
-	}()
-
 	cmd := exec.Command("docker", "run",
 		"--name="+name,
 		"--rm",
@@ -352,46 +360,53 @@
 
 		*container,
 		"--mode=contained")
-	stdin, err = cmd.StdinPipe()
+	stdin, err := cmd.StdinPipe()
 	if err != nil {
 		return nil, err
 	}
-	stdout, err = cmd.StdoutPipe()
-	if err != nil {
-		return nil, err
-	}
-	stderr, err = cmd.StderrPipe()
-	if err != nil {
-		return nil, err
-	}
+	pr, pw := io.Pipe()
+	stdout := &limitedWriter{dst: &bytes.Buffer{}, n: maxOutputSize + int64(len(containedStartMessage))}
+	stderr := &limitedWriter{dst: &bytes.Buffer{}, n: maxOutputSize}
+	cmd.Stdout = &switchWriter{switchAfter: []byte(containedStartMessage), dst1: pw, dst2: stdout}
+	cmd.Stderr = stderr
 	if err := cmd.Start(); err != nil {
 		return nil, err
 	}
+	defer func() {
+		if err != nil {
+			log.Printf("error starting container %q: %v", name, err)
+			gracefulStop(cmd.Process, 250*time.Millisecond)
+			setContainerWanted(name, false)
+		}
+	}()
+	ctx, cancel := context.WithTimeout(ctx, startTimeout)
+	defer cancel()
 
-	errc := make(chan error, 1)
+	startErr := make(chan error, 1)
 	go func() {
 		buf := make([]byte, len(containedStartMessage))
-		if _, err := io.ReadFull(stdout, buf); err != nil {
-			errc <- fmt.Errorf("error reading header from sandbox container: %v", err)
-			return
+		_, err := io.ReadFull(pr, buf)
+		if err != nil {
+			startErr <- fmt.Errorf("error reading header from sandbox container: %v", err)
+		} else if string(buf) != containedStartMessage {
+			startErr <- fmt.Errorf("sandbox container sent wrong header %q; want %q", buf, containedStartMessage)
+		} else {
+			startErr <- nil
 		}
-		if string(buf) != containedStartMessage {
-			errc <- fmt.Errorf("sandbox container sent wrong header %q; want %q", buf, containedStartMessage)
-			return
-		}
-		errc <- nil
 	}()
+
 	select {
 	case <-ctx.Done():
-		log.Printf("timeout starting container")
-		cmd.Process.Kill()
-		return nil, ctx.Err()
-	case err := <-errc:
+		err := fmt.Errorf("timeout starting container %q: %w", name, ctx.Err())
+		pw.Close()
+		<-startErr
+		return nil, err
+	case err = <-startErr:
 		if err != nil {
-			log.Printf("error starting container: %v", err)
 			return nil, err
 		}
 	}
+	log.Printf("started container %q", name)
 	return &Container{
 		name:   name,
 		stdin:  stdin,
@@ -435,6 +450,7 @@
 
 	bin, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, maxBinarySize))
 	if err != nil {
+		log.Printf("failed to read request body: %v", err)
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
@@ -451,88 +467,136 @@
 		return
 	}
 	logf("got container %s", c.name)
-	defer c.Close()
-	defer logf("leaving handler; about to close container")
 
-	runTimer := time.NewTimer(runTimeout)
-	defer runTimer.Stop()
-
-	errc := make(chan error, 2) // user-visible error
-	waitc := make(chan error, 1)
-
-	copyOut := func(which string, dst *[]byte, r io.Reader) {
-		buf := make([]byte, 4<<10)
-		for {
-			n, err := r.Read(buf)
-			logf("%s: Read = %v, %v", which, n, err)
-			*dst = append(*dst, buf[:n]...)
-			if err == io.EOF {
-				return
-			}
-			if len(*dst) > maxOutputSize {
-				errc <- errTooMuchOutput
-				return
-			}
-			if err != nil {
-				log.Printf("reading %s: %v", which, err)
-				errc <- fmt.Errorf("error reading %v", which)
-				return
-			}
-		}
-	}
-
-	res := &sandboxtypes.Response{}
-	go func() {
-		var meta processMeta
-		meta.Args = r.Header["X-Argument"]
-		metaJSON, _ := json.Marshal(&meta)
-		metaJSON = append(metaJSON, '\n')
-		if _, err := c.stdin.Write(metaJSON); err != nil {
-			log.Printf("stdin write meta: %v", err)
-			errc <- errors.New("failed to write meta to child")
-			return
-		}
-		if _, err := c.stdin.Write(bin); err != nil {
-			log.Printf("stdin write: %v", err)
-			errc <- errors.New("failed to write binary to child")
-			return
-		}
-		c.stdin.Close()
-		logf("wrote+closed")
-		go copyOut("stdout", &res.Stdout, c.stdout)
-		go copyOut("stderr", &res.Stderr, c.stderr)
-		waitc <- c.Wait()
+	ctx, cancel := context.WithTimeout(context.Background(), runTimeout)
+	closed := make(chan struct{})
+	defer func() {
+		logf("leaving handler; about to close container")
+		cancel()
+		<-closed
 	}()
-	var waitErr error
-	select {
-	case waitErr = <-waitc:
-		logf("waited: %v", waitErr)
-	case err := <-errc:
-		logf("got error: %v", err)
-		if err == errTooMuchOutput {
-			sendError(w, err.Error())
-			return
+	go func() {
+		<-ctx.Done()
+		if ctx.Err() == context.DeadlineExceeded {
+			logf("timeout")
 		}
-		if err != nil {
-			http.Error(w, "failed to read stdout from docker run", http.StatusInternalServerError)
-			return
-		}
-	case <-runTimer.C:
-		logf("timeout")
-		sendError(w, "timeout running program")
+		c.Close()
+		close(closed)
+	}()
+	var meta processMeta
+	meta.Args = r.Header["X-Argument"]
+	metaJSON, _ := json.Marshal(&meta)
+	metaJSON = append(metaJSON, '\n')
+	if _, err := c.stdin.Write(metaJSON); err != nil {
+		log.Printf("failed to write meta to child: %v", err)
+		http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
 		return
 	}
-
-	res.ExitCode = errExitCode(waitErr)
-	res.Stderr = cleanStderr(res.Stderr)
+	if _, err := c.stdin.Write(bin); err != nil {
+		log.Printf("failed to write binary to child: %v", err)
+		http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
+		return
+	}
+	c.stdin.Close()
+	logf("wrote+closed")
+	err = c.Wait()
+	select {
+	case <-ctx.Done():
+		// Timed out or canceled before or exactly as Wait returned.
+		// Either way, treat it as a timeout.
+		sendError(w, "timeout running program")
+		return
+	default:
+		logf("finished running; about to close container")
+		cancel()
+	}
+	res := &sandboxtypes.Response{}
+	if err != nil {
+		if c.stderr.n < 0 || c.stdout.n < 0 {
+			// Do not send truncated output, just send the error.
+			sendError(w, errTooMuchOutput.Error())
+			return
+		}
+		var ee *exec.ExitError
+		if !errors.As(err, &ee) {
+			http.Error(w, "unknown error during docker run", http.StatusInternalServerError)
+			return
+		}
+		res.ExitCode = ee.ExitCode()
+	}
+	res.Stdout = c.stdout.dst.Bytes()
+	res.Stderr = cleanStderr(c.stderr.dst.Bytes())
 	sendResponse(w, res)
 }
 
+// limitedWriter is an io.Writer that returns an errTooMuchOutput when the cap (n) is hit.
+type limitedWriter struct {
+	dst *bytes.Buffer
+	n   int64 // max bytes remaining
+}
+
+// Write is an io.Writer function that returns errTooMuchOutput when the cap (n) is hit.
+//
+// Partial data will be written to dst if p is larger than n, but errTooMuchOutput will be returned.
+func (l *limitedWriter) Write(p []byte) (int, error) {
+	defer func() { l.n -= int64(len(p)) }()
+
+	if l.n <= 0 {
+		return 0, errTooMuchOutput
+	}
+
+	if int64(len(p)) > l.n {
+		n, err := l.dst.Write(p[:l.n])
+		if err != nil {
+			return n, err
+		}
+		return n, errTooMuchOutput
+	}
+
+	return l.dst.Write(p)
+}
+
+// switchWriter writes to dst1 until switchAfter is written, the it writes to dst2.
+type switchWriter struct {
+	dst1        io.Writer
+	dst2        io.Writer
+	switchAfter []byte
+	buf         []byte
+	found       bool
+}
+
+func (s *switchWriter) Write(p []byte) (int, error) {
+	if s.found {
+		return s.dst2.Write(p)
+	}
+
+	s.buf = append(s.buf, p...)
+	i := bytes.Index(s.buf, s.switchAfter)
+	if i == -1 {
+		if len(s.buf) >= len(s.switchAfter) {
+			s.buf = s.buf[len(s.buf)-len(s.switchAfter)+1:]
+		}
+		return s.dst1.Write(p)
+	}
+
+	s.found = true
+	nAfter := len(s.buf) - (i + len(s.switchAfter))
+	s.buf = nil
+
+	n, err := s.dst1.Write(p[:len(p)-nAfter])
+	if err != nil {
+		return n, err
+	}
+	n2, err := s.dst2.Write(p[len(p)-nAfter:])
+	return n + n2, err
+}
+
 func errExitCode(err error) int {
 	if err == nil {
 		return 0
 	}
-	if ee, ok := err.(*exec.ExitError); ok {
+	var ee *exec.ExitError
+	if errors.As(err, &ee) {
 		return ee.ExitCode()
 	}
 	return 1
diff --git a/sandbox/sandbox_test.go b/sandbox/sandbox_test.go
new file mode 100644
index 0000000..7b32ec6
--- /dev/null
+++ b/sandbox/sandbox_test.go
@@ -0,0 +1,180 @@
+package main
+
+import (
+	"bytes"
+	"io"
+	"strings"
+	"testing"
+	"testing/iotest"
+)
+
+func TestLimitedWriter(t *testing.T) {
+	cases := []struct {
+		desc          string
+		lw            *limitedWriter
+		in            []byte
+		want          []byte
+		wantN         int64
+		wantRemaining int64
+		err           error
+	}{
+		{
+			desc:          "simple",
+			lw:            &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+			in:            []byte("hi"),
+			want:          []byte("hi"),
+			wantN:         2,
+			wantRemaining: 8,
+		},
+		{
+			desc:          "writing nothing",
+			lw:            &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+			in:            []byte(""),
+			want:          []byte(""),
+			wantN:         0,
+			wantRemaining: 10,
+		},
+		{
+			desc:          "writing exactly enough",
+			lw:            &limitedWriter{dst: &bytes.Buffer{}, n: 6},
+			in:            []byte("enough"),
+			want:          []byte("enough"),
+			wantN:         6,
+			wantRemaining: 0,
+			err:           nil,
+		},
+		{
+			desc:          "writing too much",
+			lw:            &limitedWriter{dst: &bytes.Buffer{}, n: 10},
+			in:            []byte("this is much longer than 10"),
+			want:          []byte("this is mu"),
+			wantN:         10,
+			wantRemaining: -1,
+			err:           errTooMuchOutput,
+		},
+	}
+	for _, c := range cases {
+		t.Run(c.desc, func(t *testing.T) {
+			n, err := io.Copy(c.lw, iotest.OneByteReader(bytes.NewReader(c.in)))
+			if err != c.err {
+				t.Errorf("c.lw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+			}
+			if n != c.wantN {
+				t.Errorf("c.lw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+			}
+			if c.lw.n != c.wantRemaining {
+				t.Errorf("c.lw.n = %d, wanted %d", c.lw.n, c.wantRemaining)
+			}
+			if string(c.lw.dst.Bytes()) != string(c.want) {
+				t.Errorf("c.lw.dst.Bytes() = %q, wanted %q", c.lw.dst.Bytes(), c.want)
+			}
+		})
+	}
+}
+
+func TestSwitchWriter(t *testing.T) {
+	cases := []struct {
+		desc      string
+		sw        *switchWriter
+		in        []byte
+		want1     []byte
+		want2     []byte
+		wantN     int64
+		wantFound bool
+		err       error
+	}{
+		{
+			desc:      "not found",
+			sw:        &switchWriter{switchAfter: []byte("UNIQUE")},
+			in:        []byte("hi"),
+			want1:     []byte("hi"),
+			want2:     []byte(""),
+			wantN:     2,
+			wantFound: false,
+		},
+		{
+			desc:      "writing nothing",
+			sw:        &switchWriter{switchAfter: []byte("UNIQUE")},
+			in:        []byte(""),
+			want1:     []byte(""),
+			want2:     []byte(""),
+			wantN:     0,
+			wantFound: false,
+		},
+		{
+			desc:      "writing exactly switchAfter",
+			sw:        &switchWriter{switchAfter: []byte("UNIQUE")},
+			in:        []byte("UNIQUE"),
+			want1:     []byte("UNIQUE"),
+			want2:     []byte(""),
+			wantN:     6,
+			wantFound: true,
+		},
+		{
+			desc:      "writing before and after switchAfter",
+			sw:        &switchWriter{switchAfter: []byte("UNIQUE")},
+			in:        []byte("this is before UNIQUE and this is after"),
+			want1:     []byte("this is before UNIQUE"),
+			want2:     []byte(" and this is after"),
+			wantN:     39,
+			wantFound: true,
+		},
+	}
+	for _, c := range cases {
+		t.Run(c.desc, func(t *testing.T) {
+			dst1, dst2 := &bytes.Buffer{}, &bytes.Buffer{}
+			c.sw.dst1, c.sw.dst2 = dst1, dst2
+			n, err := io.Copy(c.sw, iotest.OneByteReader(bytes.NewReader(c.in)))
+			if err != c.err {
+				t.Errorf("c.sw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+			}
+			if n != c.wantN {
+				t.Errorf("c.sw.Write(%q) = %d, %q, wanted %d, %q", c.in, n, err, c.wantN, c.err)
+			}
+			if c.sw.found != c.wantFound {
+				t.Errorf("c.sw.found = %v, wanted %v", c.sw.found, c.wantFound)
+			}
+			if string(dst1.Bytes()) != string(c.want1) {
+				t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), c.want1)
+			}
+			if string(dst2.Bytes()) != string(c.want2) {
+				t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), c.want2)
+			}
+		})
+	}
+}
+
+func TestSwitchWriterMultipleWrites(t *testing.T) {
+	dst1, dst2 := &bytes.Buffer{}, &bytes.Buffer{}
+	sw := &switchWriter{
+		dst1:        dst1,
+		dst2:        dst2,
+		switchAfter: []byte("GOPHER"),
+	}
+	n, err := io.Copy(sw, iotest.OneByteReader(strings.NewReader("this is before GO")))
+	if err != nil || n != 17 {
+		t.Errorf("sw.Write(%q) = %d, %q, wanted %d, no error", "this is before GO", n, err, 17)
+	}
+	if sw.found {
+		t.Errorf("sw.found = %v, wanted %v", sw.found, false)
+	}
+	if string(dst1.Bytes()) != "this is before GO" {
+		t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), "this is before GO")
+	}
+	if string(dst2.Bytes()) != "" {
+		t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), "")
+	}
+	n, err = io.Copy(sw, iotest.OneByteReader(strings.NewReader("PHER and this is after")))
+	if err != nil || n != 22 {
+		t.Errorf("sw.Write(%q) = %d, %q, wanted %d, no error", "this is before GO", n, err, 22)
+	}
+	if !sw.found {
+		t.Errorf("sw.found = %v, wanted %v", sw.found, true)
+	}
+	if string(dst1.Bytes()) != "this is before GOPHER" {
+		t.Errorf("dst1.Bytes() = %q, wanted %q", dst1.Bytes(), "this is before GOPHEr")
+	}
+	if string(dst2.Bytes()) != " and this is after" {
+		t.Errorf("dst2.Bytes() = %q, wanted %q", dst2.Bytes(), " and this is after")
+	}
+}
diff --git a/tests.go b/tests.go
index 9e0d55f..3db7e63 100644
--- a/tests.go
+++ b/tests.go
@@ -67,7 +67,7 @@
 			continue
 		}
 		if resp.Errors != "" {
-			stdlog.Print(resp.Errors)
+			stdlog.Printf("resp.Errors = %q, want %q", resp.Errors, t.errors)
 			failed = true
 			continue
 		}
@@ -548,4 +548,26 @@
 func Hello() { fmt.Println("hello world") }
 `,
 	},
+	{
+		name: "timeouts_handled_gracefully",
+		prog: `
+package main
+
+import (
+	"time"
+)
+
+func main() {
+	c := make(chan struct{})
+
+	go func() {
+		defer close(c)
+		for {
+			time.Sleep(10 * time.Millisecond)
+		}
+	}()
+
+	<-c
+}
+`, want: "timeout running program"},
 }