blob: 30bb4bd2648a7cada530e32e7879c449836929cd [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package source
import (
"bytes"
"context"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"go/types"
"strings"
"unicode"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/internal/analysisinternal"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/span"
)
func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) {
pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle)
if err != nil {
return nil, fmt.Errorf("ExtractVariable: %v", err)
}
file, _, m, _, err := pgh.Cached()
if err != nil {
return nil, err
}
spn, err := m.RangeSpan(protoRng)
if err != nil {
return nil, err
}
rng, err := spn.Range(m.Converter)
if err != nil {
return nil, err
}
if rng.Start == rng.End {
return nil, nil
}
path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
if len(path) == 0 {
return nil, nil
}
fset := snapshot.View().Session().Cache().FileSet()
node := path[0]
tok := fset.File(node.Pos())
if tok == nil {
return nil, fmt.Errorf("ExtractVariable: no token.File for %s", fh.URI())
}
var content []byte
if content, err = fh.Read(); err != nil {
return nil, err
}
if rng.Start != node.Pos() || rng.End != node.End() {
return nil, nil
}
name := generateAvailableIdentifier(node.Pos(), pkg, path, file)
var assignment string
expr, ok := node.(ast.Expr)
if !ok {
return nil, nil
}
// Create new AST node for extracted code.
switch expr.(type) {
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr,
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr.
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(name)},
Tok: token.DEFINE,
Rhs: []ast.Expr{expr},
}
var buf bytes.Buffer
if err = format.Node(&buf, fset, assignStmt); err != nil {
return nil, err
}
assignment = buf.String()
case *ast.CallExpr: // TODO: find number of return values and do according actions.
return nil, nil
default:
return nil, nil
}
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
if insertBeforeStmt == nil {
return nil, nil
}
// Convert token.Pos to protocol.Position.
rng = span.NewRange(fset, insertBeforeStmt.Pos(), insertBeforeStmt.End())
spn, err = rng.Span()
if err != nil {
return nil, nil
}
beforeStmtStart, err := m.Position(spn.Start())
if err != nil {
return nil, nil
}
stmtBeforeRng := protocol.Range{
Start: beforeStmtStart,
End: beforeStmtStart,
}
indent := calculateIndentation(content, tok, insertBeforeStmt)
return []protocol.TextEdit{
{
Range: stmtBeforeRng,
NewText: assignment + "\n" + indent,
},
{
Range: protoRng,
NewText: name,
},
}, nil
}
// Calculate indentation for insertion.
// When inserting lines of code, we must ensure that the lines have consistent
// formatting (i.e. the proper indentation). To do so, we observe the indentation on the
// line of code on which the insertion occurs.
func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string {
line := tok.Line(insertBeforeStmt.Pos())
lineOffset := tok.Offset(tok.LineStart(line))
stmtOffset := tok.Offset(insertBeforeStmt.Pos())
return string(content[lineOffset:stmtOffset])
}
// Check for variable collision in scope.
func isValidName(name string, scopes []*types.Scope) bool {
for _, scope := range scopes {
if scope == nil {
continue
}
if scope.Lookup(name) != nil {
return false
}
}
return true
}
// ExtractFunction refactors the selected block of code into a new function. It also
// replaces the selected block of code with a call to the extracted function. First, we
// manually adjust the selection range. We remove trailing and leading whitespace
// characters to ensure the range is precisely bounded by AST nodes. Next, we
// determine the variables that will be the paramters and return values of the
// extracted function. Lastly, we construct the call of the function and insert
// this call as well as the extracted function into their proper locations.
func ExtractFunction(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) {
pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle)
if err != nil {
return nil, fmt.Errorf("ExtractFunction: %v", err)
}
file, _, m, _, err := pgh.Cached()
if err != nil {
return nil, err
}
spn, err := m.RangeSpan(protoRng)
if err != nil {
return nil, err
}
rng, err := spn.Range(m.Converter)
if err != nil {
return nil, err
}
if rng.Start == rng.End {
return nil, nil
}
content, err := fh.Read()
if err != nil {
return nil, err
}
fset := snapshot.View().Session().Cache().FileSet()
tok := fset.File(file.Pos())
if tok == nil {
return nil, fmt.Errorf("ExtractFunction: no token.File for %s", fh.URI())
}
rng = adjustRangeForWhitespace(content, tok, rng)
path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
if len(path) == 0 {
return nil, nil
}
// Node that encloses selection must be a statement.
// TODO: Support function extraction for an expression.
if _, ok := path[0].(ast.Stmt); !ok {
return nil, nil
}
info := pkg.GetTypesInfo()
if info == nil {
return nil, fmt.Errorf("nil TypesInfo")
}
fileScope := info.Scopes[file]
if fileScope == nil {
return nil, nil
}
pkgScope := fileScope.Parent()
if pkgScope == nil {
return nil, nil
}
// Find function enclosing the selection.
var outer *ast.FuncDecl
for _, p := range path {
if p, ok := p.(*ast.FuncDecl); ok {
outer = p
break
}
}
if outer == nil {
return nil, nil
}
// At the moment, we don't extract selections containing return statements,
// as they are more complex and need to be adjusted to maintain correctness.
// TODO: Support extracting and rewriting code with return statements.
var containsReturn bool
ast.Inspect(outer, func(n ast.Node) bool {
if n == nil {
return true
}
if rng.Start <= n.Pos() && n.End() <= rng.End {
if _, ok := n.(*ast.ReturnStmt); ok {
containsReturn = true
return false
}
}
return n.Pos() <= rng.End
})
if containsReturn {
return nil, nil
}
// Find the nodes at the start and end of the selection.
var start, end ast.Node
ast.Inspect(outer, func(n ast.Node) bool {
if n == nil {
return true
}
if n.Pos() == rng.Start && n.End() <= rng.End {
start = n
}
if n.End() == rng.End && n.Pos() >= rng.Start {
end = n
}
return n.Pos() <= rng.End
})
if start == nil || end == nil {
return nil, nil
}
// Now that we have determined the correct range for the selection block,
// we must determine the signature of the extracted function. We will then replace
// the block with an assignment statement that calls the extracted function with
// the appropriate parameters and return values.
free, vars, assigned := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
var (
params, returns []ast.Expr // used when calling the extracted function
paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function
uninitialized []types.Object // vars we will need to initialize before the call
)
// Avoid duplicates while traversing vars and uninitialzed.
seenVars := make(map[types.Object]ast.Expr)
seenUninitialized := make(map[types.Object]struct{})
// Each identifier in the selected block must become (1) a parameter to the
// extracted function, (2) a return value of the extracted function, or (3) a local
// variable in the extracted function. Determine the outcome(s) for each variable
// based on whether it is free, altered within the selected block, and used outside
// of the selected block.
for _, obj := range vars {
if _, ok := seenVars[obj]; ok {
continue
}
typ := analysisinternal.TypeExpr(fset, file, pkg.GetTypes(), obj.Type())
if typ == nil {
return nil, fmt.Errorf("nil AST expression for type: %v", obj.Name())
}
seenVars[obj] = typ
identifier := ast.NewIdent(obj.Name())
// An identifier must meet two conditions to become a return value of the
// extracted function. (1) it must be used at least once after the
// selection (isUsed), and (2) its value must be initialized or reassigned
// within the selection (isAssigned).
isUsed := objUsed(obj, info, rng.End, obj.Parent().End())
_, isAssigned := assigned[obj]
_, isFree := free[obj]
if isUsed && isAssigned {
returnTypes = append(returnTypes, &ast.Field{Type: typ})
returns = append(returns, identifier)
if !isFree {
uninitialized = append(uninitialized, obj)
}
}
// All free variables are parameters of and passed as arguments to the
// extracted function.
if isFree {
params = append(params, identifier)
paramTypes = append(paramTypes, &ast.Field{
Names: []*ast.Ident{identifier},
Type: typ,
})
}
}
// Our preference is to replace the selected block with an "x, y, z := fn()" style
// assignment statement. We can use this style when none of the variables in the
// extracted function's return statement have already be initialized outside of the
// selected block. However, for example, if z is already defined elsewhere, we
// replace the selected block with:
//
// var x int
// var y string
// x, y, z = fn()
//
var initializations string
if len(uninitialized) > 0 && len(uninitialized) != len(returns) {
var declarations []ast.Stmt
for _, obj := range uninitialized {
if _, ok := seenUninitialized[obj]; ok {
continue
}
seenUninitialized[obj] = struct{}{}
valSpec := &ast.ValueSpec{
Names: []*ast.Ident{ast.NewIdent(obj.Name())},
Type: seenVars[obj],
}
genDecl := &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{valSpec},
}
declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
}
var declBuf bytes.Buffer
if err = format.Node(&declBuf, fset, declarations); err != nil {
return nil, err
}
indent := calculateIndentation(content, tok, start)
// Add proper indentation to each declaration. Also add formatting to
// the line following the last initialization to ensure that subsequent
// edits begin at the proper location.
initializations = strings.ReplaceAll(declBuf.String(), "\n", "\n"+indent) +
"\n" + indent
}
name := generateAvailableIdentifier(start.Pos(), pkg, path, file)
var replace ast.Node
if len(returns) > 0 {
// If none of the variables on the left-hand side of the function call have
// been initialized before the selection, we can use := instead of =.
assignTok := token.ASSIGN
if len(uninitialized) == len(returns) {
assignTok = token.DEFINE
}
callExpr := &ast.CallExpr{
Fun: ast.NewIdent(name),
Args: params,
}
replace = &ast.AssignStmt{
Lhs: returns,
Tok: assignTok,
Rhs: []ast.Expr{callExpr},
}
} else {
replace = &ast.CallExpr{
Fun: ast.NewIdent(name),
Args: params,
}
}
startOffset := tok.Offset(rng.Start)
endOffset := tok.Offset(rng.End)
selection := content[startOffset:endOffset]
// Put selection in constructed file to parse and produce block statement. We can
// then use the block statement to traverse and edit extracted function without
// altering the original file.
text := "package main\nfunc _() { " + string(selection) + " }"
extract, err := parser.ParseFile(fset, "", text, 0)
if err != nil {
return nil, err
}
if len(extract.Decls) == 0 {
return nil, fmt.Errorf("parsed file does not contain any declarations")
}
decl, ok := extract.Decls[0].(*ast.FuncDecl)
if !ok {
return nil, fmt.Errorf("parsed file does not contain expected function declaration")
}
// Add return statement to the end of the new function.
if len(returns) > 0 {
decl.Body.List = append(decl.Body.List,
&ast.ReturnStmt{Results: returns},
)
}
funcDecl := &ast.FuncDecl{
Name: ast.NewIdent(name),
Type: &ast.FuncType{
Params: &ast.FieldList{List: paramTypes},
Results: &ast.FieldList{List: returnTypes},
},
Body: decl.Body,
}
var replaceBuf, newFuncBuf bytes.Buffer
if err := format.Node(&replaceBuf, fset, replace); err != nil {
return nil, err
}
if err := format.Node(&newFuncBuf, fset, funcDecl); err != nil {
return nil, err
}
outerStart := tok.Offset(outer.Pos())
outerEnd := tok.Offset(outer.End())
// We're going to replace the whole enclosing function,
// so preserve the text before and after the selected block.
before := content[outerStart:startOffset]
after := content[endOffset:outerEnd]
var fullReplacement strings.Builder
fullReplacement.Write(before)
fullReplacement.WriteString(initializations) // add any initializations, if needed
fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
fullReplacement.Write(after)
fullReplacement.WriteString("\n\n") // add newlines after the enclosing function
fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
// Convert enclosing function's span.Range to protocol.Range.
rng = span.NewRange(fset, outer.Pos(), outer.End())
spn, err = rng.Span()
if err != nil {
return nil, nil
}
startFunc, err := m.Position(spn.Start())
if err != nil {
return nil, nil
}
endFunc, err := m.Position(spn.End())
if err != nil {
return nil, nil
}
funcLoc := protocol.Range{
Start: startFunc,
End: endFunc,
}
return []protocol.TextEdit{
{
Range: funcLoc,
NewText: fullReplacement.String(),
},
}, nil
}
// collectFreeVars maps each identifier in the given range to whether it is "free."
// Given a range, a variable in that range is defined as "free" if it is declared
// outside of the range and neither at the file scope nor package scope. These free
// variables will be used as arguments in the extracted function. It also returns a
// list of identifiers that may need to be returned by the extracted function.
// Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope,
pkgScope *types.Scope, rng span.Range, node ast.Node) (map[types.Object]struct{}, []types.Object, map[types.Object]struct{}) {
// id returns non-nil if n denotes an object that is referenced by the span
// and defined either within the span or in the lexical environment. The bool
// return value acts as an indicator for where it was defined.
id := func(n *ast.Ident) (types.Object, bool) {
obj := info.Uses[n]
if obj == nil {
return info.Defs[n], false
}
if _, ok := obj.(*types.PkgName); ok {
return nil, false // imported package
}
if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) {
return nil, false // not defined in this file
}
scope := obj.Parent()
if scope == nil {
return nil, false // e.g. interface method, struct field
}
if scope == fileScope || scope == pkgScope {
return nil, false // defined at file or package scope
}
if rng.Start <= obj.Pos() && obj.Pos() <= rng.End {
return obj, false // defined within selection => not free
}
return obj, true
}
// sel returns non-nil if n denotes a selection o.x.y that is referenced by the
// span and defined either within the span or in the lexical environment. The bool
// return value acts as an indicator for where it was defined.
var sel func(n *ast.SelectorExpr) (types.Object, bool)
sel = func(n *ast.SelectorExpr) (types.Object, bool) {
switch x := astutil.Unparen(n.X).(type) {
case *ast.SelectorExpr:
return sel(x)
case *ast.Ident:
return id(x)
}
return nil, false
}
free := make(map[types.Object]struct{})
var vars []types.Object
ast.Inspect(node, func(n ast.Node) bool {
if n == nil {
return true
}
if rng.Start <= n.Pos() && n.End() <= rng.End {
var obj types.Object
var isFree, prune bool
switch n := n.(type) {
case *ast.Ident:
obj, isFree = id(n)
case *ast.SelectorExpr:
obj, isFree = sel(n)
prune = true
}
if obj != nil && obj.Name() != "_" {
if isFree {
free[obj] = struct{}{}
}
vars = append(vars, obj)
if prune {
return false
}
}
}
return n.Pos() <= rng.End
})
// Find identifiers that are initialized or whose values are altered at some
// point in the selected block. For example, in a selected block from lines 2-4,
// variables x, y, and z are included in assigned. However, in a selected block
// from lines 3-4, only variables y and z are included in assigned.
//
// 1: var a int
// 2: var x int
// 3: y := 3
// 4: z := x + a
//
assigned := make(map[types.Object]struct{})
ast.Inspect(node, func(n ast.Node) bool {
if n == nil {
return true
}
if n.Pos() < rng.Start || n.End() > rng.End {
return n.Pos() <= rng.End
}
switch n := n.(type) {
case *ast.AssignStmt:
for _, assignment := range n.Lhs {
if assignment, ok := assignment.(*ast.Ident); ok {
obj, _ := id(assignment)
if obj == nil {
continue
}
assigned[obj] = struct{}{}
}
}
return false
case *ast.DeclStmt:
gen, ok := n.Decl.(*ast.GenDecl)
if !ok {
return true
}
for _, spec := range gen.Specs {
vSpecs, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for _, vSpec := range vSpecs.Names {
obj, _ := id(vSpec)
if obj == nil {
continue
}
assigned[obj] = struct{}{}
}
}
return false
}
return true
})
return free, vars, assigned
}
// Adjust new function name until no collisons in scope. Possible collisions include
// other function and variable names.
func generateAvailableIdentifier(pos token.Pos, pkg Package, path []ast.Node, file *ast.File) string {
scopes := collectScopes(pkg, path, pos)
var idx int
name := "x0"
for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
idx++
name = fmt.Sprintf("x%d", idx)
}
return name
}
// adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or
// trailing whitespace characters from selection. In the following example, each line
// of the if statement is indented once. There are also two extra spaces after the
// closing bracket before the line break.
//
// \tif (true) {
// \t _ = 1
// \t} \n
//
// By default, a valid range begins at 'if' and ends at the first whitespace character
// after the '}'. But, users are likely to highlight full lines rather than adjusting
// their cursors for whitespace. To support this use case, we must manually adjust the
// ranges to match the correct AST node. In this particular example, we would adjust
// rng.Start forward by one byte, and rng.End backwards by two bytes.
func adjustRangeForWhitespace(content []byte, tok *token.File, rng span.Range) span.Range {
offset := tok.Offset(rng.Start)
for offset < len(content) {
if !unicode.IsSpace(rune(content[offset])) {
break
}
// Move forwards one byte to find a non-whitespace character.
offset += 1
}
rng.Start = tok.Pos(offset)
offset = tok.Offset(rng.End)
for offset-1 >= 0 {
if !unicode.IsSpace(rune(content[offset-1])) {
break
}
// Move backwards one byte to find a non-whitespace character.
offset -= 1
}
rng.End = tok.Pos(offset)
return rng
}
// objUsed checks if the object is used after the selection but within
// the scope of the enclosing function.
func objUsed(obj types.Object, info *types.Info, endSel token.Pos, endScope token.Pos) bool {
for id, ob := range info.Uses {
if obj == ob && endSel < id.Pos() && id.End() <= endScope {
return true
}
}
return false
}