internal/lsp: use the token supplied by the client for progress

Our WorkDone reporting was generating a random token for each unit of
work, even if a token was supplied by the client.  Change this to use
the client token if it is non-empty, and skip the
workDoneProgress/create request.

After this change we can no longer rely on tokens being a string.
Update our progress tracking accordingly.

For golang/go#40527

Change-Id: I702f739c466efb613b69303aaf07005addd3b5e2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/247321
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 15ff4e2..8dbd9a8 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -107,7 +107,7 @@
 		if !ok {
 			return nil, err
 		}
-		go s.runTest(ctx, snapshot, []string{flag, funcName})
+		go s.runTest(ctx, snapshot, []string{flag, funcName}, params.WorkDoneToken)
 	case source.CommandGenerate:
 		var uri protocol.DocumentURI
 		var recursive bool
@@ -119,7 +119,7 @@
 		if !ok {
 			return nil, err
 		}
-		go s.runGoGenerate(xcontext.Detach(ctx), snapshot, uri.SpanURI(), recursive)
+		go s.runGoGenerate(xcontext.Detach(ctx), snapshot, uri.SpanURI(), recursive, params.WorkDoneToken)
 	case source.CommandRegenerateCgo:
 		var uri protocol.DocumentURI
 		if err := source.UnmarshalArgs(params.Arguments, &uri); err != nil {
@@ -193,13 +193,13 @@
 	return snapshot.RunGoCommandDirect(ctx, verb, args)
 }
 
