internal/lsp: consolidate progress reporting

This change contains several improvements for progress reporting:

 + Consolidate the 'progressWriter' interface into the workDone
   interface.  Now all progress reporting should use workDone, and the
   workDoneWriter struct is just an io.Writer adapter on top of
   workDone.
 + Factor out the pattern of progress reporting, and use for all
   asynchronous commands.
 + Make several commands that were previously synchronous async.
 + Add a test for cancellation when the WorkDone API is not supported.
 + Always report workdone progress using a detached context.
 + Update 'run tests' to use the -v option, and merge stderr and stdout,
   to increase the amount of information reported.
 + Since $/progress reporting is now always run with a detached context,
   the 'NoOutstandingWork' expectation should now behave correctly. Use
   it in a few places.

A follow-up CL will improve the messages reported on command completion.

For golang/go#40634

Change-Id: I7401ae62f7ed22d76e558ccc046e981622a64b12
Reviewed-on: https://go-review.googlesource.com/c/tools/+/248918
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 1e4c126..cbc9fa9 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -6,12 +6,13 @@
 
 import (
 	"context"
+	"encoding/json"
 	"fmt"
 	"io"
+	"log"
 	"path"
 
 	"golang.org/x/tools/internal/event"
-	"golang.org/x/tools/internal/lsp/debug/tag"
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/source"
 	"golang.org/x/tools/internal/span"
@@ -94,47 +95,78 @@
 		}
 		return nil, nil
 	}
-	// Default commands that don't have suggested fix functions.
+	title := command.Title
+	if title == "" {
+		title = command.Name
+	}
+	ctx, cancel := context.WithCancel(xcontext.Detach(ctx))
+	// Start progress prior to spinning off a goroutine specifically so that
+	// clients are aware of the work item before the command completes. This
+	// matters for regtests, where having a continuous thread of work is
+	// convenient for assertions.
+	work := s.progress.start(ctx, title, "starting...", params.WorkDoneToken, cancel)
+	go func() {
+		defer cancel()
+		err := s.runCommand(ctx, work, command, params.Arguments)
+		switch {
+		case errors.Is(err, context.Canceled):
+			work.end(command.Name + " canceled")
+		case err != nil:
+			event.Error(ctx, fmt.Sprintf("%s: command error", command.Name), err)
+			work.end(command.Name + " failed")
+			// Show a message when work completes with error, because the progress end
+			// message is typically dismissed immediately by LSP clients.
+			s.client.ShowMessage(ctx, &protocol.ShowMessageParams{
+				Type:    protocol.Error,
+				Message: fmt.Sprintf("An error occurred running %s: check the gopls logs.", command.Name),
+			})
+		default:
+			work.end(command.Name + " complete")
+		}
+	}()
+	return nil, nil
+}
+
+func (s *Server) runCommand(ctx context.Context, work *workDone, command *source.Command, args []json.RawMessage) error {
 	switch command {
 	case source.CommandTest:
 		var uri protocol.DocumentURI
 		var tests, benchmarks []string
-		if err := source.UnmarshalArgs(params.Arguments, &uri, &tests, &benchmarks); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &uri, &tests, &benchmarks); err != nil {
+			return err
 		}
 		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
 		defer release()
 		if !ok {
-			return nil, err
+			return err
 		}
-		go s.runTests(ctx, snapshot, uri, params.WorkDoneToken, tests, benchmarks)
+		return s.runTests(ctx, snapshot, uri, work, tests, benchmarks)
 	case source.CommandGenerate:
 		var uri protocol.DocumentURI
 		var recursive bool
-		if err := source.UnmarshalArgs(params.Arguments, &uri, &recursive); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &uri, &recursive); err != nil {
+			return err
 		}
 		snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
 		defer release()
 		if !ok {
-			return nil, err
+			return err
 		}
-		go s.runGoGenerate(xcontext.Detach(ctx), snapshot, uri.SpanURI(), recursive, params.WorkDoneToken)
+		return s.runGoGenerate(ctx, snapshot, uri.SpanURI(), recursive, work)
 	case source.CommandRegenerateCgo:
 		var uri protocol.DocumentURI
-		if err := source.UnmarshalArgs(params.Arguments, &uri); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &uri); err != nil {
+			return err
 		}
 		mod := source.FileModification{
 			URI:    uri.SpanURI(),
 			Action: source.InvalidateMetadata,
 		}
