internal/lsp/lsprpc: implement cancellation using jsonrpc2_v2

Use the jsonrpc2_v2 Preemption option to support request cancellation.

Also fix the TestRequestCancellation to actually test request
cancellation, and add a V2 version of this test. For now, the
ForwardBinder is not exercised.

Factor out test set-up and tear down.

Change-Id: Ic104e922fa2d0ae570b69c3928e371175db28a9f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/321014
Trust: Robert Findley <rfindley@google.com>
Trust: Ian Cottrell <iancottrell@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/lsprpc/binder.go b/internal/lsp/lsprpc/binder.go
index 15fdda2..3f5cb3b 100644
--- a/internal/lsp/lsprpc/binder.go
+++ b/internal/lsp/lsprpc/binder.go
@@ -6,9 +6,11 @@
 
 import (
 	"context"
+	"encoding/json"
 
 	jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
 	"golang.org/x/tools/internal/lsp/protocol"
+	errors "golang.org/x/xerrors"
 )
 
 type ServerFunc func(context.Context, protocol.ClientCloser) protocol.Server
@@ -33,11 +35,40 @@
 		ctx = protocol.WithClient(ctx, client)
 		return serverHandler.Handle(ctx, req)
 	})
+	preempter := &canceler{
+		conn: conn,
+	}
 	return jsonrpc2_v2.ConnectionOptions{
-		Handler: wrapped,
+		Handler:   wrapped,
+		Preempter: preempter,
 	}, nil
 }
 
+type canceler struct {
+	conn *jsonrpc2_v2.Connection
+}
+
+func (c *canceler) Preempt(ctx context.Context, req *jsonrpc2_v2.Request) (interface{}, error) {
+	if req.Method != "$/cancelRequest" {
+		return nil, jsonrpc2_v2.ErrNotHandled
+	}
+	var params protocol.CancelParams
+	if err := json.Unmarshal(req.Params, &params); err != nil {
+		return nil, errors.Errorf("%w: %v", jsonrpc2_v2.ErrParse, err)
+	}
+	var id jsonrpc2_v2.ID
+	switch raw := params.ID.(type) {
+	case float64:
+		id = jsonrpc2_v2.Int64ID(int64(raw))
+	case string:
+		id = jsonrpc2_v2.StringID(raw)
+	default:
+		return nil, errors.Errorf("%w: invalid ID type %T", jsonrpc2_v2.ErrParse, params.ID)
+	}
+	c.conn.Cancel(id)
+	return nil, nil
+}
+
 type ForwardBinder struct {
 	dialer jsonrpc2_v2.Dialer
 }
diff --git a/internal/lsp/lsprpc/binder_test.go b/internal/lsp/lsprpc/binder_test.go
index aa9c9d4..d29de0f 100644
--- a/internal/lsp/lsprpc/binder_test.go
+++ b/internal/lsp/lsprpc/binder_test.go
@@ -9,8 +9,8 @@
 
 import (
 	"context"
-	"log"
 	"regexp"
+	"strings"
 	"testing"
 	"time"
 
@@ -18,29 +18,57 @@
 	"golang.org/x/tools/internal/lsp/protocol"
 )
 
-func TestClientLoggingV2(t *testing.T) {
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
+type testEnv struct {
+	listener  jsonrpc2_v2.Listener
+	conn      *jsonrpc2_v2.Connection
+	rpcServer *jsonrpc2_v2.Server
+}
 
+func (e testEnv) Shutdown(t *testing.T) {
+	if err := e.listener.Close(); err != nil {
+		t.Error(err)
+	}
+	if err := e.conn.Close(); err != nil {
+		t.Error(err)
+	}
+	if err := e.rpcServer.Wait(); err != nil {
+		t.Error(err)
+	}
+}
+
+func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client) testEnv {
 	listener, err := jsonrpc2_v2.NetPipe(ctx)
 	if err != nil {
 		t.Fatal(err)
 	}
 	newServer := func(ctx context.Context, client protocol.ClientCloser) protocol.Server {
-		return pingServer{}
+		return server
 	}
 	serverBinder := NewServerBinder(newServer)
-	server, err := jsonrpc2_v2.Serve(ctx, listener, serverBinder)
+	rpcServer, err := jsonrpc2_v2.Serve(ctx, listener, serverBinder)
 	if err != nil {
 		t.Fatal(err)
 	}
-	client := fakeClient{logs: make(chan string, 10)}
 	clientBinder := NewClientBinder(func(context.Context, protocol.Server) protocol.Client { return client })
 	conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder)
 	if err != nil {
 		t.Fatal(err)
 	}
-	if err := protocol.ServerDispatcherV2(conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
+	return testEnv{
+		listener:  listener,
+		rpcServer: rpcServer,
+		conn:      conn,
+	}
+}
+
+func TestClientLoggingV2(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	client := fakeClient{logs: make(chan string, 10)}
+	env := startServing(ctx, t, pingServer{}, client)
+	defer env.Shutdown(t)
+	if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
 		t.Errorf("DidOpen: %v", err)
 	}
 	select {
@@ -56,13 +84,34 @@
 	case <-time.After(1 * time.Second):
 		t.Error("timeout waiting for client log")
 	}
