errgroup: use WithCancelCause to cancel context
Fixes golang/go#59355
Change-Id: Ib6a88e7e5fefe7b0d5672035af16d109aabcbf1e
Reviewed-on: https://go-review.googlesource.com/c/sync/+/481255
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go
index cbee7a4..b18efb7 100644
--- a/errgroup/errgroup.go
+++ b/errgroup/errgroup.go
@@ -20,7 +20,7 @@
// A zero Group is valid, has no limit on the number of active goroutines,
// and does not cancel on error.
type Group struct {
- cancel func()
+ cancel func(error)
wg sync.WaitGroup
@@ -43,7 +43,7 @@
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
- ctx, cancel := context.WithCancel(ctx)
+ ctx, cancel := withCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}
@@ -52,7 +52,7 @@
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
- g.cancel()
+ g.cancel(g.err)
}
return g.err
}
@@ -76,7 +76,7 @@
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
- g.cancel()
+ g.cancel(g.err)
}
})
}
@@ -105,7 +105,7 @@
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
- g.cancel()
+ g.cancel(g.err)
}
})
}
diff --git a/errgroup/go120.go b/errgroup/go120.go
new file mode 100644
index 0000000..7d419d3
--- /dev/null
+++ b/errgroup/go120.go
@@ -0,0 +1,14 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.20
+// +build go1.20
+
+package errgroup
+
+import "context"
+
+func withCancelCause(parent context.Context) (context.Context, func(error)) {
+ return context.WithCancelCause(parent)
+}
diff --git a/errgroup/go120_test.go b/errgroup/go120_test.go
new file mode 100644
index 0000000..0c354a1
--- /dev/null
+++ b/errgroup/go120_test.go
@@ -0,0 +1,55 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.20
+// +build go1.20
+
+package errgroup_test
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "golang.org/x/sync/errgroup"
+)
+
+func TestCancelCause(t *testing.T) {
+ errDoom := errors.New("group_test: doomed")
+
+ cases := []struct {
+ errs []error
+ want error
+ }{
+ {want: nil},
+ {errs: []error{nil}, want: nil},
+ {errs: []error{errDoom}, want: errDoom},
+ {errs: []error{errDoom, nil}, want: errDoom},
+ }
+
+ for _, tc := range cases {
+ g, ctx := errgroup.WithContext(context.Background())
+
+ for _, err := range tc.errs {
+ err := err
+ g.TryGo(func() error { return err })
+ }
+
+ if err := g.Wait(); err != tc.want {
+ t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
+ "g.Wait() = %v; want %v",
+ g, tc.errs, err, tc.want)
+ }
+
+ if tc.want == nil {
+ tc.want = context.Canceled
+ }
+
+ if err := context.Cause(ctx); err != tc.want {
+ t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
+ "context.Cause(ctx) = %v; tc.want %v",
+ g, tc.errs, err, tc.want)
+ }
+ }
+}
diff --git a/errgroup/pre_go120.go b/errgroup/pre_go120.go
new file mode 100644
index 0000000..1795c18
--- /dev/null
+++ b/errgroup/pre_go120.go
@@ -0,0 +1,15 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !go1.20
+// +build !go1.20
+
+package errgroup
+
+import "context"
+
+func withCancelCause(parent context.Context) (context.Context, func(error)) {
+ ctx, cancel := context.WithCancel(parent)
+ return ctx, func(error) { cancel() }
+}