internal/lsp: clean up some of the extract function code

This CL creates a struct that simplifies some of the extract function
logic. Also, add a test for extraction with an underscore in the
selection (Josh mentioned that this might not work, but it seems too).

Change-Id: If917614a5824e84fb79a07def3eb75f48f10a5b9
Reviewed-on: https://go-review.googlesource.com/c/tools/+/253277
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
diff --git a/internal/lsp/source/command.go b/internal/lsp/source/command.go
index 66d2f1d..2bc3c77 100644
--- a/internal/lsp/source/command.go
+++ b/internal/lsp/source/command.go
@@ -131,7 +131,7 @@
 		Title:          "Extract to function",
 		suggestedFixFn: extractFunction,
 		appliesFn: func(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) bool {
-			_, _, _, _, _, ok, _ := canExtractFunction(fset, rng, src, file, info)
+			_, ok, _ := canExtractFunction(fset, rng, src, file, info)
 			return ok
 		},
 	}
diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go
index 84679dc..411b465 100644
--- a/internal/lsp/source/extract.go
+++ b/internal/lsp/source/extract.go
@@ -180,11 +180,12 @@
 // of the function and insert this call as well as the extracted function into
 // their proper locations.
 func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
-	tok, path, rng, outer, start, ok, err := canExtractFunction(fset, rng, src, file, info)
+	p, ok, err := canExtractFunction(fset, rng, src, file, info)
 	if !ok {
 		return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
 			fset.Position(rng.Start), err)
 	}
