internal/lsp: move some per-command set-up into a helper

Create the 'prepareAndRun' helper to offload some common command set-up
within the command handler. In subsequent CLs, this will be used to hold
all configuration of the implementation, including whether the command
will execute asynchronously, and whether to show progress.

Change-Id: I6d0f072e805dade5c7df37fa5cdf993d397fa717
Reviewed-on: https://go-review.googlesource.com/c/tools/+/288494
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 8f8f916..a4492cd 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -46,30 +46,6 @@
 	if !match {
 		return nil, fmt.Errorf("%s is not a supported command", command.ID())
 	}
-	// Some commands require that all files are saved to disk. If we detect
-	// unsaved files, warn the user instead of running the commands.
-	unsaved := false
-	for _, overlay := range s.session.Overlays() {
-		if !overlay.Saved() {
-			unsaved = true
-			break
-		}
-	}
-	if unsaved {
-		switch params.Command {
-		case source.CommandTest.ID(),
-			source.CommandGenerate.ID(),
-			source.CommandToggleDetails.ID(),
-			source.CommandAddDependency.ID(),
-			source.CommandUpgradeDependency.ID(),
-			source.CommandRemoveDependency.ID(),
-			source.CommandVendor.ID():
-			// TODO(PJW): for Toggle, not an error if it is being disabled
-			err := errors.New("All files must be saved first")
-			s.showCommandError(ctx, command.Title, err)
-			return nil, nil
-		}
-	}
 	ctx, cancel := context.WithCancel(xcontext.Detach(ctx))
 
 	var work *workDone
@@ -148,6 +124,31 @@
 	}
 }
 
+type commandConfig struct {
+	requireSave bool                 // whether all files must be saved for the command to work
+	forURI      protocol.DocumentURI // URI to resolve to a snapshot. If unset, snapshot will be nil.
+}
+
+func (s *Server) prepareAndRun(ctx context.Context, cfg commandConfig, run func(source.Snapshot) error) error {
+	if cfg.requireSave {
+		for _, overlay := range s.session.Overlays() {
+			if !overlay.Saved() {
+				return errors.New("All files must be saved first")
+			}
+		}
+	}
+	var snapshot source.Snapshot
+	if cfg.forURI != "" {
+		snap, _, ok, release, err := s.beginFileRequest(ctx, cfg.forURI, source.UnknownKind)
+		defer release()
+		if !ok {
+			return err
+		}
+		snapshot = snap
+	}
+	return run(snapshot)
+}
+
 func (s *Server) runCommand(ctx context.Context, work *workDone, command *source.Command, args []json.RawMessage) (err error) {
 	// If the command has a suggested fix function available, use it and apply
 	// the edits to the workspace.
@@ -161,24 +162,24 @@
 		if err := source.UnmarshalArgs(args, &uri, &tests, &benchmarks); err != nil {
 			return err
 		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		return s.runTests(ctx, snapshot, uri, work, tests, benchmarks)
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			return s.runTests(ctx, snapshot, uri, work, tests, benchmarks)
+		})
 	case source.CommandGenerate:
 		var uri protocol.DocumentURI
 		var recursive bool
 		if err := source.UnmarshalArgs(args, &uri, &recursive); err != nil {
 			return err
 		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		return s.runGoGenerate(ctx, snapshot, uri.SpanURI(), recursive, work)
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			return s.runGoGenerate(ctx, snapshot, uri.SpanURI(), recursive, work)
+		})
 	case source.CommandRegenerateCgo:
 		var uri protocol.DocumentURI
 		if err := source.UnmarshalArgs(args, &uri); err != nil {
@@ -194,29 +195,29 @@
 		if err := source.UnmarshalArgs(args, &uri); err != nil {
 			return err
 		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
 		// The flow for `go mod tidy` and `go mod vendor` is almost identical,
 		// so we combine them into one case for convenience.
 		action := "tidy"
 		if command == source.CommandVendor {
 			action = "vendor"
 		}
-		return runSimpleGoCommand(ctx, snapshot, source.UpdateUserModFile|source.AllowNetwork, uri.SpanURI(), "mod", []string{action})
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			return runSimpleGoCommand(ctx, snapshot, source.UpdateUserModFile|source.AllowNetwork, uri.SpanURI(), "mod", []string{action})
+		})
 	case source.CommandUpdateGoSum:
 		var uri protocol.DocumentURI
 		if err := source.UnmarshalArgs(args, &uri); err != nil {
 			return err
 		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		return runSimpleGoCommand(ctx, snapshot, source.UpdateUserModFile|source.AllowNetwork, uri.SpanURI(), "list", []string{"all"})
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			return runSimpleGoCommand(ctx, snapshot, source.UpdateUserModFile|source.AllowNetwork, uri.SpanURI(), "list", []string{"all"})
+		})
 	case source.CommandCheckUpgrades:
 		var uri protocol.DocumentURI
 		var modules []string
