internal/lsp/command: pass Context to commands

Smuggling the Context was too fancy, and unidiomatic.

Change-Id: Iabca39ed73d5a40bfe7d500358228700eefbc60f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/290790
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index a11b8b3..487e182 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -39,16 +39,13 @@
 	}
 
 	handler := &commandHandler{
-		ctx:    ctx,
 		s:      s,
 		params: params,
 	}
-	return command.Dispatch(params, handler)
+	return command.Dispatch(ctx, params, handler)
 }
 
 type commandHandler struct {
-	// ctx is temporarily held so that we may implement the command.Interface interface.
-	ctx    context.Context
 	s      *Server
 	params *protocol.ExecuteCommandParams
 }
@@ -72,7 +69,7 @@
 
 type commandFunc func(context.Context, commandDeps) error
 
-func (c *commandHandler) run(cfg commandConfig, run commandFunc) (err error) {
+func (c *commandHandler) run(ctx context.Context, cfg commandConfig, run commandFunc) (err error) {
 	if cfg.requireSave {
 		for _, overlay := range c.s.session.Overlays() {
 			if !overlay.Saved() {
@@ -84,13 +81,13 @@
 	if cfg.forURI != "" {
 		var ok bool
 		var release func()
-		deps.snapshot, deps.fh, ok, release, err = c.s.beginFileRequest(c.ctx, cfg.forURI, source.UnknownKind)
+		deps.snapshot, deps.fh, ok, release, err = c.s.beginFileRequest(ctx, cfg.forURI, source.UnknownKind)
 		defer release()
 		if !ok {
 			return err
 		}
 	}
-	ctx, cancel := context.WithCancel(xcontext.Detach(c.ctx))
+	ctx, cancel := context.WithCancel(xcontext.Detach(ctx))
 	if cfg.progress != "" {
 		deps.work = c.s.progress.start(ctx, cfg.progress, "Running...", c.params.WorkDoneToken, cancel)
 	}
@@ -115,8 +112,8 @@
 	return runcmd()
 }
 
-func (c *commandHandler) ApplyFix(args command.ApplyFixArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) ApplyFix(ctx context.Context, args command.ApplyFixArgs) error {
+	return c.run(ctx, commandConfig{
 		// Note: no progress here. Applying fixes should be quick.
 		forURI: args.URI,
 	}, func(ctx context.Context, deps commandDeps) error {
@@ -139,20 +136,20 @@
 	})
 }
 
-func (c *commandHandler) RegenerateCgo(args command.URIArg) error {
-	return c.run(commandConfig{
+func (c *commandHandler) RegenerateCgo(ctx context.Context, args command.URIArg) error {
+	return c.run(ctx, commandConfig{
 		progress: "Regenerating Cgo",
 	}, func(ctx context.Context, deps commandDeps) error {
 		mod := source.FileModification{
 			URI:    args.URI.SpanURI(),
 			Action: source.InvalidateMetadata,
 		}
-		return c.s.didModifyFiles(c.ctx, []source.FileModification{mod}, FromRegenerateCgo)
+		return c.s.didModifyFiles(ctx, []source.FileModification{mod}, FromRegenerateCgo)
 	})
 }
 
-func (c *commandHandler) CheckUpgrades(args command.CheckUpgradesArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) CheckUpgrades(ctx context.Context, args command.CheckUpgradesArgs) error {
+	return c.run(ctx, commandConfig{
 		forURI:   args.URI,
 		progress: "Checking for upgrades",
 	}, func(ctx context.Context, deps commandDeps) error {
@@ -167,16 +164,16 @@
 	})
 }
 
-func (c *commandHandler) AddDependency(args command.DependencyArgs) error {
-	return c.GoGetModule(args)
+func (c *commandHandler) AddDependency(ctx context.Context, args command.DependencyArgs) error {
+	return c.GoGetModule(ctx, args)
 }
 
-func (c *commandHandler) UpgradeDependency(args command.DependencyArgs) error {
-	return c.GoGetModule(args)
+func (c *commandHandler) UpgradeDependency(ctx context.Context, args command.DependencyArgs) error {
+	return c.GoGetModule(ctx, args)
 }
 
-func (c *commandHandler) GoGetModule(args command.DependencyArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) GoGetModule(ctx context.Context, args command.DependencyArgs) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Running go get",
 		forURI:      args.URI,
@@ -187,8 +184,8 @@
 
 // TODO(rFindley): UpdateGoSum, Tidy, and Vendor could probably all be one command.
 
-func (c *commandHandler) UpdateGoSum(args command.URIArg) error {
-	return c.run(commandConfig{
+func (c *commandHandler) UpdateGoSum(ctx context.Context, args command.URIArg) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Updating go.sum",
 		forURI:      args.URI,
@@ -197,8 +194,8 @@
 	})
 }
 
-func (c *commandHandler) Tidy(args command.URIArg) error {
-	return c.run(commandConfig{
+func (c *commandHandler) Tidy(ctx context.Context, args command.URIArg) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Running go mod tidy",
 		forURI:      args.URI,
@@ -207,8 +204,8 @@
 	})
 }
 
-func (c *commandHandler) Vendor(args command.URIArg) error {
-	return c.run(commandConfig{
+func (c *commandHandler) Vendor(ctx context.Context, args command.URIArg) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Running go mod vendor",
 		forURI:      args.URI,
@@ -217,8 +214,8 @@
 	})
 }
 
-func (c *commandHandler) RemoveDependency(args command.RemoveDependencyArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) RemoveDependency(ctx context.Context, args command.RemoveDependencyArgs) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Removing dependency",
 		forURI:      args.URI,
@@ -290,16 +287,16 @@
 	return source.ToProtocolEdits(pm.Mapper, diff)
 }
 
-func (c *commandHandler) Test(uri protocol.DocumentURI, tests, benchmarks []string) error {
-	return c.RunTests(command.RunTestsArgs{
+func (c *commandHandler) Test(ctx context.Context, uri protocol.DocumentURI, tests, benchmarks []string) error {
+	return c.RunTests(ctx, command.RunTestsArgs{
 		URI:        uri,
 		Tests:      tests,
 		Benchmarks: benchmarks,
 	})
 }
 
-func (c *commandHandler) RunTests(args command.RunTestsArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) RunTests(ctx context.Context, args command.RunTestsArgs) error {
+	return c.run(ctx, commandConfig{
 		async:       true,
 		progress:    "Running go test",
 		requireSave: true,
@@ -395,12 +392,12 @@
 	})
 }
 
-func (c *commandHandler) Generate(args command.GenerateArgs) error {
+func (c *commandHandler) Generate(ctx context.Context, args command.GenerateArgs) error {
 	title := "Running go generate ."
 	if args.Recursive {
 		title = "Running go generate ./..."
 	}
-	return c.run(commandConfig{
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    title,
 		forURI:      args.Dir,
@@ -424,8 +421,8 @@
 	})
 }
 
-func (c *commandHandler) GoGetPackage(args command.GoGetPackageArgs) error {
-	return c.run(commandConfig{
+func (c *commandHandler) GoGetPackage(ctx context.Context, args command.GoGetPackageArgs) error {
+	return c.run(ctx, commandConfig{
 		forURI:   args.URI,
 		progress: "Running go get",
 	}, func(ctx context.Context, deps commandDeps) error {
@@ -489,12 +486,12 @@
 	return upgrades, nil
 }
 
-func (c *commandHandler) GCDetails(uri protocol.DocumentURI) error {
-	return c.ToggleGCDetails(command.URIArg{URI: uri})
+func (c *commandHandler) GCDetails(ctx context.Context, uri protocol.DocumentURI) error {
+	return c.ToggleGCDetails(ctx, command.URIArg{URI: uri})
 }
 
-func (c *commandHandler) ToggleGCDetails(args command.URIArg) error {
-	return c.run(commandConfig{
+func (c *commandHandler) ToggleGCDetails(ctx context.Context, args command.URIArg) error {
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Toggling GC Details",
 		forURI:      args.URI,
@@ -513,9 +510,9 @@
 	})
 }
 
-func (c *commandHandler) GenerateGoplsMod(args command.URIArg) error {
+func (c *commandHandler) GenerateGoplsMod(ctx context.Context, args command.URIArg) error {
 	// TODO: go back to using URI
-	return c.run(commandConfig{
+	return c.run(ctx, commandConfig{
 		requireSave: true,
 		progress:    "Generating gopls.mod",
 	}, func(ctx context.Context, deps commandDeps) error {
diff --git a/internal/lsp/command/command_gen.go b/internal/lsp/command/command_gen.go
index 2bd2170..7fda594 100644
--- a/internal/lsp/command/command_gen.go
+++ b/internal/lsp/command/command_gen.go
@@ -11,6 +11,7 @@
 // Code generated by generate.go. DO NOT EDIT.
 
 import (
+	"context"
 	"fmt"
 
 	"golang.org/x/tools/internal/lsp/protocol"
@@ -54,77 +55,77 @@
 	Vendor,
 }
 
-func Dispatch(params *protocol.ExecuteCommandParams, s Interface) (interface{}, error) {
+func Dispatch(ctx context.Context, params *protocol.ExecuteCommandParams, s Interface) (interface{}, error) {
 	switch params.Command {
 	case "gopls.add_dependency":
 		var a0 DependencyArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.AddDependency(a0)
+		err := s.AddDependency(ctx, a0)
 		return nil, err
 	case "gopls.apply_fix":
 		var a0 ApplyFixArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.ApplyFix(a0)
+		err := s.ApplyFix(ctx, a0)
 		return nil, err
 	case "gopls.check_upgrades":
 		var a0 CheckUpgradesArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.CheckUpgrades(a0)
+		err := s.CheckUpgrades(ctx, a0)
 		return nil, err
 	case "gopls.gc_details":
 		var a0 protocol.DocumentURI
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.GCDetails(a0)
+		err := s.GCDetails(ctx, a0)
 		return nil, err
 	case "gopls.generate":
 		var a0 GenerateArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.Generate(a0)
+		err := s.Generate(ctx, a0)
 		return nil, err
 	case "gopls.generate_gopls_mod":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.GenerateGoplsMod(a0)
+		err := s.GenerateGoplsMod(ctx, a0)
 		return nil, err
 	case "gopls.go_get_package":
 		var a0 GoGetPackageArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.GoGetPackage(a0)
+		err := s.GoGetPackage(ctx, a0)
 		return nil, err
 	case "gopls.regenerate_cgo":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.RegenerateCgo(a0)
+		err := s.RegenerateCgo(ctx, a0)
 		return nil, err
 	case "gopls.remove_dependency":
 		var a0 RemoveDependencyArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.RemoveDependency(a0)
+		err := s.RemoveDependency(ctx, a0)
 		return nil, err
 	case "gopls.run_tests":
 		var a0 RunTestsArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.RunTests(a0)
+		err := s.RunTests(ctx, a0)
 		return nil, err
 	case "gopls.test":
 		var a0 protocol.DocumentURI
@@ -133,42 +134,42 @@
 		if err := UnmarshalArgs(params.Arguments, &a0, &a1, &a2); err != nil {
 			return nil, err
 		}
-		err := s.Test(a0, a1, a2)
+		err := s.Test(ctx, a0, a1, a2)
 		return nil, err
 	case "gopls.tidy":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.Tidy(a0)
+		err := s.Tidy(ctx, a0)
 		return nil, err
 	case "gopls.toggle_gc_details":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.ToggleGCDetails(a0)
+		err := s.ToggleGCDetails(ctx, a0)
 		return nil, err
 	case "gopls.update_go_sum":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.UpdateGoSum(a0)
+		err := s.UpdateGoSum(ctx, a0)
 		return nil, err
 	case "gopls.upgrade_dependency":
 		var a0 DependencyArgs
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.UpgradeDependency(a0)
+		err := s.UpgradeDependency(ctx, a0)
 		return nil, err
 	case "gopls.vendor":
 		var a0 URIArg
 		if err := UnmarshalArgs(params.Arguments, &a0); err != nil {
 			return nil, err
 		}
-		err := s.Vendor(a0)
+		err := s.Vendor(ctx, a0)
 		return nil, err
 	}
 	return nil, fmt.Errorf("unsupported command %q", params.Command)
diff --git a/internal/lsp/command/commandmeta/meta.go b/internal/lsp/command/commandmeta/meta.go
index 70b1fe2..c036d7a 100644
--- a/internal/lsp/command/commandmeta/meta.go
+++ b/internal/lsp/command/commandmeta/meta.go
@@ -122,6 +122,15 @@
 		if err != nil {
 			return nil, err
 		}
+		if i == 0 {
+			// Lazy check that the first argument is a context. We could relax this,
+			// but then the generated code gets more complicated.
+			if named, ok := fld.Type.(*types.Named); !ok || named.Obj().Name() != "Context" || named.Obj().Pkg().Path() != "context" {
+				return nil, fmt.Errorf("first method parameter must be context.Context")
+			}
+			// Skip the context argument, as it is implied.
+			continue
+		}
 		c.Args = append(c.Args, fld)
 	}
 	return c, nil
diff --git a/internal/lsp/command/generate/generate.go b/internal/lsp/command/generate/generate.go
index eb24bbd..019574b 100644
--- a/internal/lsp/command/generate/generate.go
+++ b/internal/lsp/command/generate/generate.go
@@ -46,7 +46,7 @@
 {{- end}}
 }
 
-func Dispatch(params *protocol.ExecuteCommandParams, s Interface) (interface{}, error) {
+func Dispatch(ctx context.Context, params *protocol.ExecuteCommandParams, s Interface) (interface{}, error) {
 	switch params.Command {
 	{{- range .Commands}}
 	case "{{.ID}}":
@@ -58,7 +58,7 @@
 			return nil, err
 		}
 		{{end -}}
-		{{- if .Result -}}res, {{end}}err := s.{{.MethodName}}({{block "callargs" .}}{{range $i, $v := .Args}}{{if $i}}, {{end}}a{{$i}}{{end}}{{end}})
+		{{- if .Result -}}res, {{end}}err := s.{{.MethodName}}(ctx{{range $i, $v := .Args}}, a{{$i}}{{end}})
 		return {{if .Result}}res{{else}}nil{{end}}, err
 	{{- end}}
 	}
@@ -67,7 +67,7 @@
 {{- range .Commands}}
 
 func New{{.MethodName}}Command(title string, {{range $i, $v := .Args}}{{if $i}}, {{end}}a{{$i}} {{typeString $v.Type}}{{end}}) (protocol.Command, error) {
-	args, err := MarshalArgs({{template "callargs" .}})
+	args, err := MarshalArgs({{range $i, $v := .Args}}{{if $i}}, {{end}}a{{$i}}{{end}})
 	if err != nil {
 		return protocol.Command{}, err
 	}
@@ -107,7 +107,8 @@
 	d := data{
 		Commands: cmds,
 		Imports: map[string]bool{
-			"fmt": true,
+			"context": true,
+			"fmt":     true,
 			"golang.org/x/tools/internal/lsp/protocol": true,
 		},
 	}
diff --git a/internal/lsp/command/interface.go b/internal/lsp/command/interface.go
index 0cff1a4..9de4b32 100644
--- a/internal/lsp/command/interface.go
+++ b/internal/lsp/command/interface.go
@@ -14,7 +14,11 @@
 
 //go:generate go run -tags=generate generate.go
 
-import "golang.org/x/tools/internal/lsp/protocol"
+import (
+	"context"
+
+	"golang.org/x/tools/internal/lsp/protocol"
+)
 
 // Interface defines the interface gopls exposes for the
 // workspace/executeCommand request.
@@ -31,85 +35,85 @@
 	// ApplyFix: Apply a fix
 	//
 	// Applies a fix to a region of source code.
-	ApplyFix(ApplyFixArgs) error
+	ApplyFix(context.Context, ApplyFixArgs) error
 	// Test: Run test(s) (legacy)
 	//
 	// Runs `go test` for a specific set of test or benchmark functions.
-	Test(protocol.DocumentURI, []string, []string) error
+	Test(context.Context, protocol.DocumentURI, []string, []string) error
 
 	// TODO: deprecate Test in favor of RunTests below.
 
 	// Test: Run test(s)
 	//
 	// Runs `go test` for a specific set of test or benchmark functions.
-	RunTests(RunTestsArgs) error
+	RunTests(context.Context, RunTestsArgs) error
 
 	// Generate: Run go generate
 	//
 	// Runs `go generate` for a given directory.
-	Generate(GenerateArgs) error
+	Generate(context.Context, GenerateArgs) error
 
 	// RegenerateCgo: Regenerate cgo
 	//
 	// Regenerates cgo definitions.
-	RegenerateCgo(URIArg) error
+	RegenerateCgo(context.Context, URIArg) error
 
 	// Tidy: Run go mod tidy
 	//
 	// Runs `go mod tidy` for a module.
-	Tidy(URIArg) error
+	Tidy(context.Context, URIArg) error
 
 	// Vendor: Run go mod vendor
 	//
 	// Runs `go mod vendor` for a module.
-	Vendor(URIArg) error
+	Vendor(context.Context, URIArg) error
 
 	// UpdateGoSum: Update go.sum
 	//
 	// Updates the go.sum file for a module.
-	UpdateGoSum(URIArg) error
+	UpdateGoSum(context.Context, URIArg) error
 
 	// CheckUpgrades: Check for upgrades
 	//
 	// Checks for module upgrades.
-	CheckUpgrades(CheckUpgradesArgs) error
+	CheckUpgrades(context.Context, CheckUpgradesArgs) error
 
 	// AddDependency: Add dependency
 	//
 	// Adds a dependency to the go.mod file for a module.
-	AddDependency(DependencyArgs) error
+	AddDependency(context.Context, DependencyArgs) error
 
 	// UpgradeDependency: Upgrade dependency
 	//
 	// Upgrades a dependency in the go.mod file for a module.
-	UpgradeDependency(DependencyArgs) error
+	UpgradeDependency(context.Context, DependencyArgs) error
 
 	// RemoveDependency: Remove dependency
 	//
 	// Removes a dependency from the go.mod file of a module.
-	RemoveDependency(RemoveDependencyArgs) error
+	RemoveDependency(context.Context, RemoveDependencyArgs) error
 
 	// GoGetPackage: go get package
 	//
 	// Runs `go get` to fetch a package.
-	GoGetPackage(GoGetPackageArgs) error
+	GoGetPackage(context.Context, GoGetPackageArgs) error
 
 	// GCDetails: Toggle gc_details
 	//
 	// Toggle the calculation of gc annotations.
-	GCDetails(protocol.DocumentURI) error
+	GCDetails(context.Context, protocol.DocumentURI) error
 
 	// TODO: deprecate GCDetails in favor of ToggleGCDetails below.
 
 	// ToggleGCDetails: Toggle gc_details
 	//
 	// Toggle the calculation of gc annotations.
-	ToggleGCDetails(URIArg) error
+	ToggleGCDetails(context.Context, URIArg) error
 
 	// GenerateGoplsMod: Generate gopls.mod
 	//
 	// (Re)generate the gopls.mod file for a workspace.
-	GenerateGoplsMod(URIArg) error
+	GenerateGoplsMod(context.Context, URIArg) error
 }
 
 type RunTestsArgs struct {