-		err := s.didModifyFiles(ctx, []source.FileModification{mod}, FromRegenerateCgo)
-		return nil, err
+		return s.didModifyFiles(ctx, []source.FileModification{mod}, FromRegenerateCgo)
 	case source.CommandTidy, source.CommandVendor:
 		var uri protocol.DocumentURI
-		if err := source.UnmarshalArgs(params.Arguments, &uri); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &uri); err != nil {
+			return err
 		}
 		// The flow for `go mod tidy` and `go mod vendor` is almost identical,
 		// so we combine them into one case for convenience.
@@ -142,20 +174,18 @@
 		if command == source.CommandVendor {
 			a = "vendor"
 		}
-		err := s.directGoModCommand(ctx, uri, "mod", []string{a}...)
-		return nil, err
+		return s.directGoModCommand(ctx, uri, "mod", []string{a}...)
 	case source.CommandUpgradeDependency:
 		var uri protocol.DocumentURI
 		var goCmdArgs []string
-		if err := source.UnmarshalArgs(params.Arguments, &uri, &goCmdArgs); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &uri, &goCmdArgs); err != nil {
+			return err
 		}
-		err := s.directGoModCommand(ctx, uri, "get", goCmdArgs...)
-		return nil, err
+		return s.directGoModCommand(ctx, uri, "get", goCmdArgs...)
 	case source.CommandToggleDetails:
 		var fileURI span.URI
-		if err := source.UnmarshalArgs(params.Arguments, &fileURI); err != nil {
-			return nil, err
+		if err := source.UnmarshalArgs(args, &fileURI); err != nil {
+			return err
 		}
 		pkgDir := span.URIFromPath(path.Dir(fileURI.Filename()))
 		s.gcOptimizationDetailsMu.Lock()
@@ -171,16 +201,15 @@
 		// so find the snapshot
 		sv, err := s.session.ViewOf(fileURI)
 		if err != nil {
-			return nil, err
+			return err
 		}
 		snapshot, release := sv.Snapshot(ctx)
 		defer release()
 		s.diagnoseSnapshot(snapshot)
-		return nil, nil
 	default:
-		return nil, fmt.Errorf("unknown command: %s", params.Command)
+		return fmt.Errorf("unsupported command: %s", command.Name)
 	}
-	return nil, nil
+	return nil
 }
 
 func (s *Server) directGoModCommand(ctx context.Context, uri protocol.DocumentURI, verb string, args ...string) error {
@@ -193,10 +222,7 @@
 	return snapshot.RunGoCommandDirect(ctx, verb, args)
 }
 
