internal/lsp: extract highlighted selection to variable

I add a code action that triggers upon request of the user. A variable
name is generated manually for the extracted code because the LSP does
not support a user's ability to provide a name.

Change-Id: Id1ec19b49562b7cfbc2cd416378bec9bd021d82f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/240182
Run-TryBot: Josh Baum <joshbaum@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go
index c7f5cd5..311fbfd 100644
--- a/internal/analysisinternal/analysis.go
+++ b/internal/analysisinternal/analysis.go
@@ -124,3 +124,77 @@
 	NoResultValues TypeErrorPass = "noresultvalues"
 	UndeclaredName TypeErrorPass = "undeclaredname"
 )
+
+// StmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable.
+// Some examples:
+//
+// Basic Example:
+// z := 1
+// y := z + x
+// If x is undeclared, then this function would return `y := z + x`, so that we
+// can insert `x := ` on the line before `y := z + x`.
+//
+// If stmt example:
+// if z == 1 {
+// } else if z == y {}
+// If y is undeclared, then this function would return `if z == 1 {`, because we cannot
+// insert a statement between an if and an else if statement. As a result, we need to find
+// the top of the if chain to insert `y := ` before.
+func StmtToInsertVarBefore(path []ast.Node) ast.Stmt {
+	enclosingIndex := -1
+	for i, p := range path {
+		if _, ok := p.(ast.Stmt); ok {
+			enclosingIndex = i
+			break
+		}
+	}
+	if enclosingIndex == -1 {
+		return nil
+	}
+	enclosingStmt := path[enclosingIndex]
+	switch enclosingStmt.(type) {
+	case *ast.IfStmt:
+		// The enclosingStmt is inside of the if declaration,
+		// We need to check if we are in an else-if stmt and
+		// get the base if statement.
+		return baseIfStmt(path, enclosingIndex)
+	case *ast.CaseClause:
+		// Get the enclosing switch stmt if the enclosingStmt is
+		// inside of the case statement.
+		for i := enclosingIndex + 1; i < len(path); i++ {
+			if node, ok := path[i].(*ast.SwitchStmt); ok {
+				return node
+			} else if node, ok := path[i].(*ast.TypeSwitchStmt); ok {
+				return node
+			}
+		}
+	}
+	if len(path) <= enclosingIndex+1 {
+		return enclosingStmt.(ast.Stmt)
+	}
+	// Check if the enclosing statement is inside another node.
+	switch expr := path[enclosingIndex+1].(type) {
+	case *ast.IfStmt:
+		// Get the base if statement.
+		return baseIfStmt(path, enclosingIndex+1)
+	case *ast.ForStmt:
+		if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
+			return expr
+		}
+	}
+	return enclosingStmt.(ast.Stmt)
+}
+
+// baseIfStmt walks up the if/else-if chain until we get to
+// the top of the current if chain.
+func baseIfStmt(path []ast.Node, index int) ast.Stmt {
+	stmt := path[index]
+	for i := index + 1; i < len(path); i++ {
+		if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt {
+			stmt = node
+			continue
+		}
+		break
+	}
+	return stmt.(ast.Stmt)
+}
diff --git a/internal/lsp/analysis/undeclaredname/undeclared.go b/internal/lsp/analysis/undeclaredname/undeclared.go
index 84a96d2..fa239be 100644
--- a/internal/lsp/analysis/undeclaredname/undeclared.go
+++ b/internal/lsp/analysis/undeclaredname/undeclared.go
@@ -70,20 +70,9 @@
 		if _, ok := path[1].(*ast.CallExpr); ok {
 			continue
 		}
-		// Get the enclosing statement.
-		enclosingIndex := -1
-		for i, p := range path {
-			if _, ok := p.(ast.Stmt); ok && enclosingIndex == -1 {
-				enclosingIndex = i
-				break
-			}
-		}
-		if enclosingIndex == -1 {
-			continue
-		}
 
 		// Get the place to insert the new statement.
-		insertBeforeStmt := stmtToInsertVarBefore(path, enclosingIndex)
+		insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
 		if insertBeforeStmt == nil {
 			continue
 		}
@@ -121,70 +110,6 @@
 	return nil, nil
 }
 
-// stmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable.
-// Some examples:
-//
-// Basic Example:
-// z := 1
-// y := z + x
-// If x is undeclared, then this function would return `y := z + x`, so that we
-// can insert `x := ` on the line before `y := z + x`.
-//
-// If stmt example:
-// if z == 1 {
-// } else if z == y {}
-// If y is undeclared, then this function would return `if z == 1 {`, because we cannot
-// insert a statement between an if and an else if statement. As a result, we need to find
-// the top of the if chain to insert `y := ` before.
-func stmtToInsertVarBefore(path []ast.Node, enclosingIndex int) ast.Stmt {
-	enclosingStmt := path[enclosingIndex]
-	switch enclosingStmt.(type) {
-	case *ast.IfStmt:
-		// The enclosingStmt is inside of the if declaration,
-		// We need to check if we are in an else-if stmt and
-		// get the base if statement.
-		return baseIfStmt(path, enclosingIndex)
-	case *ast.CaseClause:
-		// Get the enclosing switch stmt if the enclosingStmt is
-		// inside of the case statement.
-		for i := enclosingIndex + 1; i < len(path); i++ {
-			if node, ok := path[i].(*ast.SwitchStmt); ok {
-				return node
-			} else if node, ok := path[i].(*ast.TypeSwitchStmt); ok {
-				return node
-			}
-		}
-	}
-	if len(path) <= enclosingIndex+1 {
-		return enclosingStmt.(ast.Stmt)
-	}
-	// Check if the enclosing statement is inside another node.
-	switch expr := path[enclosingIndex+1].(type) {
-	case *ast.IfStmt:
-		// Get the base if statement.
-		return baseIfStmt(path, enclosingIndex+1)
-	case *ast.ForStmt:
-		if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
-			return expr
-		}
-	}
-	return enclosingStmt.(ast.Stmt)
-}
-
-// baseIfStmt walks up the if/else-if chain until we get to
-// the top of the current if chain.
-func baseIfStmt(path []ast.Node, index int) ast.Stmt {
-	stmt := path[index]
-	for i := index + 1; i < len(path); i++ {
-		if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt {
-			stmt = node
-			continue
-		}
-		break
-	}
-	return stmt.(ast.Stmt)
-}
-
 func FixesError(msg string) bool {
 	return strings.HasPrefix(msg, undeclaredNamePrefix)
 }
diff --git a/internal/lsp/cmd/suggested_fix.go b/internal/lsp/cmd/suggested_fix.go
index d80e066..5738747 100644
--- a/internal/lsp/cmd/suggested_fix.go
+++ b/internal/lsp/cmd/suggested_fix.go
@@ -76,6 +76,10 @@
 		}
 	}
 
+	rng, err := file.mapper.Range(from)
+	if err != nil {
+		return err
+	}
 	p := protocol.CodeActionParams{
 		TextDocument: protocol.TextDocumentIdentifier{
 			URI: protocol.URIFromSpanURI(uri),
@@ -84,6 +88,7 @@
 			Only:        codeActionKinds,
 			Diagnostics: file.diagnostics,
 		},
+		Range: rng,
 	}
 	actions, err := conn.CodeAction(ctx, &p)
 	if err != nil {
@@ -118,6 +123,15 @@
 				break
 			}
 		}
+
+		// If suggested fix is not a diagnostic, still must collect edits.
+		if len(a.Diagnostics) == 0 {
+			for _, c := range a.Edit.DocumentChanges {
+				if fileURI(c.TextDocument.URI) == uri {
+					edits = append(edits, c.Edits...)
+				}
+			}
+		}
 	}
 
 	sedits, err := source.FromProtocolEdits(file.mapper, edits)
diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go
index f611fc9..24151fa 100644
--- a/internal/lsp/code_action.go
+++ b/internal/lsp/code_action.go
@@ -162,6 +162,13 @@
 			}
 			codeActions = append(codeActions, fixes...)
 		}
+		if wanted[protocol.RefactorExtract] {
+			fixes, err := extractionFixes(ctx, snapshot, ph, uri, params.Range)
+			if err != nil {
+				return nil, err
+			}
+			codeActions = append(codeActions, fixes...)
+		}
 	default:
 		// Unsupported file kind for a code action.
 		return nil, nil
@@ -385,6 +392,29 @@
 	return codeActions, nil
 }
 
