singleflight: fix duplicate deleting key when Forget called

When Forget was called, we delete key associated with current call from
map. When that call is done, it does delete key again, causing the same
key set by other call after Forget lost.

To fix it, adding a boolean value to check whether the call is forgotten,
the call only does delete key if Forget is not called.

Fixes golang/go#31420

Change-Id: I9708352ca3ff76c77f659916b37a496fdeb480d2
Reviewed-on: https://go-review.googlesource.com/c/sync/+/171897
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go
index 9a4f8d5..97a1aa4 100644
--- a/singleflight/singleflight.go
+++ b/singleflight/singleflight.go
@@ -17,6 +17,10 @@
 	val interface{}
 	err error
 
+	// forgotten indicates whether Forget was called with this call's key
+	// while the call was still in flight.
+	forgotten bool
+
 	// These fields are read and written with the singleflight
 	// mutex held before the WaitGroup is done, and are read but
 	// not written after the WaitGroup is done.
@@ -94,7 +98,9 @@
 	c.wg.Done()
 
 	g.mu.Lock()
-	delete(g.m, key)
+	if !c.forgotten {
+		delete(g.m, key)
+	}
 	for _, ch := range c.chans {
 		ch <- Result{c.val, c.err, c.dups > 0}
 	}
@@ -106,6 +112,9 @@
 // an earlier call to complete.
 func (g *Group) Forget(key string) {
 	g.mu.Lock()
+	if c, ok := g.m[key]; ok {
+		c.forgotten = true
+	}
 	delete(g.m, key)
 	g.mu.Unlock()
 }
diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go
index 5e6f1b3..ad04037 100644
--- a/singleflight/singleflight_test.go
+++ b/singleflight/singleflight_test.go
@@ -85,3 +85,75 @@
 		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
 	}
 }
+
+// Test that singleflight behaves correctly after Forget called.
+// See https://github.com/golang/go/issues/31420
+func TestForget(t *testing.T) {
+	var g Group
+
+	var firstStarted, firstFinished sync.WaitGroup
+
+	firstStarted.Add(1)
+	firstFinished.Add(1)
+
+	firstCh := make(chan struct{})
+	go func() {
+		g.Do("key", func() (i interface{}, e error) {
+			firstStarted.Done()
+			<-firstCh
+			firstFinished.Done()
+			return
+		})
+	}()
+
+	firstStarted.Wait()
+	g.Forget("key") // from this point no two function using same key should be executed concurrently
+
+	var secondStarted int32
+	var secondFinished int32
+	var thirdStarted int32
+
+	secondCh := make(chan struct{})
+	secondRunning := make(chan struct{})
+	go func() {
+		g.Do("key", func() (i interface{}, e error) {
+			defer func() {
+			}()
+			atomic.AddInt32(&secondStarted, 1)
+			// Notify that we started
+			secondCh <- struct{}{}
+			// Wait other get above signal
+			<-secondRunning
+			<-secondCh
+			atomic.AddInt32(&secondFinished, 1)
+			return 2, nil
+		})
+	}()
+
+	close(firstCh)
+	firstFinished.Wait() // wait for first execution (which should not affect execution after Forget)
+
+	<-secondCh
+	// Notify second that we got the signal that it started
+	secondRunning <- struct{}{}
+	if atomic.LoadInt32(&secondStarted) != 1 {
+		t.Fatal("Second execution should be executed due to usage of forget")
+	}
+
+	if atomic.LoadInt32(&secondFinished) == 1 {
+		t.Fatal("Second execution should be still active")
+	}
+
+	close(secondCh)
+	result, _, _ := g.Do("key", func() (i interface{}, e error) {
+		atomic.AddInt32(&thirdStarted, 1)
+		return 3, nil
+	})
+
+	if atomic.LoadInt32(&thirdStarted) != 0 {
+		t.Error("Third call should not be started because was started during second execution")
+	}
+	if result != 2 {
+		t.Errorf("We should receive result produced by second call, expected: 2, got %d", result)
+	}
+}