internal/imports: refactor to split finding and applying fixes

A pass is responsible for fixing the imports of a given file. It now
finds the necessary changes to make without applying the result to the
ast, which may be desirable to give a user more control about what
changes will be applied to their program. This change splits the process
of finding the fixes from making the modifications to the ast to allow
this functionality to be easily possible.

Change-Id: Ibf8ca247c35539f91de4be90c634f0db9a939d07
Reviewed-on: https://go-review.googlesource.com/c/tools/+/184197
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Suzy Mueller <suzmue@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/imports/fix.go b/internal/imports/fix.go
index 76a79e1..2c58d18 100644
--- a/internal/imports/fix.go
+++ b/internal/imports/fix.go
@@ -67,6 +67,19 @@
 	return 0
 }
 
+type importFixType int
+
+const (
+	addImport importFixType = iota
+	deleteImport
+	setImportName
+)
+
+type importFix struct {
+	info    importInfo
+	fixType importFixType
+}
+
 // An importInfo represents a single import statement.
 type importInfo struct {
 	importPath string // import path, e.g. "crypto/rand".
@@ -290,7 +303,7 @@
 // load reads in everything necessary to run a pass, and reports whether the
 // file already has all the imports it needs. It fills in p.missingRefs with the
 // file's missing symbols, if any, or removes unused imports if not.
-func (p *pass) load() bool {
+func (p *pass) load() ([]*importFix, bool) {
 	p.knownPackages = map[string]*packageInfo{}
 	p.missingRefs = references{}
 	p.existingImports = map[string]*importInfo{}
@@ -320,7 +333,7 @@
 			if p.env.Debug {
 				p.env.Logf("loading package names: %v", err)
 			}
-			return false
+			return nil, false
 		}
 	}
 	for _, imp := range imports {
@@ -339,16 +352,16 @@
 		}
 	}
 	if len(p.missingRefs) != 0 {
-		return false
+		return nil, false
 	}
 
 	return p.fix()
 }
 
 // fix attempts to satisfy missing imports using p.candidates. If it finds
