internal/lsp: add some basic tests for imports

This change adds a few simple tests for the goimports behavior of gopls.
There are still missing cases for non-standard library, but this is a
good start.

Change-Id: I2f9bc2cc876dcabf81413384b83fa3508517adf0
Reviewed-on: https://go-review.googlesource.com/c/tools/+/179918
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go
index 611d78a..8863145 100644
--- a/internal/lsp/cmd/cmd_test.go
+++ b/internal/lsp/cmd/cmd_test.go
@@ -61,6 +61,10 @@
 	//TODO: add command line link tests when it works
 }
 
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+	//TODO: add command line imports tests when it works
+}
+
 func captureStdOut(t testing.TB, f func()) string {
 	r, out, err := os.Pipe()
 	if err != nil {
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index d4a0d54..f90f76a 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -322,6 +322,53 @@
 	}
 }
 
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+	ctx := context.Background()
+	for _, spn := range data {
+		uri := spn.URI()
+		filename, err := uri.Filename()
+		if err != nil {
+			t.Fatal(err)
+		}
+		goimported := string(r.data.Golden("goimports", filename, func() ([]byte, error) {
+			cmd := exec.Command("goimports", filename)
+			out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files
+			return out, nil
+		}))
+
+		actions, err := r.server.CodeAction(context.Background(), &protocol.CodeActionParams{
+			TextDocument: protocol.TextDocumentIdentifier{
+				URI: protocol.NewURI(uri),
+			},
+		})
+		if err != nil {
+			if goimported != "" {
+				t.Error(err)
+			}
+			continue
+		}
+		_, m, err := getSourceFile(ctx, r.server.session.ViewOf(uri), uri)
+		if err != nil {
+			t.Error(err)
+		}
+		var edits []protocol.TextEdit
+		for _, a := range actions {
+			if a.Title == "Organize Imports" {
+				edits = (*a.Edit.Changes)[string(uri)]
+			}
+		}
+		sedits, err := FromProtocolEdits(m, edits)
+		if err != nil {
+			t.Error(err)
+		}
+		ops := source.EditsToDiff(sedits)
+		got := strings.Join(diff.ApplyEdits(diff.SplitLines(string(m.Content)), ops), "")
+		if goimported != got {
+			t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got)
+		}
+	}
+}
+
 func (r *runner) Definition(t *testing.T, data tests.Definitions) {
 	for _, d := range data {
 		sm, err := r.mapper(d.Src.URI())
diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go
index b459e2a..d2971f0 100644
--- a/internal/lsp/source/format.go
+++ b/internal/lsp/source/format.go
@@ -64,7 +64,7 @@
 	if tok == nil {
 		return nil, fmt.Errorf("no token file for %s", f.URI())
 	}
-	formatted, err := imports.Process(f.GetToken(ctx).Name(), fc.Data, nil)
+	formatted, err := imports.Process(tok.Name(), fc.Data, nil)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index 67ed7e3..28e767c 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -308,6 +308,47 @@
 	}
 }
 
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+	ctx := context.Background()
+	for _, spn := range data {
+		uri := spn.URI()
+		filename, err := uri.Filename()
+		if err != nil {
+			t.Fatal(err)
+		}
+		goimported := string(r.data.Golden("goimports", filename, func() ([]byte, error) {
+			cmd := exec.Command("goimports", filename)
+			out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files
+			return out, nil
+		}))
+		f, err := r.view.GetFile(ctx, uri)
+		if err != nil {
+			t.Fatalf("failed for %v: %v", spn, err)
+		}
+		rng, err := spn.Range(span.NewTokenConverter(f.FileSet(), f.GetToken(ctx)))
+		if err != nil {
+			t.Fatalf("failed for %v: %v", spn, err)
+		}
+		edits, err := source.Imports(ctx, f.(source.GoFile), rng)
+		if err != nil {
+			if goimported != "" {
+				t.Error(err)
+			}
+			continue
+		}
+		ops := source.EditsToDiff(edits)
+		fc := f.Content(ctx)
+		if fc.Error != nil {
+			t.Error(err)
+			continue
+		}
+		got := strings.Join(diff.ApplyEdits(diff.SplitLines(string(fc.Data)), ops), "")
+		if goimported != got {
+			t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got)
+		}
+	}
+}
+
 func (r *runner) Definition(t *testing.T, data tests.Definitions) {
 	ctx := context.Background()
 	for _, d := range data {
diff --git a/internal/lsp/testdata/imports/good_imports.go b/internal/lsp/testdata/imports/good_imports.go
new file mode 100644
index 0000000..667487c
--- /dev/null
+++ b/internal/lsp/testdata/imports/good_imports.go
@@ -0,0 +1,7 @@
+package imports //@import("package")
+
+import "fmt"
+
+func _() {
+	fmt.Println("")
+}
\ No newline at end of file
diff --git a/internal/lsp/testdata/imports/good_imports.go.golden b/internal/lsp/testdata/imports/good_imports.go.golden
new file mode 100644
index 0000000..eabb5b8
--- /dev/null
+++ b/internal/lsp/testdata/imports/good_imports.go.golden
@@ -0,0 +1,11 @@
+-- goimports --
+package imports //@import("package")
+
+import "fmt"
+
+func _() {
+	fmt.Println("")
+}
+
+-- goimports-d --
+
diff --git a/internal/lsp/testdata/imports/needs_imports.go b/internal/lsp/testdata/imports/needs_imports.go
new file mode 100644
index 0000000..949d56a
--- /dev/null
+++ b/internal/lsp/testdata/imports/needs_imports.go
@@ -0,0 +1,6 @@
+package imports //@import("package")
+
+func goodbye() {
+	fmt.Printf("HI")
+	log.Printf("byeeeee")
+}
diff --git a/internal/lsp/testdata/imports/needs_imports.go.golden b/internal/lsp/testdata/imports/needs_imports.go.golden
new file mode 100644
index 0000000..d09104b
--- /dev/null
+++ b/internal/lsp/testdata/imports/needs_imports.go.golden
@@ -0,0 +1,27 @@
+-- goimports --
+package imports //@import("package")
+
+import (
+	"fmt"
+	"log"
+)
+
+func goodbye() {
+	fmt.Printf("HI")
+	log.Printf("byeeeee")
+}
+
+-- goimports-d --
+--- imports/needs_imports.go.orig
++++ imports/needs_imports.go
+@@ -1,5 +1,10 @@
+ package imports //@import("package")
+ 
++import (
++	"fmt"
++	"log"
++)
++
+ func goodbye() {
+ 	fmt.Printf("HI")
+ 	log.Printf("byeeeee")
\ No newline at end of file
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 301b847..a8dc349 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -11,9 +11,7 @@
 	"go/parser"
 	"go/token"
 	"io/ioutil"
-	"os/exec"
 	"path/filepath"
-	"runtime"
 	"sort"
 	"strings"
 	"testing"
@@ -32,6 +30,7 @@
 	ExpectedCompletionSnippetCount = 14
 	ExpectedDiagnosticsCount       = 17
 	ExpectedFormatCount            = 5
+	ExpectedImportCount            = 2
 	ExpectedDefinitionsCount       = 35
 	ExpectedTypeDefinitionsCount   = 2
 	ExpectedHighlightsCount        = 2
@@ -54,6 +53,7 @@
 type Completions map[span.Span][]token.Pos
 type CompletionSnippets map[span.Span]CompletionSnippet
 type Formats []span.Span
+type Imports []span.Span
 type Definitions map[span.Span]Definition
 type Highlights map[string][]span.Span
 type Symbols map[span.URI][]source.Symbol
@@ -69,6 +69,7 @@
 	Completions        Completions
 	CompletionSnippets CompletionSnippets
 	Formats            Formats
+	Imports            Imports
 	Definitions        Definitions
 	Highlights         Highlights
 	Symbols            Symbols
@@ -86,6 +87,7 @@
 	Diagnostics(*testing.T, Diagnostics)
 	Completion(*testing.T, Completions, CompletionSnippets, CompletionItems)
 	Format(*testing.T, Formats)
+	Import(*testing.T, Imports)
 	Definition(*testing.T, Definitions)
 	Highlight(*testing.T, Highlights)
 	Symbol(*testing.T, Symbols)
@@ -202,6 +204,7 @@
 		"item":      data.collectCompletionItems,
 		"complete":  data.collectCompletions,
 		"format":    data.collectFormats,
+		"import":    data.collectImports,
 		"godef":     data.collectDefinitions,
 		"typdef":    data.collectTypeDefinitions,
 		"hover":     data.collectHoverDefinitions,
@@ -256,20 +259,20 @@
 
 	t.Run("Format", func(t *testing.T) {
 		t.Helper()
-		if _, err := exec.LookPath("gofmt"); err != nil {
-			switch runtime.GOOS {
-			case "android":
-				t.Skip("gofmt is not installed")
-			default:
-				t.Fatal(err)
-			}
-		}
 		if len(data.Formats) != ExpectedFormatCount {
 			t.Errorf("got %v formats expected %v", len(data.Formats), ExpectedFormatCount)
 		}
 		tests.Format(t, data.Formats)
 	})
 
+	t.Run("Import", func(t *testing.T) {
+		t.Helper()
+		if len(data.Imports) != ExpectedImportCount {
+			t.Errorf("got %v imports expected %v", len(data.Imports), ExpectedImportCount)
+		}
+		tests.Import(t, data.Imports)
+	})
+
 	t.Run("Definition", func(t *testing.T) {
 		t.Helper()
 		if len(data.Definitions) != ExpectedDefinitionsCount {
@@ -416,6 +419,10 @@
 	data.Formats = append(data.Formats, spn)
 }
 
+func (data *Data) collectImports(spn span.Span) {
+	data.Imports = append(data.Imports, spn)
+}
+
 func (data *Data) collectDefinitions(src, target span.Span) {
 	data.Definitions[src] = Definition{
 		Src: src,