-	if err := listener.Close(); err != nil {
-		t.Error(err)
+}
+
+func TestRequestCancellationV2(t *testing.T) {
+	ctx := context.Background()
+
+	server := waitableServer{
+		started:   make(chan struct{}),
+		completed: make(chan error),
 	}
-	if err := conn.Close(); err != nil {
-		t.Fatal(err)
+	client := fakeClient{logs: make(chan string, 10)}
+	env := startServing(ctx, t, server, client)
+	defer env.Shutdown(t)
+
+	sd := protocol.ServerDispatcherV2(env.conn)
+	ctx, cancel := context.WithCancel(ctx)
+
+	result := make(chan error)
+	go func() {
+		_, err := sd.Hover(ctx, &protocol.HoverParams{})
+		result <- err
+	}()
+	// Wait for the Hover request to start.
+	<-server.started
+	cancel()
+	if err := <-result; err == nil {
+		t.Error("nil error for cancelled Hover(), want non-nil")
 	}
-	if err := server.Wait(); err != nil {
-		log.Fatal(err)
+	if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
+		t.Errorf("Hover(): unexpected server-side error %v", err)
 	}
 }
diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go
index 1bdde59..2f2cf1a 100644
--- a/internal/lsp/lsprpc/lsprpc_test.go
+++ b/internal/lsp/lsprpc/lsprpc_test.go
@@ -6,8 +6,9 @@
 
 import (
 	"context"
+	"errors"
 	"regexp"
-	"sync"
+	"strings"
 	"testing"
 	"time"
 
@@ -89,15 +90,19 @@
 type waitableServer struct {
 	fakeServer
 
-	started chan struct{}
+	started   chan struct{}
+	completed chan error
 }
 
-func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (*protocol.Hover, error) {
+func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (_ *protocol.Hover, err error) {
 	s.started <- struct{}{}
+	defer func() {
+		s.completed <- err
+	}()
 	select {
 	case <-ctx.Done():
-		return nil, ctx.Err()
-	case <-time.After(200 * time.Millisecond):
+		return nil, errors.New("cancelled hover")
+	case <-time.After(10 * time.Second):
 	}
 	return &protocol.Hover{}, nil
 }
@@ -132,7 +137,8 @@
 func TestRequestCancellation(t *testing.T) {
 	ctx := context.Background()
 	server := waitableServer{
-		started: make(chan struct{}),
+		started:   make(chan struct{}),
+		completed: make(chan error),
 	}
 	tsDirect, tsForwarded, cleanup := setupForwarding(ctx, t, server)
 	defer cleanup()
@@ -153,32 +159,21 @@
 					jsonrpc2.MethodNotFound))
 
 			ctx := context.Background()
-			ctx1, cancel1 := context.WithCancel(ctx)
-			var (
-				err1, err2 error
-				wg         sync.WaitGroup
-			)
-			wg.Add(2)
+			ctx, cancel := context.WithCancel(ctx)
+
+			result := make(chan error)
 			go func() {
-				defer wg.Done()
-				_, err1 = sd.Hover(ctx1, &protocol.HoverParams{})
-			}()
-			go func() {
-				defer wg.Done()
-				_, err2 = sd.Resolve(ctx, &protocol.CompletionItem{})
+				_, err := sd.Hover(ctx, &protocol.HoverParams{})
+				result <- err
 			}()
 			// Wait for the Hover request to start.
 			<-server.started
-			cancel1()
-			wg.Wait()
-			if err1 == nil {
-				t.Errorf("cancelled Hover(): got nil err")
+			cancel()
+			if err := <-result; err == nil {
+				t.Error("nil error for cancelled Hover(), want non-nil")
 			}
-			if err2 != nil {
-				t.Errorf("uncancelled Hover(): err: %v", err2)
-			}
-			if _, err := sd.Resolve(ctx, &protocol.CompletionItem{}); err != nil {
-				t.Errorf("subsequent Hover(): %v", err)
+			if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
+				t.Errorf("Hover(): unexpected server-side error %v", err)
 			}
 		})
 	}
diff --git a/internal/lsp/protocol/protocol.go b/internal/lsp/protocol/protocol.go
index 05adf41..a8b3354 100644
--- a/internal/lsp/protocol/protocol.go
+++ b/internal/lsp/protocol/protocol.go
@@ -86,7 +86,13 @@
 }
 
 func (c clientConnV2) Call(ctx context.Context, method string, params interface{}, result interface{}) error {
-	return c.conn.Call(ctx, method, params).Await(ctx, result)
+	call := c.conn.Call(ctx, method, params)
+	err := call.Await(ctx, result)
+	if ctx.Err() != nil {
+		detached := xcontext.Detach(ctx)
+		c.conn.Notify(detached, "$/cancelRequest", &CancelParams{ID: call.ID().Raw()})
+	}
+	return err
 }
 
 // ServerDispatcher returns a Server that dispatches LSP requests across the