internal/lsp/progress: refactor progress reporting

Progress reporting has gotten complicated, and has had a couple bugs.
Factor out progress-related behavior to a new progressTracker type, and
use this to implement some unit tests.

Also rename some methods to remove stuttering, and reorganize the code
to be more logical.

Fixes golang/go#40527

Change-Id: I93d53a67982460e7171f892021e99f4523fe3e5d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/247407
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 8dbd9a8..28c8d48 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -199,7 +199,7 @@
 
 	ew := &eventWriter{ctx: ctx, operation: "test"}
 	msg := fmt.Sprintf("running `go test %s`", strings.Join(args, " "))
-	wc := s.newProgressWriter(ctx, "test", msg, msg, token, cancel)
+	wc := s.progress.newWriter(ctx, "test", msg, msg, token, cancel)
 	defer wc.Close()
 
 	messageType := protocol.Info
@@ -228,7 +228,7 @@
 	defer cancel()
 
 	er := &eventWriter{ctx: ctx, operation: "generate"}
-	wc := s.newProgressWriter(ctx, GenerateWorkDoneTitle, "running go generate", "started go generate, check logs for progress", token, cancel)
+	wc := s.progress.newWriter(ctx, GenerateWorkDoneTitle, "running go generate", "started go generate, check logs for progress", token, cancel)
 	defer wc.Close()
 	args := []string{"-x"}
 	if recursive {
diff --git a/internal/lsp/general.go b/internal/lsp/general.go
index 1fc7d0c..ddb8ecf 100644
--- a/internal/lsp/general.go
+++ b/internal/lsp/general.go
@@ -31,8 +31,7 @@
 	s.state = serverInitializing
 	s.stateMu.Unlock()
 
-	s.supportsWorkDoneProgress = params.Capabilities.Window.WorkDoneProgress
-	s.inProgress = map[protocol.ProgressToken]*WorkDone{}
+	s.progress.supportsWorkDoneProgress = params.Capabilities.Window.WorkDoneProgress
 
 	options := s.session.Options()
 	defer func() { s.session.SetOptions(options) }()
@@ -185,11 +184,11 @@
 
 	var wg sync.WaitGroup
 	if s.session.Options().VerboseWorkDoneProgress {
-		work := s.StartWork(ctx, DiagnosticWorkTitle(FromInitialWorkspaceLoad), "Calculating diagnostics for initial workspace load...", nil, nil)
+		work := s.progress.start(ctx, DiagnosticWorkTitle(FromInitialWorkspaceLoad), "Calculating diagnostics for initial workspace load...", nil, nil)
 		defer func() {
 			go func() {
 				wg.Wait()
-				work.End(ctx, "Done.")
+				work.end(ctx, "Done.")
 			}()
 		}()
 	}
diff --git a/internal/lsp/progress.go b/internal/lsp/progress.go
index c628b24..fc63fd2 100644
--- a/internal/lsp/progress.go
+++ b/internal/lsp/progress.go
@@ -6,30 +6,36 @@
 
 import (
 	"context"
-	"errors"
 	"io"
 	"math/rand"
 	"strconv"
+	"sync"
 
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/lsp/debug/tag"
 	"golang.org/x/tools/internal/lsp/protocol"
+	errors "golang.org/x/xerrors"
 )
 
-// WorkDone represents a unit of work that is reported to the client via the
-// progress API.
-type WorkDone struct {
-	client   protocol.Client
-	startErr error
-	token    protocol.ProgressToken
-	cancel   func()
-	cleanup  func()
+type progressTracker struct {
+	client                   protocol.Client
+	supportsWorkDoneProgress bool
+
+	mu         sync.Mutex
+	inProgress map[protocol.ProgressToken]*workDone
 }
 
-// StartWork issues a $/progress notification to begin a unit of work on the
+func newProgressTracker(client protocol.Client) *progressTracker {
+	return &progressTracker{
+		client:     client,
+		inProgress: make(map[protocol.ProgressToken]*workDone),
+	}
+}
+
+// start issues a $/progress notification to begin a unit of work on the
 // server. The returned WorkDone handle may be used to report incremental
 // progress, and to report work completion. In particular, it is an error to
-// call StartWork and not call End(...) on the returned WorkDone handle.
+// call start and not call end(...) on the returned WorkDone handle.
 //
 // If token is empty, a token will be randomly generated.
 //
@@ -40,24 +46,24 @@
 //  func Generate(ctx) (err error) {
 //    ctx, cancel := context.WithCancel(ctx)
 //    defer cancel()
-//    work := s.StartWork(ctx, "generate", "running go generate", cancel)
+//    work := s.progress.start(ctx, "generate", "running go generate", cancel)
 //    defer func() {
 //      if err != nil {
-//        work.End(ctx, fmt.Sprintf("generate failed: %v", err))
+//        work.end(ctx, fmt.Sprintf("generate failed: %v", err))
 //      } else {
-//        work.End(ctx, "done")
+//        work.end(ctx, "done")
 //      }
 //    }()
 //    // Do the work...
 //  }
 //
-func (s *Server) StartWork(ctx context.Context, title, message string, token protocol.ProgressToken, cancel func()) *WorkDone {
-	wd := &WorkDone{
-		client: s.client,
+func (t *progressTracker) start(ctx context.Context, title, message string, token protocol.ProgressToken, cancel func()) *workDone {
+	wd := &workDone{
+		client: t.client,
 		token:  token,
 		cancel: cancel,
 	}
-	if !s.supportsWorkDoneProgress {
+	if !t.supportsWorkDoneProgress {
 		wd.startErr = errors.New("workdone reporting is not supported")
 		return wd
 	}
@@ -72,9 +78,13 @@
 			return wd
 		}
 	}
-	s.addInProgress(wd)
+	t.mu.Lock()
+	t.inProgress[wd.token] = wd
+	t.mu.Unlock()
 	wd.cleanup = func() {
-		s.removeInProgress(wd.token)
+		t.mu.Lock()
+		delete(t.inProgress, token)
+		t.mu.Unlock()
 	}
 	err := wd.client.Progress(ctx, &protocol.ProgressParams{
 		Token: wd.token,
@@ -91,8 +101,44 @@
 	return wd
 }
 
-// Progress reports an update on WorkDone progress back to the client.
-func (wd *WorkDone) Progress(ctx context.Context, message string, percentage float64) error {
+func (t *progressTracker) cancel(ctx context.Context, token protocol.ProgressToken) error {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	wd, ok := t.inProgress[token]
+	if !ok {
+		return errors.Errorf("token %q not found in progress", token)
+	}
+	if wd.cancel == nil {
+		return errors.Errorf("work %q is not cancellable", token)
+	}
+	wd.cancel()
+	return nil
+}
+
+// newProgressWriter returns an io.WriterCloser that can be used
+// to report progress on a command based on the client capabilities.
+func (t *progressTracker) newWriter(ctx context.Context, title, beginMsg, msg string, token protocol.ProgressToken, cancel func()) io.WriteCloser {
+	if t.supportsWorkDoneProgress {
+		wd := t.start(ctx, title, beginMsg, token, cancel)
+		return &workDoneWriter{ctx, wd}
+	}
+	mw := &messageWriter{ctx, cancel, t.client}
+	mw.start(msg)
+	return mw
+}
+
+// workDone represents a unit of work that is reported to the client via the
+// progress API.
+type workDone struct {
+	client   protocol.Client
+	startErr error
+	token    protocol.ProgressToken
+	cancel   func()
+	cleanup  func()
+}
+
+// report reports an update on WorkDone report back to the client.
+func (wd *workDone) report(ctx context.Context, message string, percentage float64) error {
 	if wd.startErr != nil {
 		return wd.startErr
 	}
@@ -110,14 +156,14 @@
 	})
 }
 
-// End reports a workdone completion back to the client.
-func (wd *WorkDone) End(ctx context.Context, message string) error {
+// end reports a workdone completion back to the client.
+func (wd *workDone) end(ctx context.Context, message string) error {
 	if wd.startErr != nil {
 		return wd.startErr
 	}
 	err := wd.client.Progress(ctx, &protocol.ProgressParams{
 		Token: wd.token,
-		Value: protocol.WorkDoneProgressEnd{
+		Value: &protocol.WorkDoneProgressEnd{
 			Kind:    "end",
 			Message: message,
 		},
@@ -141,18 +187,6 @@
 	return len(p), nil
 }
 
-// newProgressWriter returns an io.WriterCloser that can be used
-// to report progress on a command based on the client capabilities.
-func (s *Server) newProgressWriter(ctx context.Context, title, beginMsg, msg string, token protocol.ProgressToken, cancel func()) io.WriteCloser {
-	if s.supportsWorkDoneProgress {
-		wd := s.StartWork(ctx, title, beginMsg, token, cancel)
-		return &workDoneWriter{ctx, wd}
-	}
-	mw := &messageWriter{ctx, cancel, s.client}
-	mw.start(msg)
-	return mw
-}
-
 // messageWriter implements progressWriter and only tells the user that
 // a command has started through window/showMessage, but does not report
 // anything afterwards. This is because each log shows up as a separate window
@@ -201,13 +235,13 @@
 // be rendered done.
 type workDoneWriter struct {
 	ctx context.Context
-	wd  *WorkDone
+	wd  *workDone
 }
 
 func (wdw *workDoneWriter) Write(p []byte) (n int, err error) {
-	return len(p), wdw.wd.Progress(wdw.ctx, string(p), 0)
+	return len(p), wdw.wd.report(wdw.ctx, string(p), 0)
 }
 
 func (wdw *workDoneWriter) Close() error {
-	return wdw.wd.End(wdw.ctx, "finished")
+	return wdw.wd.end(wdw.ctx, "finished")
 }
diff --git a/internal/lsp/progress_test.go b/internal/lsp/progress_test.go
new file mode 100644
index 0000000..470127e
--- /dev/null
+++ b/internal/lsp/progress_test.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lsp
+
+import (
+	"context"
+	"fmt"
+	"testing"
+
+	"golang.org/x/tools/internal/lsp/protocol"
+)
+
+type fakeClient struct {
+	protocol.Client
+
+	token protocol.ProgressToken
+
+	created, begun, reported, ended int
+}
+
+func (c *fakeClient) checkToken(token protocol.ProgressToken) {
+	if token == nil {
+		panic("nil token in progress message")
+	}
+	if c.token != nil && c.token != token {
+		panic(fmt.Errorf("invalid token in progress message: got %v, want %v", token, c.token))
+	}
+}
+
+func (c *fakeClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) error {
+	c.checkToken(params.Token)
+	c.created++
+	return nil
+}
+
+func (c *fakeClient) Progress(ctx context.Context, params *protocol.ProgressParams) error {
+	c.checkToken(params.Token)
+	switch params.Value.(type) {
+	case *protocol.WorkDoneProgressBegin:
+		c.begun++
+	case *protocol.WorkDoneProgressReport:
+		c.reported++
+	case *protocol.WorkDoneProgressEnd:
+		c.ended++
+	default:
+		panic(fmt.Errorf("unknown progress value %T", params.Value))
+	}
+	return nil
+}
+
+func setup(token protocol.ProgressToken) (context.Context, *progressTracker, *fakeClient) {
+	c := &fakeClient{}
+	tracker := newProgressTracker(c)
+	tracker.supportsWorkDoneProgress = true
+	return context.Background(), tracker, c
+}
+
+func TestProgressTracker_Reporting(t *testing.T) {
+	for _, test := range []struct {
+		name                                            string
+		supported                                       bool
+		token                                           protocol.ProgressToken
+		wantReported, wantCreated, wantBegun, wantEnded int
+	}{
+		{
+			name: "unsupported",
+		},
+		{
+			name:         "random token",
+			supported:    true,
+			wantCreated:  1,
+			wantBegun:    1,
+			wantReported: 1,
+			wantEnded:    1,
+		},
+		{
+			name:         "string token",
+			supported:    true,
+			token:        "token",
+			wantBegun:    1,
+			wantReported: 1,
+			wantEnded:    1,
+		},
+		{
+			name:         "numeric token",
+			supported:    true,
+			token:        1,
+			wantReported: 1,
+			wantBegun:    1,
+			wantEnded:    1,
+		},
+	} {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			ctx, tracker, client := setup(test.token)
+			tracker.supportsWorkDoneProgress = test.supported
+			work := tracker.start(ctx, "work", "message", test.token, nil)
+			if got := client.created; got != test.wantCreated {
+				t.Errorf("got %d created tokens, want %d", got, test.wantCreated)
+			}
+			if got := client.begun; got != test.wantBegun {
+				t.Errorf("got %d work begun, want %d", got, test.wantBegun)
+			}
+			// Ignore errors: this is just testing the reporting behavior.
+			work.report(ctx, "report", 50)
+			if got := client.reported; got != test.wantReported {
+				t.Errorf("got %d progress reports, want %d", got, test.wantCreated)
+			}
+			work.end(ctx, "done")
+			if got := client.ended; got != test.wantEnded {
+				t.Errorf("got %d ended reports, want %d", got, test.wantEnded)
+			}
+		})
+	}
+}
+
+func TestProgressTracker_Cancellation(t *testing.T) {
+	for _, token := range []protocol.ProgressToken{nil, 1, "a"} {
+		ctx, tracker, _ := setup(token)
+		var cancelled bool
+		cancel := func() { cancelled = true }
+		work := tracker.start(ctx, "work", "message", token, cancel)
+		if err := tracker.cancel(ctx, work.token); err != nil {
+			t.Fatal(err)
+		}
+		if !cancelled {
+			t.Errorf("tracker.cancel(...): cancel not called")
+		}
+	}
+}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 3a733b7..cfcc72e 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -14,7 +14,6 @@
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/source"
 	"golang.org/x/tools/internal/span"
-	errors "golang.org/x/xerrors"
 )
 
 const concurrentAnalyses = 1
@@ -29,6 +28,7 @@
 		session:              session,
 		client:               client,
 		diagnosticsSema:      make(chan struct{}, concurrentAnalyses),
+		progress:             newProgressTracker(client),
 	}
 }
 
@@ -92,11 +92,7 @@
 	// diagnosticsSema limits the concurrency of diagnostics runs, which can be expensive.
 	diagnosticsSema chan struct{}
 
-	// supportsWorkDoneProgress is set in the initializeRequest
-	// to determine if the client can support progress notifications
-	supportsWorkDoneProgress bool
-	inProgressMu             sync.Mutex
-	inProgress               map[protocol.ProgressToken]*WorkDone
+	progress *progressTracker
 }
 
 // sentDiagnostics is used to cache diagnostics that have been sent for a given file.
@@ -139,32 +135,6 @@
 	return nil, notImplemented(method)
 }
 
-func (s *Server) workDoneProgressCancel(ctx context.Context, params *protocol.WorkDoneProgressCancelParams) error {
-	s.inProgressMu.Lock()
-	defer s.inProgressMu.Unlock()
-	wd, ok := s.inProgress[params.Token]
-	if !ok {
-		return errors.Errorf("token %q not found in progress", params.Token)
-	}
-	if wd.cancel == nil {
-		return errors.Errorf("work %q is not cancellable", params.Token)
-	}
-	wd.cancel()
-	return nil
-}
-
-func (s *Server) addInProgress(wd *WorkDone) {
-	s.inProgressMu.Lock()
-	s.inProgress[wd.token] = wd
-	s.inProgressMu.Unlock()
-}
-
-func (s *Server) removeInProgress(token protocol.ProgressToken) {
-	s.inProgressMu.Lock()
-	delete(s.inProgress, token)
-	s.inProgressMu.Unlock()
-}
-
 func notImplemented(method string) error {
 	return fmt.Errorf("%w: %q not yet implemented", jsonrpc2.ErrMethodNotFound, method)
 }
diff --git a/internal/lsp/server_gen.go b/internal/lsp/server_gen.go
index 6f9eeb8..471f1a4 100644
--- a/internal/lsp/server_gen.go
+++ b/internal/lsp/server_gen.go
@@ -205,5 +205,5 @@
 }
 
 func (s *Server) WorkDoneProgressCancel(ctx context.Context, params *protocol.WorkDoneProgressCancelParams) error {
-	return s.workDoneProgressCancel(ctx, params)
+	return s.progress.cancel(ctx, params.Token)
 }
diff --git a/internal/lsp/text_synchronization.go b/internal/lsp/text_synchronization.go
index 3131a4e..0b207e2 100644
--- a/internal/lsp/text_synchronization.go
+++ b/internal/lsp/text_synchronization.go
@@ -197,11 +197,11 @@
 	// modification.
 	var diagnosticWG sync.WaitGroup
 	if s.session.Options().VerboseWorkDoneProgress {
-		work := s.StartWork(ctx, DiagnosticWorkTitle(cause), "Calculating file diagnostics...", nil, nil)
+		work := s.progress.start(ctx, DiagnosticWorkTitle(cause), "Calculating file diagnostics...", nil, nil)
 		defer func() {
 			go func() {
 				diagnosticWG.Wait()
-				work.End(ctx, "Done.")
+				work.end(ctx, "Done.")
 			}()
 		}()
 	}