internal/lsp: add tests for references includeDeclaration setting

Make sure to test both modes, as this is the second time we've
accidentally broken this.

Fixes golang/go#36598.

Change-Id: I3993af3d106b18c76c44ada558b2c6cd9cbfcf17
Reviewed-on: https://go-review.googlesource.com/c/tools/+/215777
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/cmd/test/references.go b/internal/lsp/cmd/test/references.go
index 055a37b..66d0d06 100644
--- a/internal/lsp/cmd/test/references.go
+++ b/internal/lsp/cmd/test/references.go
@@ -13,24 +13,37 @@
 )
 
 func (r *runner) References(t *testing.T, spn span.Span, itemList []span.Span) {
-	var itemStrings []string
-	for _, i := range itemList {
-		itemStrings = append(itemStrings, fmt.Sprint(i))
-	}
-	sort.Strings(itemStrings)
-	var expect string
-	for _, i := range itemStrings {
-		expect += i + "\n"
-	}
-	expect = r.Normalize(expect)
+	for _, includeDeclaration := range []bool{true, false} {
+		t.Run(fmt.Sprintf("refs-declaration-%v", includeDeclaration), func(t *testing.T) {
+			var itemStrings []string
+			for i, s := range itemList {
+				// We don't want the first result if we aren't including the declaration.
+				if i == 0 && !includeDeclaration {
+					continue
+				}
+				itemStrings = append(itemStrings, fmt.Sprint(s))
+			}
+			sort.Strings(itemStrings)
+			var expect string
+			for _, s := range itemStrings {
+				expect += s + "\n"
+			}
+			expect = r.Normalize(expect)
 
-	uri := spn.URI()
-	filename := uri.Filename()
-	target := filename + fmt.Sprintf(":%v:%v", spn.Start().Line(), spn.Start().Column())
-	got, stderr := r.NormalizeGoplsCmd(t, "references", "-d", target)
-	if stderr != "" {
-		t.Errorf("references failed for %s: %s", target, stderr)
-	} else if expect != got {
-		t.Errorf("references failed for %s expected:\n%s\ngot:\n%s", target, expect, got)
+			uri := spn.URI()
+			filename := uri.Filename()
+			target := filename + fmt.Sprintf(":%v:%v", spn.Start().Line(), spn.Start().Column())
+			args := []string{"references"}
+			if includeDeclaration {
+				args = append(args, "-d")
+			}
+			args = append(args, target)
+			got, stderr := r.NormalizeGoplsCmd(t, args...)
+			if stderr != "" {
+				t.Errorf("references failed for %s: %s", target, stderr)
+			} else if expect != got {
+				t.Errorf("references failed for %s expected:\n%s\ngot:\n%s", target, expect, got)
+			}
+		})
 	}
 }
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 7130179..a47bfdd 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -545,38 +545,47 @@
 	if err != nil {
 		t.Fatalf("failed for %v: %v", src, err)
 	}
-	want := make(map[protocol.Location]bool)
-	for _, pos := range itemList {
-		m, err := r.data.Mapper(pos.URI())
-		if err != nil {
-			t.Fatal(err)
-		}
-		loc, err := m.Location(pos)
-		if err != nil {
-			t.Fatalf("failed for %v: %v", src, err)
-		}
-		want[loc] = true
-	}
-	params := &protocol.ReferenceParams{
-		TextDocumentPositionParams: protocol.TextDocumentPositionParams{
-			TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI},
-			Position:     loc.Range.Start,
-		},
-		Context: protocol.ReferenceContext{
-			IncludeDeclaration: true,
-		},
-	}
-	got, err := r.server.References(r.ctx, params)
-	if err != nil {
-		t.Fatalf("failed for %v: %v", src, err)
-	}
-	if len(got) != len(want) {
-		t.Errorf("references failed: different lengths got %v want %v", len(got), len(want))
-	}
-	for _, loc := range got {
-		if !want[loc] {
-			t.Errorf("references failed: incorrect references got %v want %v", loc, want)
-		}
+	for _, includeDeclaration := range []bool{true, false} {
+		t.Run(fmt.Sprintf("refs-declaration-%v", includeDeclaration), func(t *testing.T) {
+			want := make(map[protocol.Location]bool)
+			for i, pos := range itemList {
+				// We don't want the first result if we aren't including the declaration.
+				if i == 0 && !includeDeclaration {
+					continue
+				}
+				m, err := r.data.Mapper(pos.URI())
+				if err != nil {
+					t.Fatal(err)
+				}
+				loc, err := m.Location(pos)
+				if err != nil {
+					t.Fatalf("failed for %v: %v", src, err)
+				}
+				want[loc] = true
+			}
+			params := &protocol.ReferenceParams{
+				TextDocumentPositionParams: protocol.TextDocumentPositionParams{
+					TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI},
+					Position:     loc.Range.Start,
+				},
+				Context: protocol.ReferenceContext{
+					IncludeDeclaration: includeDeclaration,
+				},
+			}
+			got, err := r.server.References(r.ctx, params)
+			if err != nil {
+				t.Fatalf("failed for %v: %v", src, err)
+			}
+			if len(got) != len(want) {
+				t.Errorf("references failed: different lengths got %v want %v", len(got), len(want))
+			}
+			for _, loc := range got {
+				if !want[loc] {
+					t.Errorf("references failed: incorrect references got %v want %v", loc, want)
+				}
+			}
+		})
+
 	}
 }
 
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index e2bdabf..892d46d 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -617,33 +617,42 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	fh, err := r.view.Snapshot().GetFile(src.URI())
+	snapshot := r.view.Snapshot()
+	fh, err := snapshot.GetFile(src.URI())
 	if err != nil {
 		t.Fatal(err)
 	}
-	want := make(map[span.Span]bool)
-	for _, pos := range itemList {
-		want[pos] = true
-	}
-	refs, err := source.References(ctx, r.view.Snapshot(), fh, srcRng.Start, true)
-	if err != nil {
-		t.Fatalf("failed for %v: %v", src, err)
-	}
-	got := make(map[span.Span]bool)
-	for _, refInfo := range refs {
-		refSpan, err := refInfo.Span()
-		if err != nil {
-			t.Fatal(err)
-		}
-		got[refSpan] = true
-	}
-	if len(got) != len(want) {
-		t.Errorf("references failed: different lengths got %v want %v", len(got), len(want))
-	}
-	for spn := range got {
-		if !want[spn] {
-			t.Errorf("references failed: incorrect references got %v want locations %v", got, want)
-		}
+	for _, includeDeclaration := range []bool{true, false} {
+		t.Run(fmt.Sprintf("refs-declaration-%v", includeDeclaration), func(t *testing.T) {
+			want := make(map[span.Span]bool)
+			for i, pos := range itemList {
+				// We don't want the first result if we aren't including the declaration.
+				if i == 0 && !includeDeclaration {
+					continue
+				}
+				want[pos] = true
+			}
+			refs, err := source.References(ctx, snapshot, fh, srcRng.Start, includeDeclaration)
+			if err != nil {
+				t.Fatalf("failed for %s: %v", src, err)
+			}
+			got := make(map[span.Span]bool)
+			for _, refInfo := range refs {
+				refSpan, err := refInfo.Span()
+				if err != nil {
+					t.Fatal(err)
+				}
+				got[refSpan] = true
+			}
+			if len(got) != len(want) {
+				t.Errorf("references failed: different lengths got %v want %v", len(got), len(want))
+			}
+			for spn := range got {
+				if !want[spn] {
+					t.Errorf("references failed: incorrect references got %v want locations %v", got, want)
+				}
+			}
+		})
 	}
 }