internal/lsp: cancel early

This change allows us to hanel cancel messages as they go into the queue, and
cancel messages that are ahead of them in the queue but not being processed yet.
This should reduce the amount of redundant work that we do when we are handling
a cancel storm.

Change-Id: Id1a58991407d75b68d65bacf96350a4dd69d4d2b
Reviewed-on: https://go-review.googlesource.com/c/tools/+/200766
Run-TryBot: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go
index 598c79b..bf8bfb3 100644
--- a/internal/jsonrpc2/handler.go
+++ b/internal/jsonrpc2/handler.go
@@ -38,9 +38,9 @@
 	// response
 
 	// Request is called near the start of processing any request.
-	Request(ctx context.Context, direction Direction, r *WireRequest) context.Context
+	Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context
 	// Response is called near the start of processing any response.
-	Response(ctx context.Context, direction Direction, r *WireResponse) context.Context
+	Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context
 	// Done is called when any request is fully processed.
 	// For calls, this means the response has also been processed, for notifies
 	// this is as soon as the message has been written to the stream.
@@ -90,11 +90,11 @@
 	return false
 }
 
-func (EmptyHandler) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context {
+func (EmptyHandler) Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context {
 	return ctx
 }
 
-func (EmptyHandler) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context {
+func (EmptyHandler) Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context {
 	return ctx
 }
 
diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go
index f9c4021..62c8141 100644
--- a/internal/jsonrpc2/jsonrpc2.go
+++ b/internal/jsonrpc2/jsonrpc2.go
@@ -110,7 +110,7 @@
 		return fmt.Errorf("marshalling notify request: %v", err)
 	}
 	for _, h := range c.handlers {
-		ctx = h.Request(ctx, Send, request)
+		ctx = h.Request(ctx, c, Send, request)
 	}
 	defer func() {
 		for _, h := range c.handlers {
@@ -145,7 +145,7 @@
 		return fmt.Errorf("marshalling call request: %v", err)
 	}
 	for _, h := range c.handlers {
-		ctx = h.Request(ctx, Send, request)
+		ctx = h.Request(ctx, c, Send, request)
 	}
 	// we have to add ourselves to the pending map before we send, otherwise we
 	// are racing the response
@@ -175,7 +175,7 @@
 	select {
 	case response := <-rchan:
 		for _, h := range c.handlers {
-			ctx = h.Response(ctx, Receive, response)
+			ctx = h.Response(ctx, c, Receive, response)
 		}
 		// is it an error response?
 		if response.Error != nil {
@@ -262,7 +262,7 @@
 		return err
 	}
 	for _, h := range r.conn.handlers {
-		ctx = h.Response(ctx, Send, response)
+		ctx = h.Response(ctx, r.conn, Send, response)
 	}
 	n, err := r.conn.stream.Write(ctx, data)
 	for _, h := range r.conn.handlers {
@@ -347,7 +347,7 @@
 				},
 			}
 			for _, h := range c.handlers {
-				reqCtx = h.Request(reqCtx, Receive, &req.WireRequest)
+				reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
 				reqCtx = h.Read(reqCtx, n)
 			}
 			c.setHandling(req, true)
diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go
index 89252fd..192a5e8 100644
--- a/internal/jsonrpc2/jsonrpc2_test.go
+++ b/internal/jsonrpc2/jsonrpc2_test.go
@@ -164,7 +164,7 @@
 	return false
 }
 
-func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
+func (h *handle) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
 	if h.log {
 		if r.ID != nil {
 			log.Printf("%v call [%v] %s %v", direction, r.ID, r.Method, r.Params)
@@ -177,7 +177,7 @@
 	return ctx
 }
 
-func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
+func (h *handle) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
 	if h.log {
 		method := ctx.Value("method")
 		elapsed := time.Since(ctx.Value("start").(time.Time))
diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go
index 70370af..ceca1f1 100644
--- a/internal/lsp/cmd/serve.go
+++ b/internal/lsp/cmd/serve.go
@@ -149,7 +149,7 @@
 	return false
 }
 
-func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
+func (h *handler) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
 	if r.Method == "" {
 		panic("no method in rpc stats")
 	}
@@ -174,7 +174,7 @@
 	return ctx
 }
 
-func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
+func (h *handler) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
 	return ctx
 }
 