+func extractionFixes(ctx context.Context, snapshot source.Snapshot, ph source.PackageHandle, uri span.URI, rng protocol.Range) ([]protocol.CodeAction, error) {
+	fh, err := snapshot.GetFile(ctx, uri)
+	if err != nil {
+		return nil, nil
+	}
+	edits, err := source.ExtractVariable(ctx, snapshot, fh, rng)
+	if err != nil {
+		return nil, err
+	}
+	if len(edits) == 0 {
+		return nil, nil
+	}
+	return []protocol.CodeAction{
+		{
+			Title: "Extract to variable",
+			Kind:  protocol.RefactorExtract,
+			Edit: protocol.WorkspaceEdit{
+				DocumentChanges: documentChanges(fh, edits),
+			},
+		},
+	}, nil
+}
+
 func documentChanges(fh source.FileHandle, edits []protocol.TextEdit) []protocol.TextDocumentEdit {
 	return []protocol.TextDocumentEdit{
 		{
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index cd5757e..af84517 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -955,21 +955,7 @@
 
 // lexical finds completions in the lexical environment.
 func (c *completer) lexical(ctx context.Context) error {
-	var scopes []*types.Scope // scopes[i], where i<len(path), is the possibly nil Scope of path[i].
-	for _, n := range c.path {
-		// Include *FuncType scope if pos is inside the function body.
-		switch node := n.(type) {
-		case *ast.FuncDecl:
-			if node.Body != nil && nodeContains(node.Body, c.pos) {
-				n = node.Type
-			}
-		case *ast.FuncLit:
-			if node.Body != nil && nodeContains(node.Body, c.pos) {
-				n = node.Type
-			}
-		}
-		scopes = append(scopes, c.pkg.GetTypesInfo().Scopes[n])
-	}
+	scopes := collectScopes(c.pkg, c.path, c.pos)
 	scopes = append(scopes, c.pkg.GetTypes().Scope(), types.Universe)
 
 	var (
@@ -1106,6 +1092,26 @@
 	return nil
 }
 
+func collectScopes(pkg Package, path []ast.Node, pos token.Pos) []*types.Scope {
+	// scopes[i], where i<len(path), is the possibly nil Scope of path[i].
+	var scopes []*types.Scope
+	for _, n := range path {
+		// Include *FuncType scope if pos is inside the function body.
+		switch node := n.(type) {
+		case *ast.FuncDecl:
+			if node.Body != nil && nodeContains(node.Body, pos) {
+				n = node.Type
+			}
+		case *ast.FuncLit:
+			if node.Body != nil && nodeContains(node.Body, pos) {
+				n = node.Type
+			}
+		}
+		scopes = append(scopes, pkg.GetTypesInfo().Scopes[n])
+	}
+	return scopes
+}
+
 func (c *completer) unimportedPackages(ctx context.Context, seen map[string]struct{}) error {
 	var prefix string
 	if c.surrounding != nil {
diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go
new file mode 100644
index 0000000..fb53e61
--- /dev/null
+++ b/internal/lsp/source/extract.go
@@ -0,0 +1,140 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package source
+
+import (
+	"bytes"
+	"context"
+	"fmt"
+	"go/ast"
+	"go/format"
+	"go/token"
+	"go/types"
+
+	"golang.org/x/tools/go/ast/astutil"
+	"golang.org/x/tools/internal/analysisinternal"
+	"golang.org/x/tools/internal/lsp/protocol"
+	"golang.org/x/tools/internal/span"
+)
+
+func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) {
+	pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle)
+	if err != nil {
+		return nil, fmt.Errorf("ExtractVariable: %v", err)
+	}
+	file, _, m, _, err := pgh.Cached()
+	if err != nil {
+		return nil, err
+	}
+	spn, err := m.RangeSpan(protoRng)
+	if err != nil {
+		return nil, err
+	}
+	rng, err := spn.Range(m.Converter)
+	if err != nil {
+		return nil, err
+	}
+	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
+	if len(path) == 0 {
+		return nil, nil
+	}
+	fset := snapshot.View().Session().Cache().FileSet()
+	node := path[0]
+	tok := fset.File(node.Pos())
+	if tok == nil {
+		return nil, fmt.Errorf("ExtractVariable: no token.File for %s", fh.URI())
+	}
+	var content []byte
+	if content, err = fh.Read(); err != nil {
+		return nil, err
+	}
+	if rng.Start != node.Pos() || rng.End != node.End() {
+		return nil, nil
+	}
+
+	// Adjust new variable name until no collisons in scope.
+	scopes := collectScopes(pkg, path, node.Pos())
+	name := "x0"
+	idx := 0
+	for !isValidName(name, scopes) {
+		idx++
+		name = fmt.Sprintf("x%d", idx)
+	}
+
+	var assignment string
+	expr, ok := node.(ast.Expr)
+	if !ok {
+		return nil, nil
+	}
+	// Create new AST node for extracted code
+	switch expr.(type) {
+	case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr,
+		*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr
+		assignStmt := &ast.AssignStmt{
+			Lhs: []ast.Expr{ast.NewIdent(name)},
+			Tok: token.DEFINE,
+			Rhs: []ast.Expr{expr},
+		}
+		var buf bytes.Buffer
+		if err = format.Node(&buf, fset, assignStmt); err != nil {
+			return nil, err
+		}
+		assignment = buf.String()
+	case *ast.CallExpr: // TODO: find number of return values and do according actions.
+		return nil, nil
+	default:
+		return nil, nil
+	}
+
+	insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
+	if insertBeforeStmt == nil {
+		return nil, nil
+	}
+
+	// Convert token.Pos to protcol.Position
+	rng = span.NewRange(fset, insertBeforeStmt.Pos(), insertBeforeStmt.End())
+	spn, err = rng.Span()
+	if err != nil {
+		return nil, nil
+	}
+	beforeStmtStart, err := m.Position(spn.Start())
+	if err != nil {
+		return nil, nil
+	}
+	stmtBeforeRng := protocol.Range{
+		Start: beforeStmtStart,
+		End:   beforeStmtStart,
+	}
+
+	// Calculate indentation for insertion
+	line := tok.Line(insertBeforeStmt.Pos())
+	lineOffset := tok.Offset(tok.LineStart(line))
+	stmtOffset := tok.Offset(insertBeforeStmt.Pos())
+	indent := content[lineOffset:stmtOffset] // space between these is indentation.
+
+	return []protocol.TextEdit{
+		{
+			Range:   stmtBeforeRng,
+			NewText: assignment + "\n" + string(indent),
+		},
+		{
+			Range:   protoRng,
+			NewText: name,
+		},
+	}, nil
+}
+
+// Check for variable collision in scope.
+func isValidName(name string, scopes []*types.Scope) bool {
+	for _, scope := range scopes {
+		if scope == nil {
+			continue
+		}
+		if scope.Lookup(name) != nil {
+			return false
+		}
+	}
+	return true
+}
diff --git a/internal/lsp/source/options.go b/internal/lsp/source/options.go
index cde7a7a..51c0199 100644
--- a/internal/lsp/source/options.go
+++ b/internal/lsp/source/options.go
@@ -94,6 +94,7 @@
 					protocol.SourceOrganizeImports: true,
 					protocol.QuickFix:              true,
 					protocol.RefactorRewrite:       true,
+					protocol.RefactorExtract:       true,
 				},
 				Mod: {
 					protocol.SourceOrganizeImports: true,
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go
new file mode 100644
index 0000000..c49e5d6
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go
@@ -0,0 +1,6 @@
+package extract
+
+func _() {
+	var _ = 1 + 2 //@suggestedfix("1", "refactor.extract")
+	var _ = 3 + 4 //@suggestedfix("3 + 4", "refactor.extract")
+}
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go.golden
new file mode 100644
index 0000000..202d378
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go.golden
@@ -0,0 +1,18 @@
+-- suggestedfix_extract_basic_lit_4_10 --
+package extract
+
+func _() {
+	x0 := 1
+	var _ = x0 + 2 //@suggestedfix("1", "refactor.extract")
+	var _ = 3 + 4 //@suggestedfix("3 + 4", "refactor.extract")
+}
+
+-- suggestedfix_extract_basic_lit_5_10 --
+package extract
+
+func _() {
+	var _ = 1 + 2 //@suggestedfix("1", "refactor.extract")
+	x0 := 3 + 4
+	var _ = x0 //@suggestedfix("3 + 4", "refactor.extract")
+}
+
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go
new file mode 100644
index 0000000..5dfcc36
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go
@@ -0,0 +1,13 @@
+package extract
+
+import "go/ast"
+
+func _() {
+	x0 := 0
+	if true {
+		y := ast.CompositeLit{} //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
+	}
+	if true {
+		x1 := !false //@suggestedfix("!false", "refactor.extract")
+	}
+}
diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go.golden
new file mode 100644
index 0000000..4ded99a
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go.golden
@@ -0,0 +1,32 @@
+-- suggestedfix_extract_scope_11_9 --
+package extract
+
+import "go/ast"
+
+func _() {
+	x0 := 0
+	if true {
+		y := ast.CompositeLit{} //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
+	}
+	if true {
+		x2 := !false
+		x1 := x2 //@suggestedfix("!false", "refactor.extract")
+	}
+}
+
+-- suggestedfix_extract_scope_8_8 --
+package extract
+
+import "go/ast"
+
+func _() {
+	x0 := 0
+	if true {
+		x1 := ast.CompositeLit{}
+		y := x1 //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
+	}
+	if true {
+		x1 := !false //@suggestedfix("!false", "refactor.extract")
+	}
+}
+
diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden
index a832d96..1563cb0 100644
--- a/internal/lsp/testdata/lsp/summary.txt.golden
+++ b/internal/lsp/testdata/lsp/summary.txt.golden
@@ -11,7 +11,7 @@
 FoldingRangesCount = 2
 FormatCount = 6
 ImportCount = 8
-SuggestedFixCount = 14
+SuggestedFixCount = 18
 DefinitionsCount = 53
 TypeDefinitionsCount = 2
 HighlightsCount = 69
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 850ec22..4bbdda6 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -217,6 +217,7 @@
 			protocol.SourceOrganizeImports: true,
 			protocol.QuickFix:              true,
 			protocol.RefactorRewrite:       true,
+			protocol.RefactorExtract:       true,
 			protocol.SourceFixAll:          true,
 		},
 		source.Mod: {