internal/analysisinternal: AddImport helper for inserting imports
This CL defines a new helper for inserting import declarations,
as needed, when a refactoring introduces a reference to an
imported symbol.
Also, a test.
Change-Id: Icba17e6f76e67d2dad8f68b312db7111f4df817a
Reviewed-on: https://go-review.googlesource.com/c/tools/+/592277
Reviewed-by: Robert Findley <rfindley@google.com>
Reviewed-by: Tim King <taking@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/analysisinternal/addimport_test.go b/internal/analysisinternal/addimport_test.go
new file mode 100644
index 0000000..9871b5b
--- /dev/null
+++ b/internal/analysisinternal/addimport_test.go
@@ -0,0 +1,233 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package analysisinternal_test
+
+import (
+ "fmt"
+ "go/ast"
+ "go/importer"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/tools/internal/analysisinternal"
+)
+
+func TestAddImport(t *testing.T) {
+ descr := func(s string) string {
+ if _, _, line, ok := runtime.Caller(1); ok {
+ return fmt.Sprintf("L%d %s", line, s)
+ }
+ panic("runtime.Caller failed")
+ }
+
+ // Each test case contains a «name pkgpath»
+ // section to be replaced with a reference
+ // to a valid import of pkgpath,
+ // ideally of the specified name.
+ for _, test := range []struct {
+ descr, src, want string
+ }{
+ {
+ descr: descr("simple add import"),
+ src: `package a
+func _() {
+ «fmt fmt»
+}`,
+ want: `package a
+import "fmt"
+
+func _() {
+ fmt
+}`,
+ },
+ {
+ descr: descr("existing import"),
+ src: `package a
+
+import "fmt"
+
+func _(fmt.Stringer) {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import "fmt"
+
+func _(fmt.Stringer) {
+ fmt
+}`,
+ },
+ {
+ descr: descr("existing blank import"),
+ src: `package a
+
+import _ "fmt"
+
+func _() {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import "fmt"
+
+import _ "fmt"
+
+func _() {
+ fmt
+}`,
+ },
+ {
+ descr: descr("existing renaming import"),
+ src: `package a
+
+import fmtpkg "fmt"
+
+var fmt int
+
+func _(fmtpkg.Stringer) {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import fmtpkg "fmt"
+
+var fmt int
+
+func _(fmtpkg.Stringer) {
+ fmtpkg
+}`,
+ },
+ {
+ descr: descr("existing import is shadowed"),
+ src: `package a
+
+import "fmt"
+
+var _ fmt.Stringer
+
+func _(fmt int) {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import fmt0 "fmt"
+
+import "fmt"
+
+var _ fmt.Stringer
+
+func _(fmt int) {
+ fmt0
+}`,
+ },
+ {
+ descr: descr("preferred name is shadowed"),
+ src: `package a
+
+import "fmt"
+
+func _(fmt fmt.Stringer) {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import fmt0 "fmt"
+
+import "fmt"
+
+func _(fmt fmt.Stringer) {
+ fmt0
+}`,
+ },
+ {
+ descr: descr("import inserted before doc comments"),
+ src: `package a
+
+// hello
+import ()
+
+// world
+func _() {
+ «fmt fmt»
+}`,
+ want: `package a
+
+import "fmt"
+
+// hello
+import ()
+
+// world
+func _() {
+ fmt
+}`,
+ },
+ {
+ descr: descr("arbitrary preferred name => renaming import"),
+ src: `package a
+
+func _() {
+ «foo encoding/json»
+}`,
+ want: `package a
+
+import foo "encoding/json"
+
+func _() {
+ foo
+}`,
+ },
+ } {
+ t.Run(test.descr, func(t *testing.T) {
+ // splice marker
+ before, mid, ok1 := strings.Cut(test.src, "«")
+ mid, after, ok2 := strings.Cut(mid, "»")
+ if !ok1 || !ok2 {
+ t.Fatal("no «name path» marker")
+ }
+ src := before + "/*!*/" + after
+ name, path, _ := strings.Cut(mid, " ")
+
+ // parse
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, "a.go", src, parser.ParseComments)
+ if err != nil {
+ t.Log(err)
+ }
+ pos := fset.File(f.Pos()).Pos(len(before))
+
+ // type-check
+ info := &types.Info{
+ Types: make(map[ast.Expr]types.TypeAndValue),
+ Scopes: make(map[ast.Node]*types.Scope),
+ Defs: make(map[*ast.Ident]types.Object),
+ Implicits: make(map[ast.Node]types.Object),
+ }
+ conf := &types.Config{
+ Error: func(err error) { t.Log(err) },
+ Importer: importer.Default(),
+ }
+ conf.Check(f.Name.Name, fset, []*ast.File{f}, info)
+
+ // add import
+ name, edit := analysisinternal.AddImport(info, f, pos, path, name)
+
+ // apply patch
+ start := fset.Position(edit.Pos)
+ end := fset.Position(edit.End)
+ output := src[:start.Offset] + string(edit.NewText) + src[end.Offset:]
+ output = strings.ReplaceAll(output, "/*!*/", name)
+ if output != test.want {
+ t.Errorf("\n--got--\n%s\n--want--\n%s\n--diff--\n%s",
+ output, test.want, cmp.Diff(test.want, output))
+ }
+ })
+ }
+}
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go
index db639c1..4000d27 100644
--- a/internal/analysisinternal/analysis.go
+++ b/internal/analysisinternal/analysis.go
@@ -13,6 +13,7 @@
"go/token"
"go/types"
"os"
+ pathpkg "path"
"strconv"
"golang.org/x/tools/go/analysis"
@@ -432,3 +433,81 @@
}
return false
}
+
+// AddImport checks whether this file already imports pkgpath and
+// that import is in scope at pos. If so, it returns the name under
+// which it was imported and a zero edit. Otherwise, it adds a new
+// import of pkgpath, using a name derived from the preferred name,
+// and returns the chosen name along with the edit for the new import.
+//
+// It does not mutate its arguments.
+func AddImport(info *types.Info, file *ast.File, pos token.Pos, pkgpath, preferredName string) (name string, newImport analysis.TextEdit) {
+ // Find innermost enclosing lexical block.
+ scope := info.Scopes[file].Innermost(pos)
+ if scope == nil {
+ panic("no enclosing lexical block")
+ }
+
+ // Is there an existing import of this package?
+ // If so, are we in its scope? (not shadowed)
+ for _, spec := range file.Imports {
+ pkgname, ok := importedPkgName(info, spec)
+ if ok && pkgname.Imported().Path() == pkgpath {
+ if _, obj := scope.LookupParent(pkgname.Name(), pos); obj == pkgname {
+ return pkgname.Name(), analysis.TextEdit{}
+ }
+ }
+ }
+
+ // We must add a new import.
+ // Ensure we have a fresh name.
+ newName := preferredName
+ for i := 0; ; i++ {
+ if _, obj := scope.LookupParent(newName, pos); obj == nil {
+ break // fresh
+ }
+ newName = fmt.Sprintf("%s%d", preferredName, i)
+ }
+
+ // For now, keep it real simple: create a new import
+ // declaration before the first existing declaration (which
+ // must exist), including its comments, and let goimports tidy it up.
+ //
+ // Use a renaming import whenever the preferred name is not
+ // available, or the chosen name does not match the last
+ // segment of its path.
+ newText := fmt.Sprintf("import %q\n\n", pkgpath)
+ if newName != preferredName || newName != pathpkg.Base(pkgpath) {
+ newText = fmt.Sprintf("import %s %q\n\n", newName, pkgpath)
+ }
+ decl0 := file.Decls[0]
+ var before ast.Node = decl0
+ switch decl0 := decl0.(type) {
+ case *ast.GenDecl:
+ if decl0.Doc != nil {
+ before = decl0.Doc
+ }
+ case *ast.FuncDecl:
+ if decl0.Doc != nil {
+ before = decl0.Doc
+ }
+ }
+ return newName, analysis.TextEdit{
+ Pos: before.Pos(),
+ End: before.Pos(),
+ NewText: []byte(newText),
+ }
+}
+
+// importedPkgName returns the PkgName object declared by an ImportSpec.
+// TODO(adonovan): use go1.22's Info.PkgNameOf.
+func importedPkgName(info *types.Info, imp *ast.ImportSpec) (*types.PkgName, bool) {
+ var obj types.Object
+ if imp.Name != nil {
+ obj = info.Defs[imp.Name]
+ } else {
+ obj = info.Implicits[imp]
+ }
+ pkgname, ok := obj.(*types.PkgName)
+ return pkgname, ok
+}