diff --git a/internal/lsp/protocol/protocol.go b/internal/lsp/protocol/protocol.go
index 80be75d..8aaa3ef 100644
--- a/internal/lsp/protocol/protocol.go
+++ b/internal/lsp/protocol/protocol.go
@@ -6,6 +6,7 @@
 
 import (
 	"context"
+	"encoding/json"
 
 	"golang.org/x/tools/internal/jsonrpc2"
 	"golang.org/x/tools/internal/telemetry/log"
@@ -13,6 +14,11 @@
 	"golang.org/x/tools/internal/xcontext"
 )
 
+const (
+	// RequestCancelledError should be used when a request is cancelled early.
+	RequestCancelledError = -32800
+)
+
 type DocumentUri = string
 
 type canceller struct{ jsonrpc2.EmptyHandler }
@@ -27,6 +33,18 @@
 	server Server
 }
 
+func (canceller) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
+	if direction == jsonrpc2.Receive && r.Method == "$/cancelRequest" {
+		var params CancelParams
+		if err := json.Unmarshal(*r.Params, &params); err != nil {
+			log.Error(ctx, "", err)
+		} else {
+			conn.Cancel(params.ID)
+		}
+	}
+	return ctx
+}
+
 func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool {
 	if cancelled {
 		return false
diff --git a/internal/lsp/protocol/tsclient.go b/internal/lsp/protocol/tsclient.go
index 113989f..969e055 100644
--- a/internal/lsp/protocol/tsclient.go
+++ b/internal/lsp/protocol/tsclient.go
@@ -8,6 +8,7 @@
 
 	"golang.org/x/tools/internal/jsonrpc2"
 	"golang.org/x/tools/internal/telemetry/log"
+	"golang.org/x/tools/internal/xcontext"
 )
 
 type Client interface {
@@ -27,15 +28,12 @@
 	if delivered {
 		return false
 	}
-	switch r.Method {
-	case "$/cancelRequest":
-		var params CancelParams
-		if err := json.Unmarshal(*r.Params, &params); err != nil {
-			sendParseError(ctx, r, err)
-			return true
-		}
-		r.Conn().Cancel(params.ID)
+	if ctx.Err() != nil {
+		ctx := xcontext.Detach(ctx)
+		r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
 		return true
+	}
+	switch r.Method {
 	case "window/showMessage": // notif
 		var params ShowMessageParams
 		if err := json.Unmarshal(*r.Params, &params); err != nil {
diff --git a/internal/lsp/protocol/tsserver.go b/internal/lsp/protocol/tsserver.go
index 1882ade..9ee423b 100644
--- a/internal/lsp/protocol/tsserver.go
+++ b/internal/lsp/protocol/tsserver.go
@@ -8,6 +8,7 @@
 
 	"golang.org/x/tools/internal/jsonrpc2"
 	"golang.org/x/tools/internal/telemetry/log"
+	"golang.org/x/tools/internal/xcontext"
 )
 
 type Server interface {
@@ -46,13 +47,13 @@
 	Symbol(context.Context, *WorkspaceSymbolParams) ([]SymbolInformation, error)
 	CodeLens(context.Context, *CodeLensParams) ([]CodeLens, error)
 	ResolveCodeLens(context.Context, *CodeLens) (*CodeLens, error)
+	DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error)
+	ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error)
 	Formatting(context.Context, *DocumentFormattingParams) ([]TextEdit, error)
 	RangeFormatting(context.Context, *DocumentRangeFormattingParams) ([]TextEdit, error)
 	OnTypeFormatting(context.Context, *DocumentOnTypeFormattingParams) ([]TextEdit, error)
 	Rename(context.Context, *RenameParams) (*WorkspaceEdit, error)
 	PrepareRename(context.Context, *PrepareRenameParams) (*Range, error)
-	DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error)
-	ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error)
 	ExecuteCommand(context.Context, *ExecuteCommandParams) (interface{}, error)
 }
 
@@ -60,15 +61,12 @@
 	if delivered {
 		return false
 	}
