internal/lsp: use subtests for all lsp categories

This makes it possible to run just one type of test if needed
Also add some verification that the right number of tests is being run
And finally collect all the expectations up front, including the completions.

Change-Id: Iee6045a8ad89fa399fefd03bc0712770701ec6f8
Reviewed-on: https://go-review.googlesource.com/c/149737
Run-TryBot: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 840966d..422c096 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -28,6 +28,9 @@
 
 func testLSP(t *testing.T, exporter packagestest.Exporter) {
 	const dir = "testdata"
+	const expectedCompletionsCount = 4
+	const expectedDiagnosticsCount = 7
+	const expectedFormatCount = 3
 
 	files := packagestest.MustCopyFileTree(dir)
 	for fragment, operation := range files {
@@ -48,9 +51,10 @@
 	dirs := make(map[string]bool)
 
 	// collect results for certain tests
-	expectedDiagnostics := make(map[string][]protocol.Diagnostic)
-	expectedCompletions := make(map[token.Position]*protocol.CompletionItem)
-	expectedFormat := make(map[string]string)
+	expectedDiagnostics := make(diagnostics)
+	completionItems := make(completionItems)
+	expectedCompletions := make(completions)
+	expectedFormat := make(formats)
 
 	s := &server{
 		view: source.NewView(),
@@ -81,70 +85,76 @@
 	}
 	// Collect any data that needs to be used by subsequent tests.
 	if err := exported.Expect(map[string]interface{}{
-		"diag": func(pos token.Position, msg string) {
-			collectDiagnostics(t, expectedDiagnostics, pos, msg)
-		},
-		"item": func(pos token.Position, label, detail, kind string) {
-			collectCompletionItems(expectedCompletions, pos, label, detail, kind)
-		},
-		"format": func(pos token.Position) {
-			collectFormat(expectedFormat, pos)
-		},
+		"diag":     expectedDiagnostics.collect,
+		"item":     completionItems.collect,
+		"complete": expectedCompletions.collect,
+		"format":   expectedFormat.collect,
 	}); err != nil {
 		t.Fatal(err)
 	}
 
-	// test completion
-	testCompletion(t, exported, s, expectedCompletions)
+	t.Run("Completion", func(t *testing.T) {
+		t.Helper()
+		if len(expectedCompletions) != expectedCompletionsCount {
+			t.Errorf("got %v completions expected %v", len(expectedCompletions), expectedCompletionsCount)
+		}
+		expectedCompletions.test(t, exported, s, completionItems)
+	})
 
-	// test diagnostics
-	var dirList []string
-	for dir := range dirs {
-		dirList = append(dirList, dir)
-	}
-	exported.Config.Mode = packages.LoadFiles
-	pkgs, err := packages.Load(exported.Config, dirList...)
-	if err != nil {
-		t.Fatal(err)
-	}
-	testDiagnostics(t, s.view, pkgs, expectedDiagnostics)
+	t.Run("Diagnostics", func(t *testing.T) {
+		t.Helper()
+		diagnosticsCount := expectedDiagnostics.test(t, exported, s.view, dirs)
+		if diagnosticsCount != expectedDiagnosticsCount {
+			t.Errorf("got %v diagnostics expected %v", diagnosticsCount, expectedDiagnosticsCount)
+		}
+	})
 
-	// test format
-	testFormat(t, s, expectedFormat)
+	t.Run("Format", func(t *testing.T) {
+		t.Helper()
+		if len(expectedFormat) != expectedFormatCount {
+			t.Errorf("got %v formats expected %v", len(expectedFormat), expectedFormatCount)
+		}
+		expectedFormat.test(t, s)
+	})
 }
 
-func testCompletion(t *testing.T, exported *packagestest.Exported, s *server, wants map[token.Position]*protocol.CompletionItem) {
-	if err := exported.Expect(map[string]interface{}{
-		"complete": func(src token.Position, expected []token.Position) {
-			var want []protocol.CompletionItem
-			for _, pos := range expected {
-				want = append(want, *wants[pos])
-			}
-			list, err := s.Completion(context.Background(), &protocol.CompletionParams{
-				TextDocumentPositionParams: protocol.TextDocumentPositionParams{
-					TextDocument: protocol.TextDocumentIdentifier{
-						URI: protocol.DocumentURI(source.ToURI(src.Filename)),
-					},
-					Position: protocol.Position{
-						Line:      float64(src.Line - 1),
-						Character: float64(src.Column - 1),
-					},
+type diagnostics map[string][]protocol.Diagnostic
+type completionItems map[token.Pos]*protocol.CompletionItem
+type completions map[token.Position][]token.Pos
+type formats map[string]string
+
+func (c completions) test(t *testing.T, exported *packagestest.Exported, s *server, items completionItems) {
+	for src, itemList := range c {
+		var want []protocol.CompletionItem
+		for _, pos := range itemList {
+			want = append(want, *items[pos])
+		}
+		list, err := s.Completion(context.Background(), &protocol.CompletionParams{
+			TextDocumentPositionParams: protocol.TextDocumentPositionParams{
+				TextDocument: protocol.TextDocumentIdentifier{
+					URI: protocol.DocumentURI(source.ToURI(src.Filename)),
 				},
-			})
-			if err != nil {
-				t.Fatal(err)
-			}
-			got := list.Items
-			if equal := reflect.DeepEqual(want, got); !equal {
-				t.Errorf("completion failed for %s:%v:%v: (expected: %v), (got: %v)", filepath.Base(src.Filename), src.Line, src.Column, want, got)
-			}
-		},
-	}); err != nil {
-		t.Fatal(err)
+				Position: protocol.Position{
+					Line:      float64(src.Line - 1),
+					Character: float64(src.Column - 1),
+				},
+			},
+		})
+		if err != nil {
+			t.Fatal(err)
+		}
+		got := list.Items
+		if equal := reflect.DeepEqual(want, got); !equal {
+			t.Errorf("completion failed for %s:%v:%v: (expected: %v), (got: %v)", filepath.Base(src.Filename), src.Line, src.Column, want, got)
+		}
 	}
 }
 
-func collectCompletionItems(expectedCompletions map[token.Position]*protocol.CompletionItem, pos token.Position, label, detail, kind string) {
+func (c completions) collect(src token.Position, expected []token.Pos) {
+	c[src] = expected
+}
+
+func (i completionItems) collect(pos token.Pos, label, detail, kind string) {
 	var k protocol.CompletionItemKind
 	switch kind {
 	case "struct":
@@ -164,14 +174,26 @@
 	case "method":
 		k = protocol.MethodCompletion
 	}
-	expectedCompletions[pos] = &protocol.CompletionItem{
+	i[pos] = &protocol.CompletionItem{
 		Label:  label,
 		Detail: detail,
 		Kind:   float64(k),
 	}
 }
 
-func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wants map[string][]protocol.Diagnostic) {
+func (d diagnostics) test(t *testing.T, exported *packagestest.Exported, v *source.View, dirs map[string]bool) int {
+	// first trigger a load to get the diagnostics
+	var dirList []string
+	for dir := range dirs {
+		dirList = append(dirList, dir)
+	}
+	exported.Config.Mode = packages.LoadFiles
+	pkgs, err := packages.Load(exported.Config, dirList...)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// and now see if they match the expected ones
+	count := 0
 	for _, pkg := range pkgs {
 		for _, filename := range pkg.GoFiles {
 			f := v.GetFile(source.ToURI(filename))
@@ -183,7 +205,7 @@
 			sort.Slice(got, func(i int, j int) bool {
 				return got[i].Range.Start.Line < got[j].Range.Start.Line
 			})
-			want := wants[filename]
+			want := d[filename]
 			if equal := reflect.DeepEqual(want, got); !equal {
 				msg := &bytes.Buffer{}
 				fmt.Fprintf(msg, "diagnostics failed for %s: expected:\n", filepath.Base(filename))
@@ -196,11 +218,13 @@
 				}
 				t.Error(msg.String())
 			}
+			count += len(want)
 		}
 	}
+	return count
 }
 
-func collectDiagnostics(t *testing.T, expectedDiagnostics map[string][]protocol.Diagnostic, pos token.Position, msg string) {
+func (d diagnostics) collect(pos token.Position, msg string) {
 	line := float64(pos.Line - 1)
 	col := float64(pos.Column - 1)
 	want := protocol.Diagnostic{
@@ -218,15 +242,11 @@
 		Source:   "LSP",
 		Message:  msg,
 	}
-	if _, ok := expectedDiagnostics[pos.Filename]; ok {
-		expectedDiagnostics[pos.Filename] = append(expectedDiagnostics[pos.Filename], want)
-	} else {
-		t.Errorf("unexpected filename: %v", pos.Filename)
-	}
+	d[pos.Filename] = append(d[pos.Filename], want)
 }
 
-func testFormat(t *testing.T, s *server, expectedFormat map[string]string) {
-	for filename, gofmted := range expectedFormat {
+func (f formats) test(t *testing.T, s *server) {
+	for filename, gofmted := range f {
 		edits, err := s.Formatting(context.Background(), &protocol.DocumentFormattingParams{
 			TextDocument: protocol.TextDocumentIdentifier{
 				URI: protocol.DocumentURI(source.ToURI(filename)),
@@ -245,10 +265,10 @@
 	}
 }
 
-func collectFormat(expectedFormat map[string]string, pos token.Position) {
+func (f formats) collect(pos token.Position) {
 	cmd := exec.Command("gofmt", pos.Filename)
 	stdout := bytes.NewBuffer(nil)
 	cmd.Stdout = stdout
 	cmd.Run() // ignore error, sometimes we have intentionally ungofmt-able files
-	expectedFormat[pos.Filename] = stdout.String()
+	f[pos.Filename] = stdout.String()
 }