-func (s *Server) runTest(ctx context.Context, snapshot source.Snapshot, args []string) error {
+func (s *Server) runTest(ctx context.Context, snapshot source.Snapshot, args []string, token protocol.ProgressToken) error {
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
 	ew := &eventWriter{ctx: ctx, operation: "test"}
 	msg := fmt.Sprintf("running `go test %s`", strings.Join(args, " "))
-	wc := s.newProgressWriter(ctx, "test", msg, msg, cancel)
+	wc := s.newProgressWriter(ctx, "test", msg, msg, token, cancel)
 	defer wc.Close()
 
 	messageType := protocol.Info
@@ -223,12 +223,12 @@
 // generate commands. It is exported for testing purposes.
 const GenerateWorkDoneTitle = "generate"
 
-func (s *Server) runGoGenerate(ctx context.Context, snapshot source.Snapshot, uri span.URI, recursive bool) error {
+func (s *Server) runGoGenerate(ctx context.Context, snapshot source.Snapshot, uri span.URI, recursive bool, token protocol.ProgressToken) error {
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
 	er := &eventWriter{ctx: ctx, operation: "generate"}
-	wc := s.newProgressWriter(ctx, GenerateWorkDoneTitle, "running go generate", "started go generate, check logs for progress", cancel)
+	wc := s.newProgressWriter(ctx, GenerateWorkDoneTitle, "running go generate", "started go generate, check logs for progress", token, cancel)
 	defer wc.Close()
 	args := []string{"-x"}
 	if recursive {
diff --git a/internal/lsp/general.go b/internal/lsp/general.go
index 3aeab58..1fc7d0c 100644
--- a/internal/lsp/general.go
+++ b/internal/lsp/general.go
@@ -32,7 +32,7 @@
 	s.stateMu.Unlock()
 
 	s.supportsWorkDoneProgress = params.Capabilities.Window.WorkDoneProgress
-	s.inProgress = map[string]*WorkDone{}
+	s.inProgress = map[protocol.ProgressToken]*WorkDone{}
 
 	options := s.session.Options()
 	defer func() { s.session.SetOptions(options) }()
@@ -185,7 +185,7 @@
 
 	var wg sync.WaitGroup
 	if s.session.Options().VerboseWorkDoneProgress {
-		work := s.StartWork(ctx, DiagnosticWorkTitle(FromInitialWorkspaceLoad), "Calculating diagnostics for initial workspace load...", nil)
+		work := s.StartWork(ctx, DiagnosticWorkTitle(FromInitialWorkspaceLoad), "Calculating diagnostics for initial workspace load...", nil, nil)
 		defer func() {
 			go func() {
 				wg.Wait()
diff --git a/internal/lsp/progress.go b/internal/lsp/progress.go
index 15b9e3d..c628b24 100644
--- a/internal/lsp/progress.go
+++ b/internal/lsp/progress.go
@@ -21,16 +21,17 @@
 type WorkDone struct {
 	client   protocol.Client
 	startErr error
-	token    string
+	token    protocol.ProgressToken
 	cancel   func()
 	cleanup  func()
 }
 
-// StartWork creates a unique token and issues a $/progress notification to
-// begin a unit of work on the server. The returned WorkDone handle may be used
-// to report incremental progress, and to report work completion. In
-// particular, it is an error to call StartWork and not call End(...) on the
-// returned WorkDone handle.
+// StartWork issues a $/progress notification to begin a unit of work on the
+// server. The returned WorkDone handle may be used to report incremental
+// progress, and to report work completion. In particular, it is an error to
+// call StartWork and not call End(...) on the returned WorkDone handle.
+//
+// If token is empty, a token will be randomly generated.
 //
 // The progress item is considered cancellable if the given cancel func is
 // non-nil.
@@ -50,29 +51,32 @@
 //    // Do the work...
 //  }
 //
-func (s *Server) StartWork(ctx context.Context, title, message string, cancel func()) *WorkDone {
+func (s *Server) StartWork(ctx context.Context, title, message string, token protocol.ProgressToken, cancel func()) *WorkDone {
 	wd := &WorkDone{
 		client: s.client,
-		token:  strconv.FormatInt(rand.Int63(), 10),
+		token:  token,
 		cancel: cancel,
 	}
 	if !s.supportsWorkDoneProgress {
 		wd.startErr = errors.New("workdone reporting is not supported")
 		return wd
 	}
-	err := wd.client.WorkDoneProgressCreate(ctx, &protocol.WorkDoneProgressCreateParams{
-		Token: wd.token,
-	})
-	if err != nil {
-		wd.startErr = err
-		event.Error(ctx, "starting work for "+title, err)
-		return wd
+	if wd.token == nil {
+		wd.token = strconv.FormatInt(rand.Int63(), 10)
+		err := wd.client.WorkDoneProgressCreate(ctx, &protocol.WorkDoneProgressCreateParams{
+			Token: wd.token,
+		})
+		if err != nil {
+			wd.startErr = err
+			event.Error(ctx, "starting work for "+title, err)
+			return wd
+		}
 	}
 	s.addInProgress(wd)
 	wd.cleanup = func() {
 		s.removeInProgress(wd.token)
 	}
-	err = wd.client.Progress(ctx, &protocol.ProgressParams{
+	err := wd.client.Progress(ctx, &protocol.ProgressParams{
 		Token: wd.token,
 		Value: &protocol.WorkDoneProgressBegin{
 			Kind:        "begin",
@@ -139,9 +143,9 @@
 
 // newProgressWriter returns an io.WriterCloser that can be used
 // to report progress on a command based on the client capabilities.
-func (s *Server) newProgressWriter(ctx context.Context, title, beginMsg, msg string, cancel func()) io.WriteCloser {
+func (s *Server) newProgressWriter(ctx context.Context, title, beginMsg, msg string, token protocol.ProgressToken, cancel func()) io.WriteCloser {
 	if s.supportsWorkDoneProgress {
-		wd := s.StartWork(ctx, title, beginMsg, cancel)
+		wd := s.StartWork(ctx, title, beginMsg, token, cancel)
 		return &workDoneWriter{ctx, wd}
 	}
 	mw := &messageWriter{ctx, cancel, s.client}
diff --git a/internal/lsp/regtest/env.go b/internal/lsp/regtest/env.go
index 6fbd877..081d92a 100644
--- a/internal/lsp/regtest/env.go
+++ b/internal/lsp/regtest/env.go
@@ -54,7 +54,7 @@
 	// outstandingWork is a map of token->work summary. All tokens are assumed to
 	// be string, though the spec allows for numeric tokens as well.  When work
 	// completes, it is deleted from this map.
-	outstandingWork map[string]*workProgress
+	outstandingWork map[protocol.ProgressToken]*workProgress
 	completedWork   map[string]int
 }
 
@@ -119,7 +119,7 @@
 		Server:  ts,
 		state: State{
 			diagnostics:     make(map[string]*protocol.PublishDiagnosticsParams),
-			outstandingWork: make(map[string]*workProgress),
+			outstandingWork: make(map[protocol.ProgressToken]*workProgress),
 			completedWork:   make(map[string]int),
 		},
 		waiters: make(map[int]*condition),
@@ -186,18 +186,16 @@
 	e.mu.Lock()
 	defer e.mu.Unlock()
 
-	token := m.Token.(string)
-	e.state.outstandingWork[token] = &workProgress{}
+	e.state.outstandingWork[m.Token] = &workProgress{}
 	return nil
 }
 
 func (e *Env) onProgress(_ context.Context, m *protocol.ProgressParams) error {
 	e.mu.Lock()
 	defer e.mu.Unlock()
-	token := m.Token.(string)
-	work, ok := e.state.outstandingWork[token]
+	work, ok := e.state.outstandingWork[m.Token]
 	if !ok {
-		panic(fmt.Sprintf("got progress report for unknown report %s: %v", token, m))
+		panic(fmt.Sprintf("got progress report for unknown report %v: %v", m.Token, m))
 	}
 	v := m.Value.(map[string]interface{})
 	switch kind := v["kind"]; kind {
@@ -208,9 +206,9 @@
 			work.percent = pct.(float64)
 		}
 	case "end":
-		title := e.state.outstandingWork[token].title
+		title := e.state.outstandingWork[m.Token].title
 		e.state.completedWork[title] = e.state.completedWork[title] + 1
-		delete(e.state.outstandingWork, token)
+		delete(e.state.outstandingWork, m.Token)
 	}
 	e.checkConditionsLocked()
 	return nil
diff --git a/internal/lsp/regtest/env_test.go b/internal/lsp/regtest/env_test.go
index af044ed..82fb17f 100644
--- a/internal/lsp/regtest/env_test.go
+++ b/internal/lsp/regtest/env_test.go
@@ -15,7 +15,7 @@
 func TestProgressUpdating(t *testing.T) {
 	e := &Env{
 		state: State{
-			outstandingWork: make(map[string]*workProgress),
+			outstandingWork: make(map[protocol.ProgressToken]*workProgress),
 			completedWork:   make(map[string]int),
 		},
 	}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index dda2100..3a733b7 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -96,7 +96,7 @@
 	// to determine if the client can support progress notifications
 	supportsWorkDoneProgress bool
 	inProgressMu             sync.Mutex
-	inProgress               map[string]*WorkDone
+	inProgress               map[protocol.ProgressToken]*WorkDone
 }
 
 // sentDiagnostics is used to cache diagnostics that have been sent for a given file.
@@ -140,18 +140,14 @@
 }
 
 func (s *Server) workDoneProgressCancel(ctx context.Context, params *protocol.WorkDoneProgressCancelParams) error {
-	token, ok := params.Token.(string)
-	if !ok {
-		return errors.Errorf("expected params.Token to be string but got %T", params.Token)
-	}
 	s.inProgressMu.Lock()
 	defer s.inProgressMu.Unlock()
-	wd, ok := s.inProgress[token]
+	wd, ok := s.inProgress[params.Token]
 	if !ok {
-		return errors.Errorf("token %q not found in progress", token)
+		return errors.Errorf("token %q not found in progress", params.Token)
 	}
 	if wd.cancel == nil {
-		return errors.Errorf("work %q is not cancellable", token)
+		return errors.Errorf("work %q is not cancellable", params.Token)
 	}
 	wd.cancel()
 	return nil
@@ -163,7 +159,7 @@
 	s.inProgressMu.Unlock()
 }
 
-func (s *Server) removeInProgress(token string) {
+func (s *Server) removeInProgress(token protocol.ProgressToken) {
 	s.inProgressMu.Lock()
 	delete(s.inProgress, token)
 	s.inProgressMu.Unlock()
diff --git a/internal/lsp/text_synchronization.go b/internal/lsp/text_synchronization.go
index f5db5b6..3131a4e 100644
--- a/internal/lsp/text_synchronization.go
+++ b/internal/lsp/text_synchronization.go
@@ -197,7 +197,7 @@
 	// modification.
 	var diagnosticWG sync.WaitGroup
 	if s.session.Options().VerboseWorkDoneProgress {
-		work := s.StartWork(ctx, DiagnosticWorkTitle(cause), "Calculating file diagnostics...", nil)
+		work := s.StartWork(ctx, DiagnosticWorkTitle(cause), "Calculating file diagnostics...", nil, nil)
 		defer func() {
 			go func() {
 				diagnosticWG.Wait()