+	tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
 	fileScope := info.Scopes[file]
 	if fileScope == nil {
 		return nil, fmt.Errorf("extractFunction: file scope is empty")
@@ -229,8 +230,10 @@
 	// we must determine the signature of the extracted function. We will then replace
 	// the block with an assignment statement that calls the extracted function with
 	// the appropriate parameters and return values.
-	free, vars, assigned, defined := collectFreeVars(
-		info, file, fileScope, pkgScope, rng, path[0])
+	variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
+	if err != nil {
+		return nil, err
+	}
 
 	var (
 		params, returns         []ast.Expr     // used when calling the extracted function
@@ -269,42 +272,38 @@
 	// variable in the extracted function. Determine the outcome(s) for each variable
 	// based on whether it is free, altered within the selected block, and used outside
 	// of the selected block.
-	for _, obj := range vars {
-		if _, ok := seenVars[obj]; ok {
+	for _, v := range variables {
+		if _, ok := seenVars[v.obj]; ok {
 			continue
 		}
-		typ := analysisinternal.TypeExpr(fset, file, pkg, obj.Type())
+		typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
 		if typ == nil {
-			return nil, fmt.Errorf("nil AST expression for type: %v", obj.Name())
+			return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
 		}
-		seenVars[obj] = typ
-		identifier := ast.NewIdent(obj.Name())
+		seenVars[v.obj] = typ
+		identifier := ast.NewIdent(v.obj.Name())
 		// An identifier must meet three conditions to become a return value of the
 		// extracted function. (1) its value must be defined or reassigned within
 		// the selection (isAssigned), (2) it must be used at least once after the
 		// selection (isUsed), and (3) its first use after the selection
 		// cannot be its own reassignment or redefinition (objOverriden).
-		if obj.Parent() == nil {
+		if v.obj.Parent() == nil {
 			return nil, fmt.Errorf("parent nil")
 		}
-		isUsed, firstUseAfter :=
-			objUsed(info, span.NewRange(fset, rng.End, obj.Parent().End()), obj)
-		_, isAssigned := assigned[obj]
-		_, isFree := free[obj]
-		if isAssigned && isUsed && !varOverridden(info, firstUseAfter, obj, isFree, outer) {
+		isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
+		if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
 			returnTypes = append(returnTypes, &ast.Field{Type: typ})
 			returns = append(returns, identifier)
-			if !isFree {
-				uninitialized = append(uninitialized, obj)
-			} else if obj.Parent().Pos() == startParent.Pos() {
+			if !v.free {
+				uninitialized = append(uninitialized, v.obj)
+			} else if v.obj.Parent().Pos() == startParent.Pos() {
 				canRedefineCount++
 			}
 		}
-		_, isDefined := defined[obj]
 		// An identifier must meet two conditions to become a parameter of the
 		// extracted function. (1) it must be free (isFree), and (2) its first
 		// use within the selection cannot be its own definition (isDefined).
-		if isFree && !isDefined {
+		if v.free && !v.defined {
 			params = append(params, identifier)
 			paramTypes = append(paramTypes, &ast.Field{
 				Names: []*ast.Ident{identifier},
@@ -409,8 +408,7 @@
 		// statements in the selection. Update the type signature of the extracted
 		// function and construct the if statement that will be inserted in the enclosing
 		// function.
-		retVars, ifReturn, err = generateReturnInfo(
-			enclosing, pkg, path, file, info, fset, rng.Start)
+		retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
 		if err != nil {
 			return nil, err
 		}
@@ -500,13 +498,11 @@
 	fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
 
 	return &analysis.SuggestedFix{
-		TextEdits: []analysis.TextEdit{
-			{
-				Pos:     outer.Pos(),
-				End:     outer.End(),
-				NewText: []byte(fullReplacement.String()),
-			},
-		},
+		TextEdits: []analysis.TextEdit{{
+			Pos:     outer.Pos(),
+			End:     outer.End(),
+			NewText: []byte(fullReplacement.String()),
+		}},
 	}, nil
 }
 
@@ -561,15 +557,28 @@
 	return parent
 }
 
+// variable describes the status of a variable within a selection.
+type variable struct {
+	obj types.Object
+
+	// free reports whether the variable is a free variable, meaning it should
+	// be a parameter to the extracted function.
+	free bool
+
+	// assigned reports whether the variable is assigned to in the selection.
+	assigned bool
+
+	// defined reports whether the variable is defined in the selection.
+	defined bool
+}
+
 // collectFreeVars maps each identifier in the given range to whether it is "free."
 // Given a range, a variable in that range is defined as "free" if it is declared
 // outside of the range and neither at the file scope nor package scope. These free
 // variables will be used as arguments in the extracted function. It also returns a
 // list of identifiers that may need to be returned by the extracted function.
 // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
-func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope,
-	pkgScope *types.Scope, rng span.Range, node ast.Node) (map[types.Object]struct{},
-	[]types.Object, map[types.Object]struct{}, map[types.Object]struct{}) {
+func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) {
 	// id returns non-nil if n denotes an object that is referenced by the span
 	// and defined either within the span or in the lexical environment. The bool
 	// return value acts as an indicator for where it was defined.
@@ -612,7 +621,7 @@
 		}
 		return nil, false
 	}
-	free := make(map[types.Object]struct{})
+	seen := make(map[types.Object]*variable)
 	firstUseIn := make(map[types.Object]token.Pos)
 	var vars []types.Object
 	ast.Inspect(node, func(n ast.Node) bool {
@@ -630,15 +639,16 @@
 				prune = true
 			}
 			if obj != nil {
-				if isFree {
-					free[obj] = struct{}{}
+				seen[obj] = &variable{
+					obj:  obj,
+					free: isFree,
 				}
+				vars = append(vars, obj)
 				// Find the first time that the object is used in the selection.
 				first, ok := firstUseIn[obj]
 				if !ok || n.Pos() < first {
 					firstUseIn[obj] = n.Pos()
 				}
-				vars = append(vars, obj)
 				if prune {
 					return false
 				}
@@ -657,8 +667,6 @@
 	// 3: y := 3
 	// 4: z := x + a
 	//
-	assigned := make(map[types.Object]struct{})
-	defined := make(map[types.Object]struct{})
 	ast.Inspect(node, func(n ast.Node) bool {
 		if n == nil {
 			return false
@@ -677,7 +685,10 @@
 				if obj == nil {
 					continue
 				}
-				assigned[obj] = struct{}{}
+				if _, ok := seen[obj]; !ok {
+					continue
+				}
+				seen[obj].assigned = true
 				if n.Tok != token.DEFINE {
 					continue
 				}
@@ -697,7 +708,10 @@
 					if referencesObj(info, expr, obj) {
 						continue
 					}
-					defined[obj] = struct{}{}
+					if _, ok := seen[obj]; !ok {
+						continue
+					}
+					seen[obj].defined = true
 					break
 				}
 			}
@@ -717,7 +731,10 @@
 					if obj == nil {
 						continue
 					}
-					assigned[obj] = struct{}{}
+					if _, ok := seen[obj]; !ok {
+						continue
+					}
+					seen[obj].assigned = true
 				}
 			}
 			return false
@@ -727,12 +744,23 @@
 			} else if obj, _ := id(ident); obj == nil {
 				return false
 			} else {
-				assigned[obj] = struct{}{}
+				if _, ok := seen[obj]; !ok {
+					return false
+				}
+				seen[obj].assigned = true
 			}
 		}
 		return true
 	})
-	return free, vars, assigned, defined
+	var variables []*variable
+	for _, obj := range vars {
+		v, ok := seen[obj]
+		if !ok {
+			return nil, fmt.Errorf("no seen types.Object for %v", obj)
+		}
+		variables = append(variables, v)
+	}
+	return variables, nil
 }
 
 // referencesObj checks whether the given object appears in the given expression.
@@ -756,29 +784,34 @@
 	return hasObj
 }
 
-// canExtractFunction reports whether the code in the given range can be extracted to a function.
-func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, info *types.Info) (*token.File, []ast.Node, span.Range, *ast.FuncDecl, ast.Node, bool, error) {
+type fnExtractParams struct {
+	tok   *token.File
+	path  []ast.Node
+	rng   span.Range
+	outer *ast.FuncDecl
+	start ast.Node
+}
+
+// canExtractFunction reports whether the code in the given range can be
+// extracted to a function.
+func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, info *types.Info) (*fnExtractParams, bool, error) {
 	if rng.Start == rng.End {
-		return nil, nil, span.Range{}, nil, nil, false,
-			fmt.Errorf("start and end are equal")
+		return nil, false, fmt.Errorf("start and end are equal")
 	}
 	tok := fset.File(file.Pos())
 	if tok == nil {
-		return nil, nil, span.Range{}, nil, nil, false,
-			fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
+		return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
 	}
 	rng = adjustRangeForWhitespace(rng, tok, src)
 	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
 	if len(path) == 0 {
-		return nil, nil, span.Range{}, nil, nil, false,
-			fmt.Errorf("no path enclosing interval")
+		return nil, false, fmt.Errorf("no path enclosing interval")
 	}
 	// Node that encloses the selection must be a statement.
 	// TODO: Support function extraction for an expression.
 	_, ok := path[0].(ast.Stmt)
 	if !ok {
-		return nil, nil, span.Range{}, nil, nil, false,
-			fmt.Errorf("node is not a statement")
+		return nil, false, fmt.Errorf("node is not a statement")
 	}
 
 	// Find the function declaration that encloses the selection.
@@ -790,7 +823,7 @@
 		}
 	}
 	if outer == nil {
-		return nil, nil, span.Range{}, nil, nil, false, fmt.Errorf("no enclosing function")
+		return nil, false, fmt.Errorf("no enclosing function")
 	}
 
 	// Find the nodes at the start and end of the selection.
@@ -799,8 +832,8 @@
 		if n == nil {
 			return false
 		}
-		// Do not override 'start' with a node that begins at the same location but is
-		// nested further from 'outer'.
+		// Do not override 'start' with a node that begins at the same location
+		// but is nested further from 'outer'.
 		if start == nil && n.Pos() == rng.Start && n.End() <= rng.End {
 			start = n
 		}
@@ -810,10 +843,15 @@
 		return n.Pos() <= rng.End
 	})
 	if start == nil || end == nil {
-		return nil, nil, span.Range{}, nil, nil, false,
-			fmt.Errorf("range does not map to AST nodes")
+		return nil, false, fmt.Errorf("range does not map to AST nodes")
 	}
