net: don't let cancelation of a DNS lookup affect another lookup

Updates #8602
Updates #20703
Fixes #22724

Change-Id: I27b72311b2c66148c59977361bd3f5101e47b51d
Reviewed-on: https://go-review.googlesource.com/100840
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/internal/singleflight/singleflight.go b/src/internal/singleflight/singleflight.go
index 1e9960d..b2d82e2 100644
--- a/src/internal/singleflight/singleflight.go
+++ b/src/internal/singleflight/singleflight.go
@@ -103,11 +103,21 @@
 	g.mu.Unlock()
 }
 
-// Forget tells the singleflight to forget about a key.  Future calls
-// to Do for this key will call the function rather than waiting for
-// an earlier call to complete.
-func (g *Group) Forget(key string) {
+// ForgetUnshared tells the singleflight to forget about a key if it is not
+// shared with any other goroutines. Future calls to Do for a forgotten key
+// will call the function rather than waiting for an earlier call to complete.
+// Returns whether the key was forgotten or unknown--that is, whether no
+// other goroutines are waiting for the result.
+func (g *Group) ForgetUnshared(key string) bool {
 	g.mu.Lock()
-	delete(g.m, key)
-	g.mu.Unlock()
+	defer g.mu.Unlock()
+	c, ok := g.m[key]
+	if !ok {
+		return true
+	}
+	if c.dups == 0 {
+		delete(g.m, key)
+		return true
+	}
+	return false
 }
diff --git a/src/net/lookup.go b/src/net/lookup.go
index 6844b11..dffbc01 100644
--- a/src/net/lookup.go
+++ b/src/net/lookup.go
@@ -194,10 +194,16 @@
 		resolverFunc = alt
 	}
 
+	// We don't want a cancelation of ctx to affect the
+	// lookupGroup operation. Otherwise if our context gets
+	// canceled it might cause an error to be returned to a lookup
+	// using a completely different context.
+	lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background())
+
 	dnsWaitGroup.Add(1)
 	ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
 		defer dnsWaitGroup.Done()
-		return testHookLookupIP(ctx, resolverFunc, host)
+		return testHookLookupIP(lookupGroupCtx, resolverFunc, host)
 	})
 	if !called {
 		dnsWaitGroup.Done()
@@ -205,20 +211,28 @@
 
 	select {
 	case <-ctx.Done():
-		// If the DNS lookup timed out for some reason, force
-		// future requests to start the DNS lookup again
-		// rather than waiting for the current lookup to
-		// complete. See issue 8602.
-		ctxErr := ctx.Err()
-		if ctxErr == context.DeadlineExceeded {
-			lookupGroup.Forget(host)
+		// Our context was canceled. If we are the only
+		// goroutine looking up this key, then drop the key
+		// from the lookupGroup and cancel the lookup.
+		// If there are other goroutines looking up this key,
+		// let the lookup continue uncanceled, and let later
+		// lookups with the same key share the result.
+		// See issues 8602, 20703, 22724.
+		if lookupGroup.ForgetUnshared(host) {
+			lookupGroupCancel()
+		} else {
+			go func() {
+				<-ch
+				lookupGroupCancel()
+			}()
 		}
-		err := mapErr(ctxErr)
+		err := mapErr(ctx.Err())
 		if trace != nil && trace.DNSDone != nil {
 			trace.DNSDone(nil, false, err)
 		}
 		return nil, err
 	case r := <-ch:
+		lookupGroupCancel()
 		if trace != nil && trace.DNSDone != nil {
 			addrs, _ := r.Val.([]IPAddr)
 			trace.DNSDone(ipAddrsEface(addrs), r.Shared, r.Err)
diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go
index f9f79e6..481c6d4 100644
--- a/src/net/lookup_test.go
+++ b/src/net/lookup_test.go
@@ -789,3 +789,28 @@
 		t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost)
 	}
 }
+
+func TestLookupContextCancel(t *testing.T) {
+	if testenv.Builder() == "" {
+		testenv.MustHaveExternalNetwork(t)
+	}
+	if runtime.GOOS == "nacl" {
+		t.Skip("skip on nacl")
+	}
+
+	defer dnsWaitGroup.Wait()
+
+	ctx, ctxCancel := context.WithCancel(context.Background())
+	ctxCancel()
+	_, err := DefaultResolver.LookupIPAddr(ctx, "google.com")
+	if err != errCanceled {
+		testenv.SkipFlakyNet(t)
+		t.Fatal(err)
+	}
+	ctx = context.Background()
+	_, err = DefaultResolver.LookupIPAddr(ctx, "google.com")
+	if err != nil {
+		testenv.SkipFlakyNet(t)
+		t.Fatal(err)
+	}
+}
diff --git a/src/net/tcpsock_unix_test.go b/src/net/tcpsock_unix_test.go
index 3af1834..95c02d2 100644
--- a/src/net/tcpsock_unix_test.go
+++ b/src/net/tcpsock_unix_test.go
@@ -87,6 +87,7 @@
 	if testenv.Builder() == "" {
 		testenv.MustHaveExternalNetwork(t)
 	}
+	defer dnsWaitGroup.Wait()
 	t.Parallel()
 	const tries = 10000
 	var wg sync.WaitGroup