imports: prefer paths imported by sibling files.
Adds an Imports field to packageInfo with the imports used by sibling
files, and uses it preferentially if it matches a missing import.
Example: if foo/foo.go imports "local/log", it's a reasonable assumption
that foo/bar.go will also want "local/log" instead of "log".
Change-Id: Ifb504ed5e00ff18459f19d8598cc2c94099ae563
Reviewed-on: https://go-review.googlesource.com/43454
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/imports/fix.go b/imports/fix.go
index 61a5f06..e772b6d 100644
--- a/imports/fix.go
+++ b/imports/fix.go
@@ -68,9 +68,16 @@
return 0
}
+// importInfo is a summary of information about one import.
+type importInfo struct {
+ Path string // full import path (e.g. "crypto/rand")
+ Alias string // import alias, if present (e.g. "crand")
+}
+
// packageInfo is a summary of features found in a package.
type packageInfo struct {
- Globals map[string]bool // symbol => true
+ Globals map[string]bool // symbol => true
+ Imports map[string]importInfo // pkg base name or alias => info
}
// dirPackageInfo exposes the dirPackageInfoFile function so that it can be overridden.
@@ -94,7 +101,7 @@
return nil, err
}
- info := &packageInfo{Globals: make(map[string]bool)}
+ info := &packageInfo{Globals: make(map[string]bool), Imports: make(map[string]importInfo)}
for _, fi := range packageFileInfos {
if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") {
continue
@@ -123,6 +130,16 @@
info.Globals[valueSpec.Names[0].Name] = true
}
}
+
+ for _, imp := range root.Imports {
+ impInfo := importInfo{Path: strings.Trim(imp.Path.Value, `"`)}
+ name := path.Base(impInfo.Path)
+ if imp.Name != nil {
+ name = strings.Trim(imp.Name.Name, `"`)
+ impInfo.Alias = name
+ }
+ info.Imports[name] = impInfo
+ }
}
return info, nil
}
@@ -217,6 +234,16 @@
}
}
+ // Fast path, all references already imported.
+ if len(refs) == 0 {
+ return nil, nil
+ }
+
+ // Can assume this will be necessary in all cases now.
+ if !loadedPackageInfo {
+ packageInfo, _ = dirPackageInfo(f.Name.Name, srcDir, filename)
+ }
+
// Search for imports matching potential package references.
searches := 0
type result struct {
@@ -227,6 +254,11 @@
results := make(chan result)
for pkgName, symbols := range refs {
go func(pkgName string, symbols map[string]bool) {
+ sibling := packageInfo.Imports[pkgName]
+ if sibling.Path != "" {
+ results <- result{ipath: sibling.Path, name: sibling.Alias}
+ return
+ }
ipath, rename, err := findImport(pkgName, symbols, filename)
r := result{ipath: ipath, err: err}
if rename {
diff --git a/imports/fix_test.go b/imports/fix_test.go
index 048b9c3..2026d5c 100644
--- a/imports/fix_test.go
+++ b/imports/fix_test.go
@@ -1536,6 +1536,57 @@
})
}
+// Tests that sibling files - other files in the same package - can provide an
+// import that may not be the default one otherwise.
+func TestSiblingImports(t *testing.T) {
+
+ // provide is the sibling file that provides the desired import.
+ const provide = `package siblingimporttest
+
+import "local/log"
+
+func LogSomething() {
+ log.Print("Something")
+}
+`
+
+ // need is the file being tested that needs the import.
+ const need = `package siblingimporttest
+
+func LogSomethingElse() {
+ log.Print("Something else")
+}
+`
+
+ // want is the expected result file
+ const want = `package siblingimporttest
+
+import "local/log"
+
+func LogSomethingElse() {
+ log.Print("Something else")
+}
+`
+
+ const pkg = "siblingimporttest"
+ const siblingFile = pkg + "/needs_import.go"
+ testConfig{
+ gopathFiles: map[string]string{
+ siblingFile: need,
+ pkg + "/provides_import.go": provide,
+ },
+ }.test(t, func(t *goimportTest) {
+ buf, err := Process(
+ t.gopath+"/src/"+siblingFile, []byte(need), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != want {
+ t.Errorf("wrong output.\ngot:\n%q\nwant:\n%q\n", buf, want)
+ }
+ })
+}
+
func strSet(ss []string) map[string]bool {
m := make(map[string]bool)
for _, s := range ss {