-	return tok, path, rng, outer, start, true, nil
+	return &fnExtractParams{
+		tok:   tok,
+		path:  path,
+		rng:   rng,
+		outer: outer,
+		start: start,
+	}, true, nil
 }
 
 // objUsed checks if the object is used within the range. It returns the first occurence of
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go
index 0c38011..63d24df 100644
--- a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go
@@ -5,6 +5,7 @@
 	a = 5     //@mark(exSt0, "a")
 	a = a + 2 //@mark(exEn0, "2")
 	//@extractfunc(exSt0, exEn0)
-	b := a * 2
-	_ = 3 + 4
+	b := a * 2 //@mark(exB, "	b")
+	_ = 3 + 4  //@mark(exEnd, "4")
+	//@extractfunc(exB, exEnd)
 }
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden
index 04caef2..d31fcc1 100644
--- a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden
@@ -5,8 +5,9 @@
 	a := 1
 	a = fn0(a) //@mark(exEn0, "2")
 	//@extractfunc(exSt0, exEn0)
-	b := a * 2
-	_ = 3 + 4
+	b := a * 2 //@mark(exB, "	b")
+	_ = 3 + 4  //@mark(exEnd, "4")
+	//@extractfunc(exB, exEnd)
 }
 
 func fn0(a int) int {
@@ -15,3 +16,20 @@
 	return a
 }
 
+-- functionextraction_extract_args_returns_8_1 --
+package extract
+
+func _() {
+	a := 1
+	a = 5     //@mark(exSt0, "a")
+	a = a + 2 //@mark(exEn0, "2")
+	//@extractfunc(exSt0, exEn0)
+	fn0(a)  //@mark(exEnd, "4")
+	//@extractfunc(exB, exEnd)
+}
+
+func fn0(a int) {
+	b := a * 2
+	_ = 3 + 4
+}
+
diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden
index f625017..e6e82d1 100644
--- a/internal/lsp/testdata/lsp/summary.txt.golden
+++ b/internal/lsp/testdata/lsp/summary.txt.golden
@@ -13,7 +13,7 @@
 FormatCount = 6
 ImportCount = 8
 SuggestedFixCount = 38
-FunctionExtractionCount = 11
+FunctionExtractionCount = 12
 DefinitionsCount = 63
 TypeDefinitionsCount = 2
 HighlightsCount = 69