-	switch r.Method {
-	case "$/cancelRequest":
-		var params CancelParams
-		if err := json.Unmarshal(*r.Params, &params); err != nil {
-			sendParseError(ctx, r, err)
-			return true
-		}
-		r.Conn().Cancel(params.ID)
+	if ctx.Err() != nil {
+		ctx := xcontext.Detach(ctx)
+		r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
 		return true
+	}
+	switch r.Method {
 	case "workspace/didChangeWorkspaceFolders": // notif
 		var params DidChangeWorkspaceFoldersParams
 		if err := json.Unmarshal(*r.Params, &params); err != nil {
@@ -435,6 +433,28 @@
 			log.Error(ctx, "", err)
 		}
 		return true
+	case "textDocument/documentLink": // req
+		var params DocumentLinkParams
+		if err := json.Unmarshal(*r.Params, &params); err != nil {
+			sendParseError(ctx, r, err)
+			return true
+		}
+		resp, err := h.server.DocumentLink(ctx, &params)
+		if err := r.Reply(ctx, resp, err); err != nil {
+			log.Error(ctx, "", err)
+		}
+		return true
+	case "documentLink/resolve": // req
+		var params DocumentLink
+		if err := json.Unmarshal(*r.Params, &params); err != nil {
+			sendParseError(ctx, r, err)
+			return true
+		}
+		resp, err := h.server.ResolveDocumentLink(ctx, &params)
+		if err := r.Reply(ctx, resp, err); err != nil {
+			log.Error(ctx, "", err)
+		}
+		return true
 	case "textDocument/formatting": // req
 		var params DocumentFormattingParams
 		if err := json.Unmarshal(*r.Params, &params); err != nil {
@@ -490,28 +510,6 @@
 			log.Error(ctx, "", err)
 		}
 		return true
-	case "textDocument/documentLink": // req
-		var params DocumentLinkParams
-		if err := json.Unmarshal(*r.Params, &params); err != nil {
-			sendParseError(ctx, r, err)
-			return true
-		}
-		resp, err := h.server.DocumentLink(ctx, &params)
-		if err := r.Reply(ctx, resp, err); err != nil {
-			log.Error(ctx, "", err)
-		}
-		return true
-	case "documentLink/resolve": // req
-		var params DocumentLink
-		if err := json.Unmarshal(*r.Params, &params); err != nil {
-			sendParseError(ctx, r, err)
-			return true
-		}
-		resp, err := h.server.ResolveDocumentLink(ctx, &params)
-		if err := r.Reply(ctx, resp, err); err != nil {
-			log.Error(ctx, "", err)
-		}
-		return true
 	case "workspace/executeCommand": // req
 		var params ExecuteCommandParams
 		if err := json.Unmarshal(*r.Params, &params); err != nil {
@@ -756,6 +754,22 @@
 	return &result, nil
 }
 
+func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) {
+	var result []DocumentLink
+	if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil {
+		return nil, err
+	}
+	return result, nil
+}
+
+func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) {
+	var result DocumentLink
+	if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil {
+		return nil, err
+	}
+	return &result, nil
+}
+
 func (s *serverDispatcher) Formatting(ctx context.Context, params *DocumentFormattingParams) ([]TextEdit, error) {
 	var result []TextEdit
 	if err := s.Conn.Call(ctx, "textDocument/formatting", params, &result); err != nil {
@@ -796,22 +810,6 @@
 	return &result, nil
 }
 
-func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) {
-	var result []DocumentLink
-	if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil {
-		return nil, err
-	}
-	return result, nil
-}
-
-func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) {
-	var result DocumentLink
-	if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil {
-		return nil, err
-	}
-	return &result, nil
-}
-
 func (s *serverDispatcher) ExecuteCommand(ctx context.Context, params *ExecuteCommandParams) (interface{}, error) {
 	var result interface{}
 	if err := s.Conn.Call(ctx, "workspace/executeCommand", params, &result); err != nil {
diff --git a/internal/lsp/protocol/typescript/requests.ts b/internal/lsp/protocol/typescript/requests.ts
index 0115568..373b925 100644
--- a/internal/lsp/protocol/typescript/requests.ts
+++ b/internal/lsp/protocol/typescript/requests.ts
@@ -224,7 +224,8 @@
 
     "golang.org/x/tools/internal/jsonrpc2"
     "golang.org/x/tools/internal/telemetry/log"
-  )
+    "golang.org/x/tools/internal/xcontext"
+    )
   `);
   const a = side.name[0].toUpperCase() + side.name.substring(1)
   f(`type ${a} interface {`);
@@ -235,15 +236,12 @@
       if delivered {
         return false
       }
-      switch r.Method {
-      case "$/cancelRequest":
-        var params CancelParams
-        if err := json.Unmarshal(*r.Params, &params); err != nil {
-          sendParseError(ctx, r, err)
-          return true
-        }
-        r.Conn().Cancel(params.ID)
-        return true`);
+      if ctx.Err() != nil {
+        ctx := xcontext.Detach(ctx)
+        r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
+        return true
+      }
+      switch r.Method {`);
   side.cases.forEach((v) => {f(v)});
   f(`
   default: