internal/lsp: add some basic tests for imports
This change adds a few simple tests for the goimports behavior of gopls.
There are still missing cases for non-standard library, but this is a
good start.
Change-Id: I2f9bc2cc876dcabf81413384b83fa3508517adf0
Reviewed-on: https://go-review.googlesource.com/c/tools/+/179918
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go
index 611d78a..8863145 100644
--- a/internal/lsp/cmd/cmd_test.go
+++ b/internal/lsp/cmd/cmd_test.go
@@ -61,6 +61,10 @@
//TODO: add command line link tests when it works
}
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+ //TODO: add command line imports tests when it works
+}
+
func captureStdOut(t testing.TB, f func()) string {
r, out, err := os.Pipe()
if err != nil {
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index d4a0d54..f90f76a 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -322,6 +322,53 @@
}
}
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+ ctx := context.Background()
+ for _, spn := range data {
+ uri := spn.URI()
+ filename, err := uri.Filename()
+ if err != nil {
+ t.Fatal(err)
+ }
+ goimported := string(r.data.Golden("goimports", filename, func() ([]byte, error) {
+ cmd := exec.Command("goimports", filename)
+ out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files
+ return out, nil
+ }))
+
+ actions, err := r.server.CodeAction(context.Background(), &protocol.CodeActionParams{
+ TextDocument: protocol.TextDocumentIdentifier{
+ URI: protocol.NewURI(uri),
+ },
+ })
+ if err != nil {
+ if goimported != "" {
+ t.Error(err)
+ }
+ continue
+ }
+ _, m, err := getSourceFile(ctx, r.server.session.ViewOf(uri), uri)
+ if err != nil {
+ t.Error(err)
+ }
+ var edits []protocol.TextEdit
+ for _, a := range actions {
+ if a.Title == "Organize Imports" {
+ edits = (*a.Edit.Changes)[string(uri)]
+ }
+ }
+ sedits, err := FromProtocolEdits(m, edits)
+ if err != nil {
+ t.Error(err)
+ }
+ ops := source.EditsToDiff(sedits)
+ got := strings.Join(diff.ApplyEdits(diff.SplitLines(string(m.Content)), ops), "")
+ if goimported != got {
+ t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got)
+ }
+ }
+}
+
func (r *runner) Definition(t *testing.T, data tests.Definitions) {
for _, d := range data {
sm, err := r.mapper(d.Src.URI())
diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go
index b459e2a..d2971f0 100644
--- a/internal/lsp/source/format.go
+++ b/internal/lsp/source/format.go
@@ -64,7 +64,7 @@
if tok == nil {
return nil, fmt.Errorf("no token file for %s", f.URI())
}
- formatted, err := imports.Process(f.GetToken(ctx).Name(), fc.Data, nil)
+ formatted, err := imports.Process(tok.Name(), fc.Data, nil)
if err != nil {
return nil, err
}
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index 67ed7e3..28e767c 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -308,6 +308,47 @@
}
}
+func (r *runner) Import(t *testing.T, data tests.Imports) {
+ ctx := context.Background()
+ for _, spn := range data {
+ uri := spn.URI()
+ filename, err := uri.Filename()
+ if err != nil {
+ t.Fatal(err)
+ }
+ goimported := string(r.data.Golden("goimports", filename, func() ([]byte, error) {
+ cmd := exec.Command("goimports", filename)
+ out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files
+ return out, nil
+ }))
+ f, err := r.view.GetFile(ctx, uri)
+ if err != nil {
+ t.Fatalf("failed for %v: %v", spn, err)
+ }
+ rng, err := spn.Range(span.NewTokenConverter(f.FileSet(), f.GetToken(ctx)))
+ if err != nil {
+ t.Fatalf("failed for %v: %v", spn, err)
+ }
+ edits, err := source.Imports(ctx, f.(source.GoFile), rng)
+ if err != nil {
+ if goimported != "" {
+ t.Error(err)
+ }
+ continue
+ }
+ ops := source.EditsToDiff(edits)
+ fc := f.Content(ctx)
+ if fc.Error != nil {
+ t.Error(err)
+ continue
+ }
+ got := strings.Join(diff.ApplyEdits(diff.SplitLines(string(fc.Data)), ops), "")
+ if goimported != got {
+ t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got)
+ }
+ }
+}
+
func (r *runner) Definition(t *testing.T, data tests.Definitions) {
ctx := context.Background()
for _, d := range data {
diff --git a/internal/lsp/testdata/imports/good_imports.go b/internal/lsp/testdata/imports/good_imports.go
new file mode 100644
index 0000000..667487c
--- /dev/null
+++ b/internal/lsp/testdata/imports/good_imports.go
@@ -0,0 +1,7 @@
+package imports //@import("package")
+
+import "fmt"
+
+func _() {
+ fmt.Println("")
+}
\ No newline at end of file
diff --git a/internal/lsp/testdata/imports/good_imports.go.golden b/internal/lsp/testdata/imports/good_imports.go.golden
new file mode 100644
index 0000000..eabb5b8
--- /dev/null
+++ b/internal/lsp/testdata/imports/good_imports.go.golden
@@ -0,0 +1,11 @@
+-- goimports --
+package imports //@import("package")
+
+import "fmt"
+
+func _() {
+ fmt.Println("")
+}
+
+-- goimports-d --
+
diff --git a/internal/lsp/testdata/imports/needs_imports.go b/internal/lsp/testdata/imports/needs_imports.go
new file mode 100644
index 0000000..949d56a
--- /dev/null
+++ b/internal/lsp/testdata/imports/needs_imports.go
@@ -0,0 +1,6 @@
+package imports //@import("package")
+
+func goodbye() {
+ fmt.Printf("HI")
+ log.Printf("byeeeee")
+}
diff --git a/internal/lsp/testdata/imports/needs_imports.go.golden b/internal/lsp/testdata/imports/needs_imports.go.golden
new file mode 100644
index 0000000..d09104b
--- /dev/null
+++ b/internal/lsp/testdata/imports/needs_imports.go.golden
@@ -0,0 +1,27 @@
+-- goimports --
+package imports //@import("package")
+
+import (
+ "fmt"
+ "log"
+)
+
+func goodbye() {
+ fmt.Printf("HI")
+ log.Printf("byeeeee")
+}
+
+-- goimports-d --
+--- imports/needs_imports.go.orig
++++ imports/needs_imports.go
+@@ -1,5 +1,10 @@
+ package imports //@import("package")
+
++import (
++ "fmt"
++ "log"
++)
++
+ func goodbye() {
+ fmt.Printf("HI")
+ log.Printf("byeeeee")
\ No newline at end of file
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 301b847..a8dc349 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -11,9 +11,7 @@
"go/parser"
"go/token"
"io/ioutil"
- "os/exec"
"path/filepath"
- "runtime"
"sort"
"strings"
"testing"
@@ -32,6 +30,7 @@
ExpectedCompletionSnippetCount = 14
ExpectedDiagnosticsCount = 17
ExpectedFormatCount = 5
+ ExpectedImportCount = 2
ExpectedDefinitionsCount = 35
ExpectedTypeDefinitionsCount = 2
ExpectedHighlightsCount = 2
@@ -54,6 +53,7 @@
type Completions map[span.Span][]token.Pos
type CompletionSnippets map[span.Span]CompletionSnippet
type Formats []span.Span
+type Imports []span.Span
type Definitions map[span.Span]Definition
type Highlights map[string][]span.Span
type Symbols map[span.URI][]source.Symbol
@@ -69,6 +69,7 @@
Completions Completions
CompletionSnippets CompletionSnippets
Formats Formats
+ Imports Imports
Definitions Definitions
Highlights Highlights
Symbols Symbols
@@ -86,6 +87,7 @@
Diagnostics(*testing.T, Diagnostics)
Completion(*testing.T, Completions, CompletionSnippets, CompletionItems)
Format(*testing.T, Formats)
+ Import(*testing.T, Imports)
Definition(*testing.T, Definitions)
Highlight(*testing.T, Highlights)
Symbol(*testing.T, Symbols)
@@ -202,6 +204,7 @@
"item": data.collectCompletionItems,
"complete": data.collectCompletions,
"format": data.collectFormats,
+ "import": data.collectImports,
"godef": data.collectDefinitions,
"typdef": data.collectTypeDefinitions,
"hover": data.collectHoverDefinitions,
@@ -256,20 +259,20 @@
t.Run("Format", func(t *testing.T) {
t.Helper()
- if _, err := exec.LookPath("gofmt"); err != nil {
- switch runtime.GOOS {
- case "android":
- t.Skip("gofmt is not installed")
- default:
- t.Fatal(err)
- }
- }
if len(data.Formats) != ExpectedFormatCount {
t.Errorf("got %v formats expected %v", len(data.Formats), ExpectedFormatCount)
}
tests.Format(t, data.Formats)
})
+ t.Run("Import", func(t *testing.T) {
+ t.Helper()
+ if len(data.Imports) != ExpectedImportCount {
+ t.Errorf("got %v imports expected %v", len(data.Imports), ExpectedImportCount)
+ }
+ tests.Import(t, data.Imports)
+ })
+
t.Run("Definition", func(t *testing.T) {
t.Helper()
if len(data.Definitions) != ExpectedDefinitionsCount {
@@ -416,6 +419,10 @@
data.Formats = append(data.Formats, spn)
}
+func (data *Data) collectImports(spn span.Span) {
+ data.Imports = append(data.Imports, spn)
+}
+
func (data *Data) collectDefinitions(src, target span.Span) {
data.Definitions[src] = Definition{
Src: src,