singleflight: add panicError.Unwrap method
Currently when singleflight recovers from a panic, it wraps it with the private
error type panicError. This change adds an `Unwrap` method to panicError to
allow wrapped errors to be returned.
Updates golang/go#62511
Change-Id: Ia510ad7d5881207ef71f9eb89c1766835af19b6b
Reviewed-on: https://go-review.googlesource.com/c/sync/+/526171
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go
index 8473fb7..4051830 100644
--- a/singleflight/singleflight.go
+++ b/singleflight/singleflight.go
@@ -31,6 +31,15 @@
return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
}
+func (p *panicError) Unwrap() error {
+ err, ok := p.value.(error)
+ if !ok {
+ return nil
+ }
+
+ return err
+}
+
func newPanicError(v interface{}) error {
stack := debug.Stack()
diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go
index bb25a1e..1e85b17 100644
--- a/singleflight/singleflight_test.go
+++ b/singleflight/singleflight_test.go
@@ -19,6 +19,69 @@
"time"
)
+type errValue struct{}
+
+func (err *errValue) Error() string {
+ return "error value"
+}
+
+func TestPanicErrorUnwrap(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ panicValue interface{}
+ wrappedErrorType bool
+ }{
+ {
+ name: "panicError wraps non-error type",
+ panicValue: &panicError{value: "string value"},
+ wrappedErrorType: false,
+ },
+ {
+ name: "panicError wraps error type",
+ panicValue: &panicError{value: new(errValue)},
+ wrappedErrorType: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ var recovered interface{}
+
+ group := &Group{}
+
+ func() {
+ defer func() {
+ recovered = recover()
+ t.Logf("after panic(%#v) in group.Do, recovered %#v", tc.panicValue, recovered)
+ }()
+
+ _, _, _ = group.Do(tc.name, func() (interface{}, error) {
+ panic(tc.panicValue)
+ })
+ }()
+
+ if recovered == nil {
+ t.Fatal("expected a non-nil panic value")
+ }
+
+ err, ok := recovered.(error)
+ if !ok {
+ t.Fatalf("recovered non-error type: %T", recovered)
+ }
+
+ if !errors.Is(err, new(errValue)) && tc.wrappedErrorType {
+ t.Errorf("unexpected wrapped error type %T; want %T", err, new(errValue))
+ }
+ })
+ }
+}
+
func TestDo(t *testing.T) {
var g Group
v, err, _ := g.Do("key", func() (interface{}, error) {