-func (s *Server) runTests(ctx context.Context, snapshot source.Snapshot, uri protocol.DocumentURI, token protocol.ProgressToken, tests, benchmarks []string) error {
-	ctx, cancel := context.WithCancel(ctx)
-	defer cancel()
-
+func (s *Server) runTests(ctx context.Context, snapshot source.Snapshot, uri protocol.DocumentURI, work *workDone, tests, benchmarks []string) error {
 	pkgs, err := snapshot.PackagesForFile(ctx, uri.SpanURI(), source.TypecheckWorkspace)
 	if err != nil {
 		return err
@@ -218,17 +244,15 @@
 	} else {
 		return errors.New("No functions were provided")
 	}
-	msg := fmt.Sprintf("Running %s...", title)
-	wc := s.progress.newWriter(ctx, title, msg, msg, token, cancel)
-	defer wc.Close()
 
-	stderr := io.MultiWriter(ew, wc)
+	out := io.MultiWriter(ew, workDoneWriter{work})
 
-	// run `go test -run Func` on each test
+	// Run `go test -run Func` on each test.
 	var failedTests int
 	for _, funcName := range tests {
-		args := []string{pkgPath, "-run", fmt.Sprintf("^%s$", funcName)}
-		if err := snapshot.RunGoCommandPiped(ctx, "test", args, ew, stderr); err != nil {
+		args := []string{pkgPath, "-v", "-count=1", "-run", fmt.Sprintf("^%s$", funcName)}
+		log.Printf("running with these args: %v", args)
+		if err := snapshot.RunGoCommandPiped(ctx, "test", args, out, out); err != nil {
 			if errors.Is(err, context.Canceled) {
 				return err
 			}
@@ -236,11 +260,11 @@
 		}
 	}
 
-	// run `go test -run=^$ -bench Func` on each test
+	// Run `go test -run=^$ -bench Func` on each test.
 	var failedBenchmarks int
 	for _, funcName := range tests {
-		args := []string{pkgPath, "-run=^$", "-bench", fmt.Sprintf("^%s$", funcName)}
-		if err := snapshot.RunGoCommandPiped(ctx, "test", args, ew, stderr); err != nil {
+		args := []string{pkgPath, "-v", "-run=^$", "-bench", fmt.Sprintf("^%s$", funcName)}
+		if err := snapshot.RunGoCommandPiped(ctx, "test", args, out, out); err != nil {
 			if errors.Is(err, context.Canceled) {
 				return err
 			}
@@ -267,33 +291,23 @@
 	})
 }
 
-// GenerateWorkDoneTitle is the title used in progress reporting for go
-// generate commands. It is exported for testing purposes.
-const GenerateWorkDoneTitle = "generate"
-
-func (s *Server) runGoGenerate(ctx context.Context, snapshot source.Snapshot, uri span.URI, recursive bool, token protocol.ProgressToken) error {
+func (s *Server) runGoGenerate(ctx context.Context, snapshot source.Snapshot, uri span.URI, recursive bool, work *workDone) error {
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
 	er := &eventWriter{ctx: ctx, operation: "generate"}
-	wc := s.progress.newWriter(ctx, GenerateWorkDoneTitle, "running go generate", "started go generate, check logs for progress", token, cancel)
-	defer wc.Close()
 	args := []string{"-x"}
 	if recursive {
 		args = append(args, "./...")
 	}
 
-	stderr := io.MultiWriter(er, wc)
+	stderr := io.MultiWriter(er, workDoneWriter{work})
 
 	if err := snapshot.RunGoCommandPiped(ctx, "generate", args, er, stderr); err != nil {
-		if errors.Is(err, context.Canceled) {
-			return nil
-		}
-		event.Error(ctx, "generate: command error", err, tag.Directory.Of(uri.Filename()))
-		return s.client.ShowMessage(ctx, &protocol.ShowMessageParams{
-			Type:    protocol.Error,
-			Message: "go generate exited with an error, check gopls logs",
-		})
+		return err
 	}
-	return nil
+	return s.client.ShowMessage(ctx, &protocol.ShowMessageParams{
+		Type:    protocol.Info,
+		Message: "go generate complete",
+	})
 }
diff --git a/internal/lsp/general.go b/internal/lsp/general.go
index d499d8f..1ddc993 100644
--- a/internal/lsp/general.go
+++ b/internal/lsp/general.go
@@ -190,7 +190,7 @@
 		defer func() {
 			go func() {
 				wg.Wait()
-				work.end(ctx, "Done.")
+				work.end("Done.")
 			}()
 		}()
 	}
@@ -201,12 +201,12 @@
 		view, snapshot, release, err := s.addView(ctx, folder.Name, uri)
 		if err != nil {
 			viewErrors[uri] = err
-			work.end(ctx, fmt.Sprintf("Error loading packages: %s", err))
+			work.end(fmt.Sprintf("Error loading packages: %s", err))
 			continue
 		}
 		go func() {
 			view.AwaitInitialized(ctx)
-			work.end(ctx, "Finished loading packages.")
+			work.end("Finished loading packages.")
 		}()
 
 		for _, dir := range snapshot.WorkspaceDirectories(ctx) {
diff --git a/internal/lsp/progress.go b/internal/lsp/progress.go
index fc63fd2..cdaf172 100644
--- a/internal/lsp/progress.go
+++ b/internal/lsp/progress.go
@@ -6,7 +6,6 @@
 
 import (
 	"context"
-	"io"
 	"math/rand"
 	"strconv"
 	"sync"
@@ -14,6 +13,7 @@
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/lsp/debug/tag"
 	"golang.org/x/tools/internal/lsp/protocol"
+	"golang.org/x/tools/internal/xcontext"
 	errors "golang.org/x/xerrors"
 )
 
@@ -32,15 +32,16 @@
 	}
 }
 
-// start issues a $/progress notification to begin a unit of work on the
-// server. The returned WorkDone handle may be used to report incremental
+// start notifies the client of work being done on the server. It uses either
+// ShowMessage RPCs or $/progress messages, depending on the capabilities of
+// the client.  The returned WorkDone handle may be used to report incremental
 // progress, and to report work completion. In particular, it is an error to
 // call start and not call end(...) on the returned WorkDone handle.
 //
 // If token is empty, a token will be randomly generated.
 //
 // The progress item is considered cancellable if the given cancel func is
-// non-nil.
+// non-nil. In this case, cancel is called when the work done
 //
 // Example:
 //  func Generate(ctx) (err error) {
@@ -59,25 +60,29 @@
 //
 func (t *progressTracker) start(ctx context.Context, title, message string, token protocol.ProgressToken, cancel func()) *workDone {
 	wd := &workDone{
+		ctx:    xcontext.Detach(ctx),
 		client: t.client,
 		token:  token,
 		cancel: cancel,
 	}
 	if !t.supportsWorkDoneProgress {
-		wd.startErr = errors.New("workdone reporting is not supported")
+		go wd.openStartMessage(message)
 		return wd
 	}
 	if wd.token == nil {
-		wd.token = strconv.FormatInt(rand.Int63(), 10)
+		token = strconv.FormatInt(rand.Int63(), 10)
 		err := wd.client.WorkDoneProgressCreate(ctx, &protocol.WorkDoneProgressCreateParams{
-			Token: wd.token,
+			Token: token,
 		})
 		if err != nil {
-			wd.startErr = err
+			wd.err = err
 			event.Error(ctx, "starting work for "+title, err)
 			return wd
 		}
+		wd.token = token
 	}
+	// At this point we have a token that the client knows about. Store the token
+	// before starting work.
 	t.mu.Lock()
 	t.inProgress[wd.token] = wd
 	t.mu.Unlock()
@@ -111,38 +116,82 @@
 	if wd.cancel == nil {
 		return errors.Errorf("work %q is not cancellable", token)
 	}
-	wd.cancel()
+	wd.doCancel()
 	return nil
 }
 
-// newProgressWriter returns an io.WriterCloser that can be used
-// to report progress on a command based on the client capabilities.
-func (t *progressTracker) newWriter(ctx context.Context, title, beginMsg, msg string, token protocol.ProgressToken, cancel func()) io.WriteCloser {
-	if t.supportsWorkDoneProgress {
-		wd := t.start(ctx, title, beginMsg, token, cancel)
-		return &workDoneWriter{ctx, wd}
-	}
-	mw := &messageWriter{ctx, cancel, t.client}
-	mw.start(msg)
-	return mw
-}
-
 // workDone represents a unit of work that is reported to the client via the
 // progress API.
 type workDone struct {
-	client   protocol.Client
-	startErr error
-	token    protocol.ProgressToken
-	cancel   func()
-	cleanup  func()
+	// ctx is detached, for sending $/progress updates.
+	ctx    context.Context
+	client protocol.Client
+	// If token is nil, this workDone object uses the ShowMessage API, rather
+	// than $/progress.
+	token protocol.ProgressToken
+	// err is set if progress reporting is broken for some reason (for example,
+	// if there was an initial error creating a token).
+	err error
+
+	cancelMu  sync.Mutex
+	cancelled bool
+	cancel    func()
+
+	cleanup func()
+}
+
+func (wd *workDone) openStartMessage(msg string) {
+	go func() {
+		if wd.cancel == nil {
+			err := wd.client.ShowMessage(wd.ctx, &protocol.ShowMessageParams{
+				Type:    protocol.Log,
+				Message: msg,
+			})
+			if err != nil {
+				event.Error(wd.ctx, "error sending message request", err)
+			}
+			return
+		}
+		const cancel = "Cancel"
+		item, err := wd.client.ShowMessageRequest(wd.ctx, &protocol.ShowMessageRequestParams{
+			Type:    protocol.Log,
+			Message: msg,
+			Actions: []protocol.MessageActionItem{{
+				Title: cancel,
+			}},
+		})
+		if err != nil {
+			event.Error(wd.ctx, "error sending message request", err)
+			return
+		}
+		if item != nil && item.Title == cancel {
+			wd.doCancel()
+		}
+	}()
+}
+
+func (wd *workDone) doCancel() {
+	wd.cancelMu.Lock()
+	defer wd.cancelMu.Unlock()
+	if !wd.cancelled {
+		wd.cancel()
+	}
 }
 
 // report reports an update on WorkDone report back to the client.
-func (wd *workDone) report(ctx context.Context, message string, percentage float64) error {
-	if wd.startErr != nil {
-		return wd.startErr
+func (wd *workDone) report(message string, percentage float64) {
+	wd.cancelMu.Lock()
+	cancelled := wd.cancelled
+	wd.cancelMu.Unlock()
+	if cancelled {
+		return
 	}
-	return wd.client.Progress(ctx, &protocol.ProgressParams{
+	if wd.err != nil || wd.token == nil {
+		// Not using the workDone API, so we do nothing. It would be far too spammy
+		// to send incremental messages.
+		return
+	}
+	err := wd.client.Progress(wd.ctx, &protocol.ProgressParams{
 		Token: wd.token,
 		Value: &protocol.WorkDoneProgressReport{
 			Kind: "report",
@@ -154,24 +203,38 @@
 			Percentage:  percentage,
 		},
 	})
+	if err != nil {
+		event.Error(wd.ctx, "reporting progress", err)
+	}
 }
 
 // end reports a workdone completion back to the client.
-func (wd *workDone) end(ctx context.Context, message string) error {
-	if wd.startErr != nil {
-		return wd.startErr
-	}
-	err := wd.client.Progress(ctx, &protocol.ProgressParams{
-		Token: wd.token,
-		Value: &protocol.WorkDoneProgressEnd{
-			Kind:    "end",
+func (wd *workDone) end(message string) {
+	var err error
+	switch {
+	case wd.err != nil:
+		// There is a prior error.
+	case wd.token == nil:
+		// We're falling back to message-based reporting.
+		err = wd.client.ShowMessage(wd.ctx, &protocol.ShowMessageParams{
+			Type:    protocol.Info,
 			Message: message,
-		},
-	})
+		})
+	default:
+		err = wd.client.Progress(wd.ctx, &protocol.ProgressParams{
+			Token: wd.token,
+			Value: &protocol.WorkDoneProgressEnd{
+				Kind:    "end",
+				Message: message,
+			},
+		})
+	}
+	if err != nil {
+		event.Error(wd.ctx, "ending work", err)
+	}
 	if wd.cleanup != nil {
 		wd.cleanup()
 	}
-	return err
 }
 
 // eventWriter writes every incoming []byte to
@@ -187,61 +250,14 @@
 	return len(p), nil
 }
 
-// messageWriter implements progressWriter and only tells the user that
-// a command has started through window/showMessage, but does not report
-// anything afterwards. This is because each log shows up as a separate window
-// and therefore would be obnoxious to show every incoming line. Request
-// cancellation happens synchronously through the ShowMessageRequest response.
-type messageWriter struct {
-	ctx    context.Context
-	cancel func()
-	client protocol.Client
-}
-
-func (lw *messageWriter) Write(p []byte) (n int, err error) {
-	return len(p), nil
-}
-
-func (lw *messageWriter) start(msg string) {
-	go func() {
-		const cancel = "Cancel"
-		item, err := lw.client.ShowMessageRequest(lw.ctx, &protocol.ShowMessageRequestParams{
-			Type:    protocol.Log,
-			Message: msg,
-			Actions: []protocol.MessageActionItem{{
-				Title: "Cancel",
-			}},
-		})
-		if err != nil {
-			event.Error(lw.ctx, "error sending message request", err)
-			return
-		}
-		if item != nil && item.Title == "Cancel" {
-			lw.cancel()
-		}
-	}()
-}
-
-func (lw *messageWriter) Close() error {
-	return lw.client.ShowMessage(lw.ctx, &protocol.ShowMessageParams{
-		Type:    protocol.Info,
-		Message: "go generate has finished",
-	})
-}
-
-// workDoneWriter implements progressWriter by sending $/progress notifications
-// to the client. Request cancellations happens separately through the
-// window/workDoneProgress/cancel request, in which case the given context will
-// be rendered done.
+// workDoneWriter wraps a workDone handle to provide a Writer interface,
+// so that workDone reporting can more easily be hooked into commands.
 type workDoneWriter struct {
-	ctx context.Context
-	wd  *workDone
+	wd *workDone
 }
 
-func (wdw *workDoneWriter) Write(p []byte) (n int, err error) {
-	return len(p), wdw.wd.report(wdw.ctx, string(p), 0)
-}
-
-func (wdw *workDoneWriter) Close() error {
-	return wdw.wd.end(wdw.ctx, "finished")
+func (wdw workDoneWriter) Write(p []byte) (n int, err error) {
+	wdw.wd.report(string(p), 0)
+	// Don't fail just because of a failure to report progress.
+	return len(p), nil
 }
diff --git a/internal/lsp/progress_test.go b/internal/lsp/progress_test.go
index 470127e..1a162a6 100644
--- a/internal/lsp/progress_test.go
+++ b/internal/lsp/progress_test.go
@@ -7,7 +7,9 @@
 import (
 	"context"
 	"fmt"
+	"sync"
 	"testing"
+	"time"
 
 	"golang.org/x/tools/internal/lsp/protocol"
 )
@@ -17,7 +19,10 @@
 
 	token protocol.ProgressToken
 
-	created, begun, reported, ended int
+	mu                                        sync.Mutex
+	created, begun, reported, messages, ended int
+
+	cancel chan struct{}
 }
 
 func (c *fakeClient) checkToken(token protocol.ProgressToken) {
@@ -30,12 +35,16 @@
 }
 
 func (c *fakeClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	c.checkToken(params.Token)
 	c.created++
 	return nil
 }
 
 func (c *fakeClient) Progress(ctx context.Context, params *protocol.ProgressParams) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	c.checkToken(params.Token)
 	switch params.Value.(type) {
 	case *protocol.WorkDoneProgressBegin:
@@ -50,8 +59,26 @@
 	return nil
 }
 
+func (c *fakeClient) ShowMessage(context.Context, *protocol.ShowMessageParams) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.messages++
+	return nil
+}
+
+func (c *fakeClient) ShowMessageRequest(ctx context.Context, params *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) {
+	select {
+	case <-ctx.Done():
+	case <-c.cancel:
+		return &params.Actions[0], nil
+	}
+	return nil, nil
+}
+
 func setup(token protocol.ProgressToken) (context.Context, *progressTracker, *fakeClient) {
-	c := &fakeClient{}
+	c := &fakeClient{
+		cancel: make(chan struct{}),
+	}
 	tracker := newProgressTracker(c)
 	tracker.supportsWorkDoneProgress = true
 	return context.Background(), tracker, c
@@ -63,9 +90,11 @@
 		supported                                       bool
 		token                                           protocol.ProgressToken
 		wantReported, wantCreated, wantBegun, wantEnded int
+		wantMessages                                    int
 	}{
 		{
-			name: "unsupported",
+			name:         "unsupported",
+			wantMessages: 1,
 		},
 		{
 			name:         "random token",
@@ -95,22 +124,36 @@
 		test := test
 		t.Run(test.name, func(t *testing.T) {
 			ctx, tracker, client := setup(test.token)
+			ctx, cancel := context.WithCancel(ctx)
+			defer cancel()
 			tracker.supportsWorkDoneProgress = test.supported
 			work := tracker.start(ctx, "work", "message", test.token, nil)
-			if got := client.created; got != test.wantCreated {
-				t.Errorf("got %d created tokens, want %d", got, test.wantCreated)
+			client.mu.Lock()
+			gotCreated, gotBegun := client.created, client.begun
+			client.mu.Unlock()
+			if gotCreated != test.wantCreated {
+				t.Errorf("got %d created tokens, want %d", gotCreated, test.wantCreated)
 			}
-			if got := client.begun; got != test.wantBegun {
-				t.Errorf("got %d work begun, want %d", got, test.wantBegun)
+			if gotBegun != test.wantBegun {
+				t.Errorf("got %d work begun, want %d", gotBegun, test.wantBegun)
 			}
 			// Ignore errors: this is just testing the reporting behavior.
-			work.report(ctx, "report", 50)
-			if got := client.reported; got != test.wantReported {
-				t.Errorf("got %d progress reports, want %d", got, test.wantCreated)
+			work.report("report", 50)
+			client.mu.Lock()
+			gotReported := client.reported
+			client.mu.Unlock()
+			if gotReported != test.wantReported {
+				t.Errorf("got %d progress reports, want %d", gotReported, test.wantCreated)
 			}
-			work.end(ctx, "done")
-			if got := client.ended; got != test.wantEnded {
-				t.Errorf("got %d ended reports, want %d", got, test.wantEnded)
+			work.end("done")
+			client.mu.Lock()
+			gotEnded, gotMessages := client.ended, client.messages
+			client.mu.Unlock()
+			if gotEnded != test.wantEnded {
+				t.Errorf("got %d ended reports, want %d", gotEnded, test.wantEnded)
+			}
+			if gotMessages != test.wantMessages {
+				t.Errorf("got %d messages, want %d", gotMessages, test.wantMessages)
 			}
 		})
 	}
@@ -119,14 +162,35 @@
 func TestProgressTracker_Cancellation(t *testing.T) {
 	for _, token := range []protocol.ProgressToken{nil, 1, "a"} {
 		ctx, tracker, _ := setup(token)
-		var cancelled bool
-		cancel := func() { cancelled = true }
+		var canceled bool
+		cancel := func() { canceled = true }
 		work := tracker.start(ctx, "work", "message", token, cancel)
 		if err := tracker.cancel(ctx, work.token); err != nil {
 			t.Fatal(err)
 		}
-		if !cancelled {
+		if !canceled {
 			t.Errorf("tracker.cancel(...): cancel not called")
 		}
 	}
 }
+
+func TestProgressTracker_MessageCancellation(t *testing.T) {
+	// Test that progress is canceled via the showMessageRequest dialog,
+	// when workDone is not supported.
+	ctx, tracker, client := setup(nil)
+	tracker.supportsWorkDoneProgress = false
+	ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+	defer cancel()
+
+	canceled := make(chan struct{})
+	canceler := func() { close(canceled) }
+	tracker.start(ctx, "work", "message", nil, canceler)
+
+	close(client.cancel)
+
+	select {
+	case <-canceled:
+	case <-ctx.Done():
+		t.Errorf("timed out waiting for cancel")
+	}
+}
diff --git a/internal/lsp/regtest/codelens_test.go b/internal/lsp/regtest/codelens_test.go
index e5bf57d..cb9443a 100644
--- a/internal/lsp/regtest/codelens_test.go
+++ b/internal/lsp/regtest/codelens_test.go
@@ -60,9 +60,9 @@
 }
 
 // This test confirms the full functionality of the code lenses for updating
-// dependencies in a go.mod file. It checks for the code lens that suggests an
-// update and then executes the command associated with that code lens.
-// A regression test for golang/go#39446.
+// dependencies in a go.mod file. It checks for the code lens that suggests
+// an update and then executes the command associated with that code lens. A
+// regression test for golang/go#39446.
 func TestUpdateCodelens(t *testing.T) {
 	const proxyWithLatest = `
 -- golang.org/x/hello@v1.3.3/go.mod --
@@ -119,6 +119,7 @@
 		}); err != nil {
 			t.Fatal(err)
 		}
+		env.Await(NoOutstandingWork())
 		got := env.ReadWorkspaceFile("go.mod")
 		const wantGoMod = `module mod.com
 
@@ -189,6 +190,7 @@
 		}); err != nil {
 			t.Fatal(err)
 		}
+		env.Await(NoOutstandingWork())
 		got := env.ReadWorkspaceFile("go.mod")
 		const wantGoMod = `module mod.com
 
diff --git a/internal/lsp/regtest/wrappers.go b/internal/lsp/regtest/wrappers.go
index 89bc656..66e19f5 100644
--- a/internal/lsp/regtest/wrappers.go
+++ b/internal/lsp/regtest/wrappers.go
@@ -9,7 +9,6 @@
 	"io"
 	"testing"
 
-	"golang.org/x/tools/internal/lsp"
 	"golang.org/x/tools/internal/lsp/fake"
 	"golang.org/x/tools/internal/lsp/protocol"
 )
@@ -217,7 +216,7 @@
 	if err := e.Editor.RunGenerate(e.Ctx, dir); err != nil {
 		e.T.Fatal(err)
 	}
-	e.Await(CompletedWork(lsp.GenerateWorkDoneTitle, 1))
+	e.Await(NoOutstandingWork())
 	// Ideally the fake.Workspace would handle all synthetic file watching, but
 	// we help it out here as we need to wait for the generate command to
 	// complete before checking the filesystem.
diff --git a/internal/lsp/text_synchronization.go b/internal/lsp/text_synchronization.go
index c709a3a..29e4ca9 100644
--- a/internal/lsp/text_synchronization.go
+++ b/internal/lsp/text_synchronization.go
@@ -185,7 +185,7 @@
 		defer func() {
 			go func() {
 				diagnosticWG.Wait()
-				work.end(ctx, "Done.")
+				work.end("Done.")
 			}()
 		}()
 	}