errgroup: propagate panic and Goexit through Wait

Recovered panic values are wrapped and saved in Group.
Goexits are detected by a sentinel value set after the given function
returns normally. Wait propagates the first instance of a panic or
Goexit.

According to the runtime.Goexit after the code will not be executed,
with a bool, if f not call runtime.Goexit, is true,
determine whether to propagate runtime.Goexit.

Fixes golang/go#53757

Change-Id: Ic6426fc014fd1c4368ebaceef5b0d6163770a099
Reviewed-on: https://go-review.googlesource.com/c/sync/+/644575
Reviewed-by: Sean Liao <sean@liao.dev>
Auto-Submit: Alan Donovan <adonovan@google.com>
Commit-Queue: Alan Donovan <adonovan@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go
index f8c3c09..cfafed5 100644
--- a/errgroup/errgroup.go
+++ b/errgroup/errgroup.go
@@ -12,6 +12,8 @@
 import (
 	"context"
 	"fmt"
+	"runtime"
+	"runtime/debug"
 	"sync"
 )
 
@@ -31,6 +33,10 @@
 
 	errOnce sync.Once
 	err     error
+
+	mu         sync.Mutex
+	panicValue any  // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
+	abnormal   bool // some Group.Go goroutine terminated abnormally (panic or goexit).
 }
 
 func (g *Group) done() {
@@ -50,13 +56,22 @@
 	return &Group{cancel: cancel}, ctx
 }
 
-// Wait blocks until all function calls from the Go method have returned, then
-// returns the first non-nil error (if any) from them.
+// Wait blocks until all function calls from the Go method have returned
+// normally, then returns the first non-nil error (if any) from them.
+//
+// If any of the calls panics, Wait panics with a [PanicValue];
+// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
 func (g *Group) Wait() error {
 	g.wg.Wait()
 	if g.cancel != nil {
 		g.cancel(g.err)
 	}
+	if g.panicValue != nil {
+		panic(g.panicValue)
+	}
+	if g.abnormal {
+		runtime.Goexit()
+	}
 	return g.err
 }
 
@@ -65,18 +80,56 @@
 // It blocks until the new goroutine can be added without the number of
 // active goroutines in the group exceeding the configured limit.
 //
-// The first call to return a non-nil error cancels the group's context, if the
-// group was created by calling WithContext. The error will be returned by Wait.
+// It blocks until the new goroutine can be added without the number of
+// goroutines in the group exceeding the configured limit.
+//
+// The first goroutine in the group that returns a non-nil error, panics, or
+// invokes [runtime.Goexit] will cancel the associated Context, if any.
 func (g *Group) Go(f func() error) {
 	if g.sem != nil {
 		g.sem <- token{}
 	}
 
+	g.add(f)
+}
+
+func (g *Group) add(f func() error) {
 	g.wg.Add(1)
 	go func() {
 		defer g.done()
+		normalReturn := false
+		defer func() {
+			if normalReturn {
+				return
+			}
+			v := recover()
+			g.mu.Lock()
+			defer g.mu.Unlock()
+			if !g.abnormal {
+				if g.cancel != nil {
+					g.cancel(g.err)
+				}
+				g.abnormal = true
+			}
+			if v != nil && g.panicValue == nil {
+				switch v := v.(type) {
+				case error:
+					g.panicValue = PanicError{
+						Recovered: v,
+						Stack:     debug.Stack(),
+					}
+				default:
+					g.panicValue = PanicValue{
+						Recovered: v,
+						Stack:     debug.Stack(),
+					}
+				}
+			}
+		}()
 
-		if err := f(); err != nil {
+		err := f()
+		normalReturn = true
+		if err != nil {
 			g.errOnce.Do(func() {
 				g.err = err
 				if g.cancel != nil {
@@ -101,19 +154,7 @@
 		}
 	}
 
-	g.wg.Add(1)
-	go func() {
-		defer g.done()
-
-		if err := f(); err != nil {
-			g.errOnce.Do(func() {
-				g.err = err
-				if g.cancel != nil {
-					g.cancel(g.err)
-				}
-			})
-		}
-	}()
+	g.add(f)
 	return true
 }
 
@@ -135,3 +176,33 @@
 	}
 	g.sem = make(chan token, n)
 }
+
+// PanicError wraps an error recovered from an unhandled panic
+// when calling a function passed to Go or TryGo.
+type PanicError struct {
+	Recovered error
+	Stack     []byte // result of call to [debug.Stack]
+}
+
+func (p PanicError) Error() string {
+	// A Go Error method conventionally does not include a stack dump, so omit it
+	// here. (Callers who care can extract it from the Stack field.)
+	return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
+}
+
+func (p PanicError) Unwrap() error { return p.Recovered }
+
+// PanicValue wraps a value that does not implement the error interface,
+// recovered from an unhandled panic when calling a function passed to Go or
+// TryGo.
+type PanicValue struct {
+	Recovered any
+	Stack     []byte // result of call to [debug.Stack]
+}
+
+func (p PanicValue) String() string {
+	if len(p.Stack) > 0 {
+		return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
+	}
+	return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
+}
diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go
index 2a491bf..4684259 100644
--- a/errgroup/errgroup_test.go
+++ b/errgroup/errgroup_test.go
@@ -10,6 +10,7 @@
 	"fmt"
 	"net/http"
 	"os"
+	"strings"
 	"sync/atomic"
 	"testing"
 	"time"
@@ -289,6 +290,69 @@
 	}
 }
 
+func TestPanic(t *testing.T) {
+	t.Run("error", func(t *testing.T) {
+		g := &errgroup.Group{}
+		p := errors.New("")
+		g.Go(func() error {
+			panic(p)
+		})
+		defer func() {
+			err := recover()
+			if err == nil {
+				t.Fatalf("should propagate panic through Wait")
+			}
+			pe, ok := err.(errgroup.PanicError)
+			if !ok {
+				t.Fatalf("type should is errgroup.PanicError, but is %T", err)
+			}
+			if pe.Recovered != p {
+				t.Fatalf("got %v, want %v", pe.Recovered, p)
+			}
+			if !strings.Contains(string(pe.Stack), "TestPanic.func") {
+				t.Log(string(pe.Stack))
+				t.Fatalf("stack trace incomplete")
+			}
+		}()
+		g.Wait()
+	})
+	t.Run("any", func(t *testing.T) {
+		g := &errgroup.Group{}
+		g.Go(func() error {
+			panic(1)
+		})
+		defer func() {
+			err := recover()
+			if err == nil {
+				t.Fatalf("should propagate panic through Wait")
+			}
+			pe, ok := err.(errgroup.PanicValue)
+			if !ok {
+				t.Fatalf("type should is errgroup.PanicValue, but is %T", err)
+			}
+			if pe.Recovered != 1 {
+				t.Fatalf("got %v, want %v", pe.Recovered, 1)
+			}
+			if !strings.Contains(string(pe.Stack), "TestPanic.func") {
+				t.Log(string(pe.Stack))
+				t.Fatalf("stack trace incomplete")
+			}
+		}()
+		g.Wait()
+	})
+}
+
+func TestGoexit(t *testing.T) {
+	g := &errgroup.Group{}
+	g.Go(func() error {
+		t.Skip()
+		t.Fatalf("Goexit fail")
+		return nil
+	})
+	g.Wait()
+	t.Fatalf("should call runtime.Goexit from Wait")
+}
+
 func BenchmarkGo(b *testing.B) {
 	fn := func() {}
 	g := &errgroup.Group{}