internal/lsp/lsprpc: implement cancellation using jsonrpc2_v2
Use the jsonrpc2_v2 Preemption option to support request cancellation.
Also fix the TestRequestCancellation to actually test request
cancellation, and add a V2 version of this test. For now, the
ForwardBinder is not exercised.
Factor out test set-up and tear down.
Change-Id: Ic104e922fa2d0ae570b69c3928e371175db28a9f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/321014
Trust: Robert Findley <rfindley@google.com>
Trust: Ian Cottrell <iancottrell@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/lsprpc/binder.go b/internal/lsp/lsprpc/binder.go
index 15fdda2..3f5cb3b 100644
--- a/internal/lsp/lsprpc/binder.go
+++ b/internal/lsp/lsprpc/binder.go
@@ -6,9 +6,11 @@
import (
"context"
+ "encoding/json"
jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
"golang.org/x/tools/internal/lsp/protocol"
+ errors "golang.org/x/xerrors"
)
type ServerFunc func(context.Context, protocol.ClientCloser) protocol.Server
@@ -33,11 +35,40 @@
ctx = protocol.WithClient(ctx, client)
return serverHandler.Handle(ctx, req)
})
+ preempter := &canceler{
+ conn: conn,
+ }
return jsonrpc2_v2.ConnectionOptions{
- Handler: wrapped,
+ Handler: wrapped,
+ Preempter: preempter,
}, nil
}
+type canceler struct {
+ conn *jsonrpc2_v2.Connection
+}
+
+func (c *canceler) Preempt(ctx context.Context, req *jsonrpc2_v2.Request) (interface{}, error) {
+ if req.Method != "$/cancelRequest" {
+ return nil, jsonrpc2_v2.ErrNotHandled
+ }
+ var params protocol.CancelParams
+ if err := json.Unmarshal(req.Params, ¶ms); err != nil {
+ return nil, errors.Errorf("%w: %v", jsonrpc2_v2.ErrParse, err)
+ }
+ var id jsonrpc2_v2.ID
+ switch raw := params.ID.(type) {
+ case float64:
+ id = jsonrpc2_v2.Int64ID(int64(raw))
+ case string:
+ id = jsonrpc2_v2.StringID(raw)
+ default:
+ return nil, errors.Errorf("%w: invalid ID type %T", jsonrpc2_v2.ErrParse, params.ID)
+ }
+ c.conn.Cancel(id)
+ return nil, nil
+}
+
type ForwardBinder struct {
dialer jsonrpc2_v2.Dialer
}
diff --git a/internal/lsp/lsprpc/binder_test.go b/internal/lsp/lsprpc/binder_test.go
index aa9c9d4..d29de0f 100644
--- a/internal/lsp/lsprpc/binder_test.go
+++ b/internal/lsp/lsprpc/binder_test.go
@@ -9,8 +9,8 @@
import (
"context"
- "log"
"regexp"
+ "strings"
"testing"
"time"
@@ -18,29 +18,57 @@
"golang.org/x/tools/internal/lsp/protocol"
)
-func TestClientLoggingV2(t *testing.T) {
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+type testEnv struct {
+ listener jsonrpc2_v2.Listener
+ conn *jsonrpc2_v2.Connection
+ rpcServer *jsonrpc2_v2.Server
+}
+func (e testEnv) Shutdown(t *testing.T) {
+ if err := e.listener.Close(); err != nil {
+ t.Error(err)
+ }
+ if err := e.conn.Close(); err != nil {
+ t.Error(err)
+ }
+ if err := e.rpcServer.Wait(); err != nil {
+ t.Error(err)
+ }
+}
+
+func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client) testEnv {
listener, err := jsonrpc2_v2.NetPipe(ctx)
if err != nil {
t.Fatal(err)
}
newServer := func(ctx context.Context, client protocol.ClientCloser) protocol.Server {
- return pingServer{}
+ return server
}
serverBinder := NewServerBinder(newServer)
- server, err := jsonrpc2_v2.Serve(ctx, listener, serverBinder)
+ rpcServer, err := jsonrpc2_v2.Serve(ctx, listener, serverBinder)
if err != nil {
t.Fatal(err)
}
- client := fakeClient{logs: make(chan string, 10)}
clientBinder := NewClientBinder(func(context.Context, protocol.Server) protocol.Client { return client })
conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder)
if err != nil {
t.Fatal(err)
}
- if err := protocol.ServerDispatcherV2(conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
+ return testEnv{
+ listener: listener,
+ rpcServer: rpcServer,
+ conn: conn,
+ }
+}
+
+func TestClientLoggingV2(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ client := fakeClient{logs: make(chan string, 10)}
+ env := startServing(ctx, t, pingServer{}, client)
+ defer env.Shutdown(t)
+ if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
t.Errorf("DidOpen: %v", err)
}
select {
@@ -56,13 +84,34 @@
case <-time.After(1 * time.Second):
t.Error("timeout waiting for client log")
}
- if err := listener.Close(); err != nil {
- t.Error(err)
+}
+
+func TestRequestCancellationV2(t *testing.T) {
+ ctx := context.Background()
+
+ server := waitableServer{
+ started: make(chan struct{}),
+ completed: make(chan error),
}
- if err := conn.Close(); err != nil {
- t.Fatal(err)
+ client := fakeClient{logs: make(chan string, 10)}
+ env := startServing(ctx, t, server, client)
+ defer env.Shutdown(t)
+
+ sd := protocol.ServerDispatcherV2(env.conn)
+ ctx, cancel := context.WithCancel(ctx)
+
+ result := make(chan error)
+ go func() {
+ _, err := sd.Hover(ctx, &protocol.HoverParams{})
+ result <- err
+ }()
+ // Wait for the Hover request to start.
+ <-server.started
+ cancel()
+ if err := <-result; err == nil {
+ t.Error("nil error for cancelled Hover(), want non-nil")
}
- if err := server.Wait(); err != nil {
- log.Fatal(err)
+ if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
+ t.Errorf("Hover(): unexpected server-side error %v", err)
}
}
diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go
index 1bdde59..2f2cf1a 100644
--- a/internal/lsp/lsprpc/lsprpc_test.go
+++ b/internal/lsp/lsprpc/lsprpc_test.go
@@ -6,8 +6,9 @@
import (
"context"
+ "errors"
"regexp"
- "sync"
+ "strings"
"testing"
"time"
@@ -89,15 +90,19 @@
type waitableServer struct {
fakeServer
- started chan struct{}
+ started chan struct{}
+ completed chan error
}
-func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (*protocol.Hover, error) {
+func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (_ *protocol.Hover, err error) {
s.started <- struct{}{}
+ defer func() {
+ s.completed <- err
+ }()
select {
case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(200 * time.Millisecond):
+ return nil, errors.New("cancelled hover")
+ case <-time.After(10 * time.Second):
}
return &protocol.Hover{}, nil
}
@@ -132,7 +137,8 @@
func TestRequestCancellation(t *testing.T) {
ctx := context.Background()
server := waitableServer{
- started: make(chan struct{}),
+ started: make(chan struct{}),
+ completed: make(chan error),
}
tsDirect, tsForwarded, cleanup := setupForwarding(ctx, t, server)
defer cleanup()
@@ -153,32 +159,21 @@
jsonrpc2.MethodNotFound))
ctx := context.Background()
- ctx1, cancel1 := context.WithCancel(ctx)
- var (
- err1, err2 error
- wg sync.WaitGroup
- )
- wg.Add(2)
+ ctx, cancel := context.WithCancel(ctx)
+
+ result := make(chan error)
go func() {
- defer wg.Done()
- _, err1 = sd.Hover(ctx1, &protocol.HoverParams{})
- }()
- go func() {
- defer wg.Done()
- _, err2 = sd.Resolve(ctx, &protocol.CompletionItem{})
+ _, err := sd.Hover(ctx, &protocol.HoverParams{})
+ result <- err
}()
// Wait for the Hover request to start.
<-server.started
- cancel1()
- wg.Wait()
- if err1 == nil {
- t.Errorf("cancelled Hover(): got nil err")
+ cancel()
+ if err := <-result; err == nil {
+ t.Error("nil error for cancelled Hover(), want non-nil")
}
- if err2 != nil {
- t.Errorf("uncancelled Hover(): err: %v", err2)
- }
- if _, err := sd.Resolve(ctx, &protocol.CompletionItem{}); err != nil {
- t.Errorf("subsequent Hover(): %v", err)
+ if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
+ t.Errorf("Hover(): unexpected server-side error %v", err)
}
})
}
diff --git a/internal/lsp/protocol/protocol.go b/internal/lsp/protocol/protocol.go
index 05adf41..a8b3354 100644
--- a/internal/lsp/protocol/protocol.go
+++ b/internal/lsp/protocol/protocol.go
@@ -86,7 +86,13 @@
}
func (c clientConnV2) Call(ctx context.Context, method string, params interface{}, result interface{}) error {
- return c.conn.Call(ctx, method, params).Await(ctx, result)
+ call := c.conn.Call(ctx, method, params)
+ err := call.Await(ctx, result)
+ if ctx.Err() != nil {
+ detached := xcontext.Detach(ctx)
+ c.conn.Notify(detached, "$/cancelRequest", &CancelParams{ID: call.ID().Raw()})
+ }
+ return err
}
// ServerDispatcher returns a Server that dispatches LSP requests across the