@@ -243,12 +244,12 @@
 		if err := source.UnmarshalArgs(args, &uri, &addRequire, &goCmdArgs); err != nil {
 			return err
 		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		return s.runGoGetModule(ctx, snapshot, uri.SpanURI(), addRequire, goCmdArgs)
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			return s.runGoGetModule(ctx, snapshot, uri.SpanURI(), addRequire, goCmdArgs)
+		})
 	case source.CommandRemoveDependency:
 		var uri protocol.DocumentURI
 		var modulePath string
@@ -256,6 +257,86 @@
 		if err := source.UnmarshalArgs(args, &uri, &onlyDiagnostic, &modulePath); err != nil {
 			return err
 		}
+		return s.removeDependency(ctx, modulePath, uri, onlyDiagnostic)
+	case source.CommandGoGetPackage:
+		var uri protocol.DocumentURI
+		var pkg string
+		var addRequire bool
+		if err := source.UnmarshalArgs(args, &uri, &addRequire, &pkg); err != nil {
+			return err
+		}
+		return s.prepareAndRun(ctx, commandConfig{
+			forURI: uri,
+		}, func(snapshot source.Snapshot) error {
+			return s.runGoGetPackage(ctx, snapshot, uri.SpanURI(), addRequire, pkg)
+		})
+
+	case source.CommandToggleDetails:
+		var uri protocol.DocumentURI
+		if err := source.UnmarshalArgs(args, &uri); err != nil {
+			return err
+		}
+		return s.prepareAndRun(ctx, commandConfig{
+			requireSave: true,
+			forURI:      uri,
+		}, func(snapshot source.Snapshot) error {
+			pkgDir := span.URIFromPath(filepath.Dir(uri.SpanURI().Filename()))
+			s.gcOptimizationDetailsMu.Lock()
+			if _, ok := s.gcOptimizationDetails[pkgDir]; ok {
+				delete(s.gcOptimizationDetails, pkgDir)
+				s.clearDiagnosticSource(gcDetailsSource)
+			} else {
+				s.gcOptimizationDetails[pkgDir] = struct{}{}
+			}
+			s.gcOptimizationDetailsMu.Unlock()
+			s.diagnoseSnapshot(snapshot, nil, false)
+			return nil
+		})
+	case source.CommandGenerateGoplsMod:
+		var v source.View
+		if len(args) == 0 {
+			views := s.session.Views()
+			if len(views) != 1 {
+				return fmt.Errorf("cannot resolve view: have %d views", len(views))
+			}
+			v = views[0]
+		} else {
+			var uri protocol.DocumentURI
+			if err := source.UnmarshalArgs(args, &uri); err != nil {
+				return err
+			}
+			var err error
+			v, err = s.session.ViewOf(uri.SpanURI())
+			if err != nil {
+				return err
+			}
+		}
+		snapshot, release := v.Snapshot(ctx)
+		defer release()
+		modFile, err := cache.BuildGoplsMod(ctx, v.Folder(), snapshot)
+		if err != nil {
+			return errors.Errorf("getting workspace mod file: %w", err)
+		}
+		content, err := modFile.Format()
+		if err != nil {
+			return errors.Errorf("formatting mod file: %w", err)
+		}
+		filename := filepath.Join(v.Folder().Filename(), "gopls.mod")
+		if err := ioutil.WriteFile(filename, content, 0644); err != nil {
+			return errors.Errorf("writing mod file: %w", err)
+		}
+	default:
+		return fmt.Errorf("unsupported command: %s", command.ID())
+	}
+	return nil
+}
+
+func (s *Server) removeDependency(ctx context.Context, modulePath string, uri protocol.DocumentURI, onlyDiagnostic bool) error {
+	return s.prepareAndRun(ctx, commandConfig{
+		requireSave: true,
+		forURI:      uri,
+	}, func(source.Snapshot) error {
+
 		snapshot, fh, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
 		defer release()
 		if !ok {
@@ -299,79 +380,8 @@
 		if !response.Applied {
 			return fmt.Errorf("edits not applied because of %s", response.FailureReason)
 		}
-	case source.CommandGoGetPackage:
-		var uri protocol.DocumentURI
-		var pkg string
-		var addRequire bool
-		if err := source.UnmarshalArgs(args, &uri, &addRequire, &pkg); err != nil {
-			return err
-		}
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		return s.runGoGetPackage(ctx, snapshot, uri.SpanURI(), addRequire, pkg)
-
-	case source.CommandToggleDetails:
-		var fileURI protocol.DocumentURI
-		if err := source.UnmarshalArgs(args, &fileURI); err != nil {
-			return err
-		}
-		pkgDir := span.URIFromPath(filepath.Dir(fileURI.SpanURI().Filename()))
-		s.gcOptimizationDetailsMu.Lock()
-		if _, ok := s.gcOptimizationDetails[pkgDir]; ok {
-			delete(s.gcOptimizationDetails, pkgDir)
-			s.clearDiagnosticSource(gcDetailsSource)
-		} else {
-			s.gcOptimizationDetails[pkgDir] = struct{}{}
-		}
-		s.gcOptimizationDetailsMu.Unlock()
-		// need to recompute diagnostics.
-		// so find the snapshot
-		snapshot, _, ok, release, err := s.beginFileRequest(ctx, fileURI, source.UnknownKind)
-		defer release()
-		if !ok {
-			return err
-		}
-		s.diagnoseSnapshot(snapshot, nil, false)
-	case source.CommandGenerateGoplsMod:
-		var v source.View
-		if len(args) == 0 {
-			views := s.session.Views()
-			if len(views) != 1 {
-				return fmt.Errorf("cannot resolve view: have %d views", len(views))
-			}
-			v = views[0]
-		} else {
-			var uri protocol.DocumentURI
-			if err := source.UnmarshalArgs(args, &uri); err != nil {
-				return err
-			}
-			var err error
-			v, err = s.session.ViewOf(uri.SpanURI())
-			if err != nil {
-				return err
-			}
-		}
-		snapshot, release := v.Snapshot(ctx)
-		defer release()
-		modFile, err := cache.BuildGoplsMod(ctx, v.Folder(), snapshot)
-		if err != nil {
-			return errors.Errorf("getting workspace mod file: %w", err)
-		}
-		content, err := modFile.Format()
-		if err != nil {
-			return errors.Errorf("formatting mod file: %w", err)
-		}
-		filename := filepath.Join(v.Folder().Filename(), "gopls.mod")
-		if err := ioutil.WriteFile(filename, content, 0644); err != nil {
-			return errors.Errorf("writing mod file: %w", err)
-		}
-	default:
-		return fmt.Errorf("unsupported command: %s", command.ID())
-	}
-	return nil
+		return nil
+	})
 }
 
 // dropDependency returns the edits to remove the given require from the go.mod