| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package source |
| |
| import ( |
| "bytes" |
| "context" |
| "fmt" |
| "go/ast" |
| "go/format" |
| "go/parser" |
| "go/token" |
| "go/types" |
| "strings" |
| "unicode" |
| |
| "golang.org/x/tools/go/ast/astutil" |
| "golang.org/x/tools/internal/analysisinternal" |
| "golang.org/x/tools/internal/lsp/protocol" |
| "golang.org/x/tools/internal/span" |
| ) |
| |
| func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) { |
| pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle) |
| if err != nil { |
| return nil, fmt.Errorf("ExtractVariable: %v", err) |
| } |
| file, _, m, _, err := pgh.Cached() |
| if err != nil { |
| return nil, err |
| } |
| spn, err := m.RangeSpan(protoRng) |
| if err != nil { |
| return nil, err |
| } |
| rng, err := spn.Range(m.Converter) |
| if err != nil { |
| return nil, err |
| } |
| if rng.Start == rng.End { |
| return nil, nil |
| } |
| path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) |
| if len(path) == 0 { |
| return nil, nil |
| } |
| fset := snapshot.View().Session().Cache().FileSet() |
| node := path[0] |
| tok := fset.File(node.Pos()) |
| if tok == nil { |
| return nil, fmt.Errorf("ExtractVariable: no token.File for %s", fh.URI()) |
| } |
| var content []byte |
| if content, err = fh.Read(); err != nil { |
| return nil, err |
| } |
| if rng.Start != node.Pos() || rng.End != node.End() { |
| return nil, nil |
| } |
| name := generateAvailableIdentifier(node.Pos(), pkg, path, file) |
| |
| var assignment string |
| expr, ok := node.(ast.Expr) |
| if !ok { |
| return nil, nil |
| } |
| // Create new AST node for extracted code. |
| switch expr.(type) { |
| case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, |
| *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr. |
| assignStmt := &ast.AssignStmt{ |
| Lhs: []ast.Expr{ast.NewIdent(name)}, |
| Tok: token.DEFINE, |
| Rhs: []ast.Expr{expr}, |
| } |
| var buf bytes.Buffer |
| if err = format.Node(&buf, fset, assignStmt); err != nil { |
| return nil, err |
| } |
| assignment = buf.String() |
| case *ast.CallExpr: // TODO: find number of return values and do according actions. |
| return nil, nil |
| default: |
| return nil, nil |
| } |
| |
| insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) |
| if insertBeforeStmt == nil { |
| return nil, nil |
| } |
| |
| // Convert token.Pos to protocol.Position. |
| rng = span.NewRange(fset, insertBeforeStmt.Pos(), insertBeforeStmt.End()) |
| spn, err = rng.Span() |
| if err != nil { |
| return nil, nil |
| } |
| beforeStmtStart, err := m.Position(spn.Start()) |
| if err != nil { |
| return nil, nil |
| } |
| stmtBeforeRng := protocol.Range{ |
| Start: beforeStmtStart, |
| End: beforeStmtStart, |
| } |
| indent := calculateIndentation(content, tok, insertBeforeStmt) |
| |
| return []protocol.TextEdit{ |
| { |
| Range: stmtBeforeRng, |
| NewText: assignment + "\n" + indent, |
| }, |
| { |
| Range: protoRng, |
| NewText: name, |
| }, |
| }, nil |
| } |
| |
| // Calculate indentation for insertion. |
| // When inserting lines of code, we must ensure that the lines have consistent |
| // formatting (i.e. the proper indentation). To do so, we observe the indentation on the |
| // line of code on which the insertion occurs. |
| func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string { |
| line := tok.Line(insertBeforeStmt.Pos()) |
| lineOffset := tok.Offset(tok.LineStart(line)) |
| stmtOffset := tok.Offset(insertBeforeStmt.Pos()) |
| return string(content[lineOffset:stmtOffset]) |
| } |
| |
| // Check for variable collision in scope. |
| func isValidName(name string, scopes []*types.Scope) bool { |
| for _, scope := range scopes { |
| if scope == nil { |
| continue |
| } |
| if scope.Lookup(name) != nil { |
| return false |
| } |
| } |
| return true |
| } |
| |
| // ExtractFunction refactors the selected block of code into a new function. It also |
| // replaces the selected block of code with a call to the extracted function. First, we |
| // manually adjust the selection range. We remove trailing and leading whitespace |
| // characters to ensure the range is precisely bounded by AST nodes. Next, we |
| // determine the variables that will be the paramters and return values of the |
| // extracted function. Lastly, we construct the call of the function and insert |
| // this call as well as the extracted function into their proper locations. |
| func ExtractFunction(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) { |
| pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle) |
| if err != nil { |
| return nil, fmt.Errorf("ExtractFunction: %v", err) |
| } |
| file, _, m, _, err := pgh.Cached() |
| if err != nil { |
| return nil, err |
| } |
| spn, err := m.RangeSpan(protoRng) |
| if err != nil { |
| return nil, err |
| } |
| rng, err := spn.Range(m.Converter) |
| if err != nil { |
| return nil, err |
| } |
| if rng.Start == rng.End { |
| return nil, nil |
| } |
| content, err := fh.Read() |
| if err != nil { |
| return nil, err |
| } |
| fset := snapshot.View().Session().Cache().FileSet() |
| tok := fset.File(file.Pos()) |
| if tok == nil { |
| return nil, fmt.Errorf("ExtractFunction: no token.File for %s", fh.URI()) |
| } |
| rng = adjustRangeForWhitespace(content, tok, rng) |
| path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) |
| if len(path) == 0 { |
| return nil, nil |
| } |
| // Node that encloses selection must be a statement. |
| // TODO: Support function extraction for an expression. |
| if _, ok := path[0].(ast.Stmt); !ok { |
| return nil, nil |
| } |
| info := pkg.GetTypesInfo() |
| if info == nil { |
| return nil, fmt.Errorf("nil TypesInfo") |
| } |
| fileScope := info.Scopes[file] |
| if fileScope == nil { |
| return nil, nil |
| } |
| pkgScope := fileScope.Parent() |
| if pkgScope == nil { |
| return nil, nil |
| } |
| // Find function enclosing the selection. |
| var outer *ast.FuncDecl |
| for _, p := range path { |
| if p, ok := p.(*ast.FuncDecl); ok { |
| outer = p |
| break |
| } |
| } |
| if outer == nil { |
| return nil, nil |
| } |
| // At the moment, we don't extract selections containing return statements, |
| // as they are more complex and need to be adjusted to maintain correctness. |
| // TODO: Support extracting and rewriting code with return statements. |
| var containsReturn bool |
| ast.Inspect(outer, func(n ast.Node) bool { |
| if n == nil { |
| return true |
| } |
| if rng.Start <= n.Pos() && n.End() <= rng.End { |
| if _, ok := n.(*ast.ReturnStmt); ok { |
| containsReturn = true |
| return false |
| } |
| } |
| return n.Pos() <= rng.End |
| }) |
| if containsReturn { |
| return nil, nil |
| } |
| // Find the nodes at the start and end of the selection. |
| var start, end ast.Node |
| ast.Inspect(outer, func(n ast.Node) bool { |
| if n == nil { |
| return true |
| } |
| if n.Pos() == rng.Start && n.End() <= rng.End { |
| start = n |
| } |
| if n.End() == rng.End && n.Pos() >= rng.Start { |
| end = n |
| } |
| return n.Pos() <= rng.End |
| }) |
| if start == nil || end == nil { |
| return nil, nil |
| } |
| |
| // Now that we have determined the correct range for the selection block, |
| // we must determine the signature of the extracted function. We will then replace |
| // the block with an assignment statement that calls the extracted function with |
| // the appropriate parameters and return values. |
| free, vars, assigned := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) |
| |
| var ( |
| params, returns []ast.Expr // used when calling the extracted function |
| paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function |
| uninitialized []types.Object // vars we will need to initialize before the call |
| ) |
| |
| // Avoid duplicates while traversing vars and uninitialzed. |
| seenVars := make(map[types.Object]ast.Expr) |
| seenUninitialized := make(map[types.Object]struct{}) |
| |
| // Each identifier in the selected block must become (1) a parameter to the |
| // extracted function, (2) a return value of the extracted function, or (3) a local |
| // variable in the extracted function. Determine the outcome(s) for each variable |
| // based on whether it is free, altered within the selected block, and used outside |
| // of the selected block. |
| for _, obj := range vars { |
| if _, ok := seenVars[obj]; ok { |
| continue |
| } |
| typ := analysisinternal.TypeExpr(fset, file, pkg.GetTypes(), obj.Type()) |
| if typ == nil { |
| return nil, fmt.Errorf("nil AST expression for type: %v", obj.Name()) |
| } |
| seenVars[obj] = typ |
| identifier := ast.NewIdent(obj.Name()) |
| // An identifier must meet two conditions to become a return value of the |
| // extracted function. (1) it must be used at least once after the |
| // selection (isUsed), and (2) its value must be initialized or reassigned |
| // within the selection (isAssigned). |
| isUsed := objUsed(obj, info, rng.End, obj.Parent().End()) |
| _, isAssigned := assigned[obj] |
| _, isFree := free[obj] |
| if isUsed && isAssigned { |
| returnTypes = append(returnTypes, &ast.Field{Type: typ}) |
| returns = append(returns, identifier) |
| if !isFree { |
| uninitialized = append(uninitialized, obj) |
| } |
| } |
| // All free variables are parameters of and passed as arguments to the |
| // extracted function. |
| if isFree { |
| params = append(params, identifier) |
| paramTypes = append(paramTypes, &ast.Field{ |
| Names: []*ast.Ident{identifier}, |
| Type: typ, |
| }) |
| } |
| } |
| |
| // Our preference is to replace the selected block with an "x, y, z := fn()" style |
| // assignment statement. We can use this style when none of the variables in the |
| // extracted function's return statement have already be initialized outside of the |
| // selected block. However, for example, if z is already defined elsewhere, we |
| // replace the selected block with: |
| // |
| // var x int |
| // var y string |
| // x, y, z = fn() |
| // |
| var initializations string |
| if len(uninitialized) > 0 && len(uninitialized) != len(returns) { |
| var declarations []ast.Stmt |
| for _, obj := range uninitialized { |
| if _, ok := seenUninitialized[obj]; ok { |
| continue |
| } |
| seenUninitialized[obj] = struct{}{} |
| valSpec := &ast.ValueSpec{ |
| Names: []*ast.Ident{ast.NewIdent(obj.Name())}, |
| Type: seenVars[obj], |
| } |
| genDecl := &ast.GenDecl{ |
| Tok: token.VAR, |
| Specs: []ast.Spec{valSpec}, |
| } |
| declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) |
| } |
| var declBuf bytes.Buffer |
| if err = format.Node(&declBuf, fset, declarations); err != nil { |
| return nil, err |
| } |
| indent := calculateIndentation(content, tok, start) |
| // Add proper indentation to each declaration. Also add formatting to |
| // the line following the last initialization to ensure that subsequent |
| // edits begin at the proper location. |
| initializations = strings.ReplaceAll(declBuf.String(), "\n", "\n"+indent) + |
| "\n" + indent |
| } |
| |
| name := generateAvailableIdentifier(start.Pos(), pkg, path, file) |
| var replace ast.Node |
| if len(returns) > 0 { |
| // If none of the variables on the left-hand side of the function call have |
| // been initialized before the selection, we can use := instead of =. |
| assignTok := token.ASSIGN |
| if len(uninitialized) == len(returns) { |
| assignTok = token.DEFINE |
| } |
| callExpr := &ast.CallExpr{ |
| Fun: ast.NewIdent(name), |
| Args: params, |
| } |
| replace = &ast.AssignStmt{ |
| Lhs: returns, |
| Tok: assignTok, |
| Rhs: []ast.Expr{callExpr}, |
| } |
| } else { |
| replace = &ast.CallExpr{ |
| Fun: ast.NewIdent(name), |
| Args: params, |
| } |
| } |
| |
| startOffset := tok.Offset(rng.Start) |
| endOffset := tok.Offset(rng.End) |
| selection := content[startOffset:endOffset] |
| // Put selection in constructed file to parse and produce block statement. We can |
| // then use the block statement to traverse and edit extracted function without |
| // altering the original file. |
| text := "package main\nfunc _() { " + string(selection) + " }" |
| extract, err := parser.ParseFile(fset, "", text, 0) |
| if err != nil { |
| return nil, err |
| } |
| if len(extract.Decls) == 0 { |
| return nil, fmt.Errorf("parsed file does not contain any declarations") |
| } |
| decl, ok := extract.Decls[0].(*ast.FuncDecl) |
| if !ok { |
| return nil, fmt.Errorf("parsed file does not contain expected function declaration") |
| } |
| // Add return statement to the end of the new function. |
| if len(returns) > 0 { |
| decl.Body.List = append(decl.Body.List, |
| &ast.ReturnStmt{Results: returns}, |
| ) |
| } |
| funcDecl := &ast.FuncDecl{ |
| Name: ast.NewIdent(name), |
| Type: &ast.FuncType{ |
| Params: &ast.FieldList{List: paramTypes}, |
| Results: &ast.FieldList{List: returnTypes}, |
| }, |
| Body: decl.Body, |
| } |
| |
| var replaceBuf, newFuncBuf bytes.Buffer |
| if err := format.Node(&replaceBuf, fset, replace); err != nil { |
| return nil, err |
| } |
| if err := format.Node(&newFuncBuf, fset, funcDecl); err != nil { |
| return nil, err |
| } |
| |
| outerStart := tok.Offset(outer.Pos()) |
| outerEnd := tok.Offset(outer.End()) |
| // We're going to replace the whole enclosing function, |
| // so preserve the text before and after the selected block. |
| before := content[outerStart:startOffset] |
| after := content[endOffset:outerEnd] |
| var fullReplacement strings.Builder |
| fullReplacement.Write(before) |
| fullReplacement.WriteString(initializations) // add any initializations, if needed |
| fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function |
| fullReplacement.Write(after) |
| fullReplacement.WriteString("\n\n") // add newlines after the enclosing function |
| fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function |
| |
| // Convert enclosing function's span.Range to protocol.Range. |
| rng = span.NewRange(fset, outer.Pos(), outer.End()) |
| spn, err = rng.Span() |
| if err != nil { |
| return nil, nil |
| } |
| startFunc, err := m.Position(spn.Start()) |
| if err != nil { |
| return nil, nil |
| } |
| endFunc, err := m.Position(spn.End()) |
| if err != nil { |
| return nil, nil |
| } |
| funcLoc := protocol.Range{ |
| Start: startFunc, |
| End: endFunc, |
| } |
| return []protocol.TextEdit{ |
| { |
| Range: funcLoc, |
| NewText: fullReplacement.String(), |
| }, |
| }, nil |
| } |
| |
| // collectFreeVars maps each identifier in the given range to whether it is "free." |
| // Given a range, a variable in that range is defined as "free" if it is declared |
| // outside of the range and neither at the file scope nor package scope. These free |
| // variables will be used as arguments in the extracted function. It also returns a |
| // list of identifiers that may need to be returned by the extracted function. |
| // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. |
| func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, |
| pkgScope *types.Scope, rng span.Range, node ast.Node) (map[types.Object]struct{}, []types.Object, map[types.Object]struct{}) { |
| // id returns non-nil if n denotes an object that is referenced by the span |
| // and defined either within the span or in the lexical environment. The bool |
| // return value acts as an indicator for where it was defined. |
| id := func(n *ast.Ident) (types.Object, bool) { |
| obj := info.Uses[n] |
| if obj == nil { |
| return info.Defs[n], false |
| } |
| if _, ok := obj.(*types.PkgName); ok { |
| return nil, false // imported package |
| } |
| if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) { |
| return nil, false // not defined in this file |
| } |
| scope := obj.Parent() |
| if scope == nil { |
| return nil, false // e.g. interface method, struct field |
| } |
| if scope == fileScope || scope == pkgScope { |
| return nil, false // defined at file or package scope |
| } |
| if rng.Start <= obj.Pos() && obj.Pos() <= rng.End { |
| return obj, false // defined within selection => not free |
| } |
| return obj, true |
| } |
| // sel returns non-nil if n denotes a selection o.x.y that is referenced by the |
| // span and defined either within the span or in the lexical environment. The bool |
| // return value acts as an indicator for where it was defined. |
| var sel func(n *ast.SelectorExpr) (types.Object, bool) |
| sel = func(n *ast.SelectorExpr) (types.Object, bool) { |
| switch x := astutil.Unparen(n.X).(type) { |
| case *ast.SelectorExpr: |
| return sel(x) |
| case *ast.Ident: |
| return id(x) |
| } |
| return nil, false |
| } |
| free := make(map[types.Object]struct{}) |
| var vars []types.Object |
| ast.Inspect(node, func(n ast.Node) bool { |
| if n == nil { |
| return true |
| } |
| if rng.Start <= n.Pos() && n.End() <= rng.End { |
| var obj types.Object |
| var isFree, prune bool |
| switch n := n.(type) { |
| case *ast.Ident: |
| obj, isFree = id(n) |
| case *ast.SelectorExpr: |
| obj, isFree = sel(n) |
| prune = true |
| } |
| if obj != nil && obj.Name() != "_" { |
| if isFree { |
| free[obj] = struct{}{} |
| } |
| vars = append(vars, obj) |
| if prune { |
| return false |
| } |
| } |
| } |
| return n.Pos() <= rng.End |
| }) |
| |
| // Find identifiers that are initialized or whose values are altered at some |
| // point in the selected block. For example, in a selected block from lines 2-4, |
| // variables x, y, and z are included in assigned. However, in a selected block |
| // from lines 3-4, only variables y and z are included in assigned. |
| // |
| // 1: var a int |
| // 2: var x int |
| // 3: y := 3 |
| // 4: z := x + a |
| // |
| assigned := make(map[types.Object]struct{}) |
| ast.Inspect(node, func(n ast.Node) bool { |
| if n == nil { |
| return true |
| } |
| if n.Pos() < rng.Start || n.End() > rng.End { |
| return n.Pos() <= rng.End |
| } |
| switch n := n.(type) { |
| case *ast.AssignStmt: |
| for _, assignment := range n.Lhs { |
| if assignment, ok := assignment.(*ast.Ident); ok { |
| obj, _ := id(assignment) |
| if obj == nil { |
| continue |
| } |
| assigned[obj] = struct{}{} |
| } |
| } |
| return false |
| case *ast.DeclStmt: |
| gen, ok := n.Decl.(*ast.GenDecl) |
| if !ok { |
| return true |
| } |
| for _, spec := range gen.Specs { |
| vSpecs, ok := spec.(*ast.ValueSpec) |
| if !ok { |
| continue |
| } |
| for _, vSpec := range vSpecs.Names { |
| obj, _ := id(vSpec) |
| if obj == nil { |
| continue |
| } |
| assigned[obj] = struct{}{} |
| } |
| } |
| return false |
| } |
| return true |
| }) |
| return free, vars, assigned |
| } |
| |
| // Adjust new function name until no collisons in scope. Possible collisions include |
| // other function and variable names. |
| func generateAvailableIdentifier(pos token.Pos, pkg Package, path []ast.Node, file *ast.File) string { |
| scopes := collectScopes(pkg, path, pos) |
| var idx int |
| name := "x0" |
| for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) { |
| idx++ |
| name = fmt.Sprintf("x%d", idx) |
| } |
| return name |
| } |
| |
| // adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or |
| // trailing whitespace characters from selection. In the following example, each line |
| // of the if statement is indented once. There are also two extra spaces after the |
| // closing bracket before the line break. |
| // |
| // \tif (true) { |
| // \t _ = 1 |
| // \t} \n |
| // |
| // By default, a valid range begins at 'if' and ends at the first whitespace character |
| // after the '}'. But, users are likely to highlight full lines rather than adjusting |
| // their cursors for whitespace. To support this use case, we must manually adjust the |
| // ranges to match the correct AST node. In this particular example, we would adjust |
| // rng.Start forward by one byte, and rng.End backwards by two bytes. |
| func adjustRangeForWhitespace(content []byte, tok *token.File, rng span.Range) span.Range { |
| offset := tok.Offset(rng.Start) |
| for offset < len(content) { |
| if !unicode.IsSpace(rune(content[offset])) { |
| break |
| } |
| // Move forwards one byte to find a non-whitespace character. |
| offset += 1 |
| } |
| rng.Start = tok.Pos(offset) |
| |
| offset = tok.Offset(rng.End) |
| for offset-1 >= 0 { |
| if !unicode.IsSpace(rune(content[offset-1])) { |
| break |
| } |
| // Move backwards one byte to find a non-whitespace character. |
| offset -= 1 |
| } |
| rng.End = tok.Pos(offset) |
| return rng |
| } |
| |
| // objUsed checks if the object is used after the selection but within |
| // the scope of the enclosing function. |
| func objUsed(obj types.Object, info *types.Info, endSel token.Pos, endScope token.Pos) bool { |
| for id, ob := range info.Uses { |
| if obj == ob && endSel < id.Pos() && id.End() <= endScope { |
| return true |
| } |
| } |
| return false |
| } |