-// everything, or if p.lastTry is true, it adds the imports it found,
-// removes anything unused, and returns true.
-func (p *pass) fix() bool {
+// everything, or if p.lastTry is true, it updates fixes to add the imports it found,
+// delete anything unused, and update import names, and returns true.
+func (p *pass) fix() ([]*importFix, bool) {
 	// Find missing imports.
 	var selected []*importInfo
 	for left, rights := range p.missingRefs {
@@ -358,10 +371,11 @@
 	}
 
 	if !p.lastTry && len(selected) != len(p.missingRefs) {
-		return false
+		return nil, false
 	}
 
 	// Found everything, or giving up. Add the new imports and remove any unused.
+	var fixes []*importFix
 	for _, imp := range p.existingImports {
 		// We deliberately ignore globals here, because we can't be sure
 		// they're in the same package. People do things like put multiple
@@ -369,27 +383,77 @@
 		// remove imports if they happen to have the same name as a var in
 		// a different package.
 		if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok {
-			astutil.DeleteNamedImport(p.fset, p.f, imp.name, imp.importPath)
+			fixes = append(fixes, &importFix{
+				info:    *imp,
+				fixType: deleteImport,
+			})
+			continue
+		}
+
+		// An existing import may need to update its import name to be correct.
+		if name := p.importSpecName(imp); name != imp.name {
+			fixes = append(fixes, &importFix{
+				info: importInfo{
+					name:       name,
+					importPath: imp.importPath,
+				},
+				fixType: setImportName,
+			})
 		}
 	}
 
 	for _, imp := range selected {
-		astutil.AddNamedImport(p.fset, p.f, imp.name, imp.importPath)
+		fixes = append(fixes, &importFix{
+			info: importInfo{
+				name:       p.importSpecName(imp),
+				importPath: imp.importPath,
+			},
+			fixType: addImport,
+		})
 	}
 
-	if p.loadRealPackageNames {
-		for _, imp := range p.f.Imports {
-			if imp.Name != nil {
-				continue
-			}
-			path := strings.Trim(imp.Path.Value, `""`)
-			ident := p.importIdentifier(&importInfo{importPath: path})
-			if ident != importPathToAssumedName(path) {
-				imp.Name = &ast.Ident{Name: ident, NamePos: imp.Pos()}
+	return fixes, true
+}
+
+// importSpecName gets the import name of imp in the import spec.
+//
+// When the import identifier matches the assumed import name, the import name does
+// not appear in the import spec.
+func (p *pass) importSpecName(imp *importInfo) string {
+	// If we did not load the real package names, or the name is already set,
+	// we just return the existing name.
+	if !p.loadRealPackageNames || imp.name != "" {
+		return imp.name
+	}
+
+	ident := p.importIdentifier(imp)
+	if ident == importPathToAssumedName(imp.importPath) {
+		return "" // ident not needed since the assumed and real names are the same.
+	}
+	return ident
+}
+
+// apply will perform the fixes on f in order.
+func apply(fset *token.FileSet, f *ast.File, fixes []*importFix) bool {
+	for _, fix := range fixes {
+		switch fix.fixType {
+		case deleteImport:
+			astutil.DeleteNamedImport(fset, f, fix.info.name, fix.info.importPath)
+		case addImport:
+			astutil.AddNamedImport(fset, f, fix.info.name, fix.info.importPath)
+		case setImportName:
+			// Find the matching import path and change the name.
+			for _, spec := range f.Imports {
+				path := strings.Trim(spec.Path.Value, `""`)
+				if path == fix.info.importPath {
+					spec.Name = &ast.Ident{
+						Name:    fix.info.name,
+						NamePos: spec.Pos(),
+					}
+				}
 			}
 		}
 	}
-
 	return true
 }
 
@@ -442,10 +506,21 @@
 var fixImports = fixImportsDefault
 
 func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
-	abs, err := filepath.Abs(filename)
+	fixes, err := getFixes(fset, f, filename, env)
 	if err != nil {
 		return err
 	}
+	apply(fset, f, fixes)
+	return err
+}
+
+// getFixes gets the getFixes that need to be made to f in order to fix the imports.
+// It does not modify the ast.
+func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*importFix, error) {
+	abs, err := filepath.Abs(filename)
+	if err != nil {
+		return nil, err
+	}
 	srcDir := filepath.Dir(abs)
 	if env.Debug {
 		env.Logf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir)
@@ -456,8 +531,8 @@
 	// complete. We can't add any imports yet, because we don't know
 	// if missing references are actually package vars.
 	p := &pass{fset: fset, f: f, srcDir: srcDir}
-	if p.load() {
-		return nil
+	if fixes, done := p.load(); done {
+		return fixes, nil
 	}
 
 	otherFiles := parseOtherFiles(fset, srcDir, filename)
@@ -465,15 +540,15 @@
 	// Second pass: add information from other files in the same package,
 	// like their package vars and imports.
 	p.otherFiles = otherFiles
-	if p.load() {
-		return nil
+	if fixes, done := p.load(); done {
+		return fixes, nil
 	}
 
 	// Now we can try adding imports from the stdlib.
 	p.assumeSiblingImportsValid()
 	addStdlibCandidates(p, p.missingRefs)
-	if p.fix() {
-		return nil
+	if fixes, done := p.fix(); done {
+		return fixes, nil
 	}
 
 	// Third pass: get real package names where we had previously used
@@ -482,25 +557,25 @@
 	p = &pass{fset: fset, f: f, srcDir: srcDir, env: env}
 	p.loadRealPackageNames = true
 	p.otherFiles = otherFiles
-	if p.load() {
-		return nil
+	if fixes, done := p.load(); done {
+		return fixes, nil
 	}
 
 	addStdlibCandidates(p, p.missingRefs)
 	p.assumeSiblingImportsValid()
-	if p.fix() {
-		return nil
+	if fixes, done := p.fix(); done {
+		return fixes, nil
 	}
 
 	// Go look for candidates in $GOPATH, etc. We don't necessarily load
 	// the real exports of sibling imports, so keep assuming their contents.
 	if err := addExternalCandidates(p, p.missingRefs, filename); err != nil {
-		return err
+		return nil, err
 	}
 
 	p.lastTry = true
-	p.fix()
-	return nil
+	fixes, _ := p.fix()
+	return fixes, nil
 }
 
 // ProcessEnv contains environment variables and settings that affect the use of