singleflight: fix hangs after first Do panic

When first Do panic, the related wait group will never be done,
and all the subsequent calls would block on the same wait group forever.

Fixes golang/go#41133

Change-Id: I0ad9bfb387b6133b10766a34fc0040f200eae27e
Reviewed-on: https://go-review.googlesource.com/c/sync/+/251677
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Bryan C. Mills <bcmills@google.com>
Trust: Ian Lance Taylor <iant@golang.org>
Trust: Bryan C. Mills <bcmills@google.com>
diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go
index 97a1aa4..690eb85 100644
--- a/singleflight/singleflight.go
+++ b/singleflight/singleflight.go
@@ -6,7 +6,42 @@
 // mechanism.
 package singleflight // import "golang.org/x/sync/singleflight"
 
-import "sync"
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"runtime"
+	"runtime/debug"
+	"sync"
+)
+
+// errGoexit indicates the runtime.Goexit was called in
+// the user given function.
+var errGoexit = errors.New("runtime.Goexit was called")
+
+// A panicError is an arbitrary value recovered from a panic
+// with the stack trace during the execution of given function.
+type panicError struct {
+	value interface{}
+	stack []byte
+}
+
+// Error implements error interface.
+func (p *panicError) Error() string {
+	return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
+}
+
+func newPanicError(v interface{}) error {
+	stack := debug.Stack()
+
+	// The first line of the stack trace is of the form "goroutine N [status]:"
+	// but by the time the panic reaches Do the goroutine may no longer exist
+	// and its status will have changed. Trim out the misleading line.
+	if line := bytes.IndexByte(stack[:], '\n'); line >= 0 {
+		stack = stack[line+1:]
+	}
+	return &panicError{value: v, stack: stack}
+}
 
 // call is an in-flight or completed singleflight.Do call
 type call struct {
@@ -57,6 +92,12 @@
 		c.dups++
 		g.mu.Unlock()
 		c.wg.Wait()
+
+		if e, ok := c.err.(*panicError); ok {
+			panic(e)
+		} else if c.err == errGoexit {
+			runtime.Goexit()
+		}
 		return c.val, c.err, true
 	}
 	c := new(call)
@@ -70,6 +111,8 @@
 
 // DoChan is like Do but returns a channel that will receive the
 // results when they are ready.
+//
+// The returned channel will not be closed.
 func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
 	ch := make(chan Result, 1)
 	g.mu.Lock()
@@ -94,17 +137,66 @@
 
 // doCall handles the single call for a key.
 func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
-	c.val, c.err = fn()
-	c.wg.Done()
+	normalReturn := false
+	recovered := false
 
-	g.mu.Lock()
-	if !c.forgotten {
-		delete(g.m, key)
+	// use double-defer to distinguish panic from runtime.Goexit,
+	// more details see https://golang.org/cl/134395
+	defer func() {
+		// the given function invoked runtime.Goexit
+		if !normalReturn && !recovered {
+			c.err = errGoexit
+		}
+
+		c.wg.Done()
+		g.mu.Lock()
+		defer g.mu.Unlock()
+		if !c.forgotten {
+			delete(g.m, key)
+		}
+
+		if e, ok := c.err.(*panicError); ok {
+			// In order to prevent the waiting channels from being blocked forever,
+			// needs to ensure that this panic cannot be recovered.
+			if len(c.chans) > 0 {
+				go panic(e)
+				select {} // Keep this goroutine around so that it will appear in the crash dump.
+			} else {
+				panic(e)
+			}
+		} else if c.err == errGoexit {
+			// Already in the process of goexit, no need to call again
+		} else {
+			// Normal return
+			for _, ch := range c.chans {
+				ch <- Result{c.val, c.err, c.dups > 0}
+			}
+		}
+	}()
+
+	func() {
+		defer func() {
+			if !normalReturn {
+				// Ideally, we would wait to take a stack trace until we've determined
+				// whether this is a panic or a runtime.Goexit.
+				//
+				// Unfortunately, the only way we can distinguish the two is to see
+				// whether the recover stopped the goroutine from terminating, and by
+				// the time we know that, the part of the stack trace relevant to the
+				// panic has been discarded.
+				if r := recover(); r != nil {
+					c.err = newPanicError(r)
+				}
+			}
+		}()
+
+		c.val, c.err = fn()
+		normalReturn = true
+	}()
+
+	if !normalReturn {
+		recovered = true
 	}
-	for _, ch := range c.chans {
-		ch <- Result{c.val, c.err, c.dups > 0}
-	}
-	g.mu.Unlock()
 }
 
 // Forget tells the singleflight to forget about a key.  Future calls
diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go
index ad04037..c635edc 100644
--- a/singleflight/singleflight_test.go
+++ b/singleflight/singleflight_test.go
@@ -5,8 +5,14 @@
 package singleflight
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
+	"os"
+	"os/exec"
+	"runtime"
+	"runtime/debug"
+	"strings"
 	"sync"
 	"sync/atomic"
 	"testing"
@@ -157,3 +163,179 @@
 		t.Errorf("We should receive result produced by second call, expected: 2, got %d", result)
 	}
 }
+
+func TestDoChan(t *testing.T) {
+	var g Group
+	ch := g.DoChan("key", func() (interface{}, error) {
+		return "bar", nil
+	})
+
+	res := <-ch
+	v := res.Val
+	err := res.Err
+	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
+		t.Errorf("Do = %v; want %v", got, want)
+	}
+	if err != nil {
+		t.Errorf("Do error = %v", err)
+	}
+}
+
+// Test singleflight behaves correctly after Do panic.
+// See https://github.com/golang/go/issues/41133
+func TestPanicDo(t *testing.T) {
+	var g Group
+	fn := func() (interface{}, error) {
+		panic("invalid memory address or nil pointer dereference")
+	}
+
+	const n = 5
+	waited := int32(n)
+	panicCount := int32(0)
+	done := make(chan struct{})
+	for i := 0; i < n; i++ {
+		go func() {
+			defer func() {
+				if err := recover(); err != nil {
+					t.Logf("Got panic: %v\n%s", err, debug.Stack())
+					atomic.AddInt32(&panicCount, 1)
+				}
+
+				if atomic.AddInt32(&waited, -1) == 0 {
+					close(done)
+				}
+			}()
+
+			g.Do("key", fn)
+		}()
+	}
+
+	select {
+	case <-done:
+		if panicCount != n {
+			t.Errorf("Expect %d panic, but got %d", n, panicCount)
+		}
+	case <-time.After(time.Second):
+		t.Fatalf("Do hangs")
+	}
+}
+
+func TestGoexitDo(t *testing.T) {
+	var g Group
+	fn := func() (interface{}, error) {
+		runtime.Goexit()
+		return nil, nil
+	}
+
+	const n = 5
+	waited := int32(n)
+	done := make(chan struct{})
+	for i := 0; i < n; i++ {
+		go func() {
+			var err error
+			defer func() {
+				if err != nil {
+					t.Errorf("Error should be nil, but got: %v", err)
+				}
+				if atomic.AddInt32(&waited, -1) == 0 {
+					close(done)
+				}
+			}()
+			_, err, _ = g.Do("key", fn)
+		}()
+	}
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatalf("Do hangs")
+	}
+}
+
+func TestPanicDoChan(t *testing.T) {
+	if os.Getenv("TEST_PANIC_DOCHAN") != "" {
+		defer func() {
+			recover()
+		}()
+
+		g := new(Group)
+		ch := g.DoChan("", func() (interface{}, error) {
+			panic("Panicking in DoChan")
+		})
+		<-ch
+		t.Fatalf("DoChan unexpectedly returned")
+	}
+
+	t.Parallel()
+
+	cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
+	cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
+	out := new(bytes.Buffer)
+	cmd.Stdout = out
+	cmd.Stderr = out
+	if err := cmd.Start(); err != nil {
+		t.Fatal(err)
+	}
+
+	err := cmd.Wait()
+	t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
+	if err == nil {
+		t.Errorf("Test subprocess passed; want a crash due to panic in DoChan")
+	}
+	if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
+		t.Errorf("Test subprocess failed with an unexpected failure mode.")
+	}
+	if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) {
+		t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan")
+	}
+}
+
+func TestPanicDoSharedByDoChan(t *testing.T) {
+	if os.Getenv("TEST_PANIC_DOCHAN") != "" {
+		blocked := make(chan struct{})
+		unblock := make(chan struct{})
+
+		g := new(Group)
+		go func() {
+			defer func() {
+				recover()
+			}()
+			g.Do("", func() (interface{}, error) {
+				close(blocked)
+				<-unblock
+				panic("Panicking in Do")
+			})
+		}()
+
+		<-blocked
+		ch := g.DoChan("", func() (interface{}, error) {
+			panic("DoChan unexpectedly executed callback")
+		})
+		close(unblock)
+		<-ch
+		t.Fatalf("DoChan unexpectedly returned")
+	}
+
+	t.Parallel()
+
+	cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
+	cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
+	out := new(bytes.Buffer)
+	cmd.Stdout = out
+	cmd.Stderr = out
+	if err := cmd.Start(); err != nil {
+		t.Fatal(err)
+	}
+
+	err := cmd.Wait()
+	t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
+	if err == nil {
+		t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan")
+	}
+	if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
+		t.Errorf("Test subprocess failed with an unexpected failure mode.")
+	}
+	if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) {
+		t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do")
+	}
+}