internal/jsonrpc2: fix races in cancellation

We had a deadlock in cases where a request was cancelled (1) after being
written to the stream, but (2) before a response was received. This
resulted in the request ID being removed from the pending map while the
server has the request, after which point the server response would hang
in Conn.Run trying to send to a nil channel.

After fixing this nil send there was still a race: it was possible that
Conn.Run could get the pending request, and Conn.Call would select
ctx.Done before Conn.Run could send to the response channel, again
resulting in a blocking send. Fix this by adding a buffer to the
response channel.

The response channel management is also made less forgiving, because we
should be able to reason precisely about how many sends and receives
will occur:
 + Don't close the response channel after sending a response: there
   should only be one recipient.
 + Don't delete the ID from pending map twice: it should only be cleaned
   up by Conn.Call.

Cancellation tests in the lsprpc package are updated to exercise the
race conditions.

Fixes golang/go#37159

Change-Id: Ie3207442ea910f79247b18d8647fd52f39fb15db
Reviewed-on: https://go-review.googlesource.com/c/tools/+/219126
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go
index 963f818..6cbd607 100644
--- a/internal/jsonrpc2/jsonrpc2.go
+++ b/internal/jsonrpc2/jsonrpc2.go
@@ -147,14 +147,15 @@
 	for _, h := range c.handlers {
 		ctx = h.Request(ctx, c, Send, request)
 	}
-	// we have to add ourselves to the pending map before we send, otherwise we
-	// are racing the response
-	rchan := make(chan *WireResponse)
+	// We have to add ourselves to the pending map before we send, otherwise we
+	// are racing the response. Also add a buffer to rchan, so that if we get a
+	// wire response between the time this call is cancelled and id is deleted
+	// from c.pending, the send to rchan will not block.
+	rchan := make(chan *WireResponse, 1)
 	c.pendingMu.Lock()
 	c.pending[id] = rchan
 	c.pendingMu.Unlock()
 	defer func() {
-		// clean up the pending response handler on the way out
 		c.pendingMu.Lock()
 		delete(c.pending, id)
 		c.pendingMu.Unlock()
@@ -189,7 +190,7 @@
 		}
 		return nil
 	case <-ctx.Done():
-		// allow the handler to propagate the cancel
+		// Allow the handler to propagate the cancel.
 		cancelled := false
 		for _, h := range c.handlers {
 			if h.Cancel(ctx, c, id, cancelled) {
@@ -328,10 +329,10 @@
 			}
 			continue
 		}
-		// work out which kind of message we have
+		// Work out whether this is a request or response.
 		switch {
 		case msg.Method != "":
-			// if method is set it must be a request
+			// If method is set it must be a request.
 			reqCtx, cancelReq := context.WithCancel(runCtx)
 			thisRequest := nextRequest
 			nextRequest = make(chan struct{})
@@ -373,21 +374,19 @@
 				}
 			}()
 		case msg.ID != nil:
-			// we have a response, get the pending entry from the map
+			// If method is not set, this should be a response, in which case we must
+			// have an id to send the response back to the caller.
 			c.pendingMu.Lock()
-			rchan := c.pending[*msg.ID]
-			if rchan != nil {
-				delete(c.pending, *msg.ID)
-			}
+			rchan, ok := c.pending[*msg.ID]
 			c.pendingMu.Unlock()
-			// and send the reply to the channel
-			response := &WireResponse{
-				Result: msg.Result,
-				Error:  msg.Error,
-				ID:     msg.ID,
+			if ok {
+				response := &WireResponse{
+					Result: msg.Result,
+					Error:  msg.Error,
+					ID:     msg.ID,
+				}
+				rchan <- response
 			}
-			rchan <- response
-			close(rchan)
 		default:
 			for _, h := range c.handlers {
 				h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring"))
diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go
index a36affc..b7c20fe 100644
--- a/internal/lsp/lsprpc/lsprpc_test.go
+++ b/internal/lsp/lsprpc/lsprpc_test.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"regexp"
+	"sync"
 	"testing"
 	"time"
 
@@ -61,40 +62,37 @@
 		if !matched {
 			t.Errorf("got log %q, want a log containing %q", got, want)
 		}
-	case <-time.After(1000 * time.Second):
+	case <-time.After(1 * time.Second):
 		t.Error("timeout waiting for client log")
 	}
 }
 
+// waitableServer instruments LSP request so that we can control their timing.
+// The requests chosen are arbitrary: we simply needed one that blocks, and
+// another that doesn't.
 type waitableServer struct {
 	protocol.Server
 
 	started chan struct{}
-	// finished records whether the request ended with a cancellation or not
-	// (true means the request was cancelled).
-	finished chan bool
 }
 
-func (s waitableServer) CodeLens(ctx context.Context, params *protocol.CodeLensParams) ([]protocol.CodeLens, error) {
+func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (*protocol.Hover, error) {
 	s.started <- struct{}{}
-	cancelled := false
-	defer func() {
-		s.finished <- cancelled
-	}()
 	select {
 	case <-ctx.Done():
-		cancelled = true
 		return nil, ctx.Err()
-	case <-time.After(1 * time.Second):
-		cancelled = false
+	case <-time.After(200 * time.Millisecond):
 	}
-	return []protocol.CodeLens{}, nil
+	return &protocol.Hover{}, nil
+}
+
+func (s waitableServer) Resolve(_ context.Context, item *protocol.CompletionItem) (*protocol.CompletionItem, error) {
+	return item, nil
 }
 
 func TestRequestCancellation(t *testing.T) {
 	server := waitableServer{
-		started:  make(chan struct{}),
-		finished: make(chan bool),
+		started: make(chan struct{}),
 	}
 	ss := &StreamServer{
 		accept: func(c protocol.Client) protocol.Server {
@@ -119,14 +117,33 @@
 		t.Run(test.serverType, func(t *testing.T) {
 			cc := test.ts.Connect(ctx)
 			cc.AddHandler(protocol.Canceller{})
-			lensCtx, cancelLens := context.WithCancel(context.Background())
+			ctx := context.Background()
+			ctx1, cancel1 := context.WithCancel(ctx)
+			var (
+				err1, err2 error
+				wg         sync.WaitGroup
+			)
+			wg.Add(2)
 			go func() {
-				protocol.ServerDispatcher(cc).CodeLens(lensCtx, &protocol.CodeLensParams{})
+				defer wg.Done()
+				_, err1 = protocol.ServerDispatcher(cc).Hover(ctx1, &protocol.HoverParams{})
 			}()
+			go func() {
+				defer wg.Done()
+				_, err2 = protocol.ServerDispatcher(cc).Resolve(ctx, &protocol.CompletionItem{})
+			}()
+			// Wait for the Hover request to start.
 			<-server.started
-			cancelLens()
-			if got, want := <-server.finished, true; got != want {
-				t.Errorf("CodeLens was cancelled: %t, want %t", got, want)
+			cancel1()
+			wg.Wait()
+			if err1 == nil {
+				t.Errorf("cancelled Hover(): got nil err")
+			}
+			if err2 != nil {
+				t.Errorf("uncancelled Hover(): err: %v", err2)
+			}
+			if _, err := protocol.ServerDispatcher(cc).Resolve(ctx, &protocol.CompletionItem{}); err != nil {
+				t.Errorf("subsequent Hover(): %v", err)
 			}
 		})
 	}