internal/lsp/progress: detach context for all progress notifications

Use a detached Context for all progress notifications. In particular,
using a detached Context for the window/workDoneProgress/create
notification avoids races where the $/cancelRequest notification and
create response cross paths, such that the client has created a progress
dialog but the server thinks that starting progress failed.

Also, as a matter of best practice don't store a context on the WorkDone
type, despite the fact that this Context is detached. Instead, only
close over a Context in the WorkDoneWriter, which requires a Context in
order to function but which implements the io.Writer interface.

The TestProgressBarErrors test should now pass reliably.

Fixes golang/go#46930

Change-Id: I0d115ed3a62de97fe545c8dc0403e7bb55f6e481
Reviewed-on: https://go-review.googlesource.com/c/tools/+/409936
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
diff --git a/gopls/internal/regtest/diagnostics/diagnostics_test.go b/gopls/internal/regtest/diagnostics/diagnostics_test.go
index 6ce0cdb..6f5db4c 100644
--- a/gopls/internal/regtest/diagnostics/diagnostics_test.go
+++ b/gopls/internal/regtest/diagnostics/diagnostics_test.go
@@ -1519,7 +1519,6 @@
 // TestProgressBarErrors confirms that critical workspace load errors are shown
 // and updated via progress reports.
 func TestProgressBarErrors(t *testing.T) {
-	t.Skip("too flaky: golang/go#46930")
 	testenv.NeedsGo1Point(t, 14)
 
 	const pkg = `
diff --git a/internal/lsp/cache/load.go b/internal/lsp/cache/load.go
index 41b1ad5..f91c961 100644
--- a/internal/lsp/cache/load.go
+++ b/internal/lsp/cache/load.go
@@ -92,7 +92,7 @@
 	if s.view.Options().VerboseWorkDoneProgress {
 		work := s.view.session.progress.Start(ctx, "Load", fmt.Sprintf("Loading query=%s", query), nil, nil)
 		defer func() {
-			work.End("Done.")
+			work.End(ctx, "Done.")
 		}()
 	}
 
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 9bc979a..862af60 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -109,12 +109,12 @@
 		if deps.work != nil {
 			switch {
 			case errors.Is(err, context.Canceled):
-				deps.work.End("canceled")
+				deps.work.End(ctx, "canceled")
 			case err != nil:
 				event.Error(ctx, "command error", err)
-				deps.work.End("failed")
+				deps.work.End(ctx, "failed")
 			default:
-				deps.work.End("completed")
+				deps.work.End(ctx, "completed")
 			}
 		}
 		return err
diff --git a/internal/lsp/diagnostics.go b/internal/lsp/diagnostics.go
index 7ba6a4a..0837b22 100644
--- a/internal/lsp/diagnostics.go
+++ b/internal/lsp/diagnostics.go
@@ -428,10 +428,10 @@
 	// If an error is already shown to the user, update it or mark it as
 	// resolved.
 	if errMsg == "" {
-		s.criticalErrorStatus.End("Done.")
+		s.criticalErrorStatus.End(ctx, "Done.")
 		s.criticalErrorStatus = nil
 	} else {
-		s.criticalErrorStatus.Report(errMsg, 0)
+		s.criticalErrorStatus.Report(ctx, errMsg, 0)
 	}
 }
 
diff --git a/internal/lsp/general.go b/internal/lsp/general.go
index ab74778..478152b 100644
--- a/internal/lsp/general.go
+++ b/internal/lsp/general.go
@@ -234,7 +234,7 @@
 		defer func() {
 			go func() {
 				wg.Wait()
-				work.End("Done.")
+				work.End(ctx, "Done.")
 			}()
 		}()
 	}
@@ -253,7 +253,7 @@
 		}
 		if err != nil {
 			viewErrors[uri] = err
-			work.End(fmt.Sprintf("Error loading packages: %s", err))
+			work.End(ctx, fmt.Sprintf("Error loading packages: %s", err))
 			continue
 		}
 		var swg sync.WaitGroup
@@ -263,7 +263,7 @@
 			defer swg.Done()
 			defer allFoldersWg.Done()
 			snapshot.AwaitInitialized(ctx)
-			work.End("Finished loading packages.")
+			work.End(ctx, "Finished loading packages.")
 		}()
 
 		// Print each view's environment.
diff --git a/internal/lsp/progress/progress.go b/internal/lsp/progress/progress.go
index d9a01bd..d6794cf 100644
--- a/internal/lsp/progress/progress.go
+++ b/internal/lsp/progress/progress.go
@@ -64,8 +64,8 @@
 //	  // Do the work...
 //	}
 func (t *Tracker) Start(ctx context.Context, title, message string, token protocol.ProgressToken, cancel func()) *WorkDone {
+	ctx = xcontext.Detach(ctx) // progress messages should not be cancelled
 	wd := &WorkDone{
-		ctx:    xcontext.Detach(ctx),
 		client: t.client,
 		token:  token,
 		cancel: cancel,
@@ -78,7 +78,7 @@
 		//
 		// Just show a simple message. Clients can implement workDone progress
 		// reporting to get cancellation support.
-		if err := wd.client.ShowMessage(wd.ctx, &protocol.ShowMessageParams{
+		if err := wd.client.ShowMessage(ctx, &protocol.ShowMessageParams{
 			Type:    protocol.Log,
 			Message: message,
 		}); err != nil {
@@ -123,7 +123,7 @@
 	return wd
 }
 
-func (t *Tracker) Cancel(ctx context.Context, token protocol.ProgressToken) error {
+func (t *Tracker) Cancel(token protocol.ProgressToken) error {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	wd, ok := t.inProgress[token]
@@ -140,8 +140,6 @@
 // WorkDone represents a unit of work that is reported to the client via the
 // progress API.
 type WorkDone struct {
-	// ctx is detached, for sending $/progress updates.
-	ctx    context.Context
 	client protocol.Client
 	// If token is nil, this workDone object uses the ShowMessage API, rather
 	// than $/progress.
@@ -170,7 +168,8 @@
 }
 
 // report reports an update on WorkDone report back to the client.
-func (wd *WorkDone) Report(message string, percentage float64) {
+func (wd *WorkDone) Report(ctx context.Context, message string, percentage float64) {
+	ctx = xcontext.Detach(ctx) // progress messages should not be cancelled
 	if wd == nil {
 		return
 	}
@@ -186,7 +185,7 @@
 		return
 	}
 	message = strings.TrimSuffix(message, "\n")
-	err := wd.client.Progress(wd.ctx, &protocol.ProgressParams{
+	err := wd.client.Progress(ctx, &protocol.ProgressParams{
 		Token: wd.token,
 		Value: &protocol.WorkDoneProgressReport{
 			Kind: "report",
@@ -199,12 +198,13 @@
 		},
 	})
 	if err != nil {
-		event.Error(wd.ctx, "reporting progress", err)
+		event.Error(ctx, "reporting progress", err)
 	}
 }
 
 // end reports a workdone completion back to the client.
-func (wd *WorkDone) End(message string) {
+func (wd *WorkDone) End(ctx context.Context, message string) {
+	ctx = xcontext.Detach(ctx) // progress messages should not be cancelled
 	if wd == nil {
 		return
 	}
@@ -214,12 +214,12 @@
 		// There is a prior error.
 	case wd.token == nil:
 		// We're falling back to message-based reporting.
-		err = wd.client.ShowMessage(wd.ctx, &protocol.ShowMessageParams{
+		err = wd.client.ShowMessage(ctx, &protocol.ShowMessageParams{
 			Type:    protocol.Info,
 			Message: message,
 		})
 	default:
-		err = wd.client.Progress(wd.ctx, &protocol.ProgressParams{
+		err = wd.client.Progress(ctx, &protocol.ProgressParams{
 			Token: wd.token,
 			Value: &protocol.WorkDoneProgressEnd{
 				Kind:    "end",
@@ -228,7 +228,7 @@
 		})
 	}
 	if err != nil {
-		event.Error(wd.ctx, "ending work", err)
+		event.Error(ctx, "ending work", err)
 	}
 	if wd.cleanup != nil {
 		wd.cleanup()
@@ -255,15 +255,17 @@
 // WorkDoneWriter wraps a workDone handle to provide a Writer interface,
 // so that workDone reporting can more easily be hooked into commands.
 type WorkDoneWriter struct {
-	wd *WorkDone
+	// In order to implement the io.Writer interface, we must close over ctx.
+	ctx context.Context
+	wd  *WorkDone
 }
 
 func NewWorkDoneWriter(wd *WorkDone) *WorkDoneWriter {
 	return &WorkDoneWriter{wd: wd}
 }
 
-func (wdw WorkDoneWriter) Write(p []byte) (n int, err error) {
-	wdw.wd.Report(string(p), 0)
+func (wdw *WorkDoneWriter) Write(p []byte) (n int, err error) {
+	wdw.wd.Report(wdw.ctx, string(p), 0)
 	// Don't fail just because of a failure to report progress.
 	return len(p), nil
 }
diff --git a/internal/lsp/progress/progress_test.go b/internal/lsp/progress/progress_test.go
index b3c8219..6e901d1 100644
--- a/internal/lsp/progress/progress_test.go
+++ b/internal/lsp/progress/progress_test.go
@@ -124,14 +124,14 @@
 				t.Errorf("got %d work begun, want %d", gotBegun, test.wantBegun)
 			}
 			// Ignore errors: this is just testing the reporting behavior.
-			work.Report("report", 50)
+			work.Report(ctx, "report", 50)
 			client.mu.Lock()
 			gotReported := client.reported
 			client.mu.Unlock()
 			if gotReported != test.wantReported {
 				t.Errorf("got %d progress reports, want %d", gotReported, test.wantCreated)
 			}
-			work.End("done")
+			work.End(ctx, "done")
 			client.mu.Lock()
 			gotEnded, gotMessages := client.ended, client.messages
 			client.mu.Unlock()
@@ -151,7 +151,7 @@
 		var canceled bool
 		cancel := func() { canceled = true }
 		work := tracker.Start(ctx, "work", "message", token, cancel)
-		if err := tracker.Cancel(ctx, work.Token()); err != nil {
+		if err := tracker.Cancel(work.Token()); err != nil {
 			t.Fatal(err)
 		}
 		if !canceled {
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 3b86f47..fb820cc 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -123,8 +123,8 @@
 	changes      []source.FileModification
 }
 
-func (s *Server) workDoneProgressCancel(ctx context.Context, params *protocol.WorkDoneProgressCancelParams) error {
-	return s.progress.Cancel(ctx, params.Token)
+func (s *Server) workDoneProgressCancel(params *protocol.WorkDoneProgressCancelParams) error {
+	return s.progress.Cancel(params.Token)
 }
 
 func (s *Server) nonstandardRequest(ctx context.Context, method string, params interface{}) (interface{}, error) {
diff --git a/internal/lsp/server_gen.go b/internal/lsp/server_gen.go
index 2062693..93b2f99 100644
--- a/internal/lsp/server_gen.go
+++ b/internal/lsp/server_gen.go
@@ -317,5 +317,5 @@
 }
 
 func (s *Server) WorkDoneProgressCancel(ctx context.Context, params *protocol.WorkDoneProgressCancelParams) error {
-	return s.workDoneProgressCancel(ctx, params)
+	return s.workDoneProgressCancel(params)
 }
diff --git a/internal/lsp/text_synchronization.go b/internal/lsp/text_synchronization.go
index ff153d7..3276a47 100644
--- a/internal/lsp/text_synchronization.go
+++ b/internal/lsp/text_synchronization.go
@@ -211,7 +211,7 @@
 		defer func() {
 			go func() {
 				<-diagnoseDone
-				work.End("Done.")
+				work.End(ctx, "Done.")
 			}()
 		}()
 	}