blob: f99db705d70c60e19aa0dc4e9d21ca8a107cb13d [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package inline_test
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/gob"
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
"testing"
"unsafe"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/expect"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/diff"
"golang.org/x/tools/internal/refactor/inline"
"golang.org/x/tools/internal/testenv"
"golang.org/x/tools/txtar"
)
// TestData executes test scenarios specified by files in testdata/*.txtar.
func TestData(t *testing.T) {
testenv.NeedsGoPackages(t)
files, err := filepath.Glob("testdata/*.txtar")
if err != nil {
t.Fatal(err)
}
for _, file := range files {
file := file
t.Run(filepath.Base(file), func(t *testing.T) {
t.Parallel()
// Extract archive to temporary tree.
ar, err := txtar.ParseFile(file)
if err != nil {
t.Fatal(err)
}
dir := t.TempDir()
if err := extractTxtar(ar, dir); err != nil {
t.Fatal(err)
}
// Load packages.
cfg := &packages.Config{
Dir: dir,
Mode: packages.LoadAllSyntax,
Env: append(os.Environ(),
"GO111MODULES=on",
"GOPATH=",
"GOWORK=off",
"GOPROXY=off"),
}
pkgs, err := packages.Load(cfg, "./...")
if err != nil {
t.Errorf("Load: %v", err)
}
// Report parse/type errors; they may be benign.
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
for _, err := range pkg.Errors {
t.Log(err)
}
})
// Process @inline notes in comments in initial packages.
for _, pkg := range pkgs {
for _, file := range pkg.Syntax {
// Read file content (for @inline regexp, and inliner).
content, err := os.ReadFile(pkg.Fset.File(file.Pos()).Name())
if err != nil {
t.Error(err)
continue
}
// Read and process @inline notes.
notes, err := expect.ExtractGo(pkg.Fset, file)
if err != nil {
t.Errorf("parsing notes in %q: %v", pkg.Fset.File(file.Pos()).Name(), err)
continue
}
for _, note := range notes {
posn := pkg.Fset.Position(note.Pos)
if note.Name != "inline" {
t.Errorf("%s: invalid marker @%s", posn, note.Name)
continue
}
if nargs := len(note.Args); nargs != 2 {
t.Errorf("@inline: want 2 args, got %d", nargs)
continue
}
pattern, ok := note.Args[0].(*regexp.Regexp)
if !ok {
t.Errorf("%s: @inline(rx, want): want regular expression rx", posn)
continue
}
// want is a []byte (success) or *Regexp (failure)
var want any
switch x := note.Args[1].(type) {
case string, expect.Identifier:
for _, file := range ar.Files {
if file.Name == fmt.Sprint(x) {
want = file.Data
break
}
}
if want == nil {
t.Errorf("%s: @inline(rx, want): archive entry %q not found", posn, x)
continue
}
case *regexp.Regexp:
want = x
default:
t.Errorf("%s: @inline(rx, want): want file name (to assert success) or error message regexp (to assert failure)", posn)
continue
}
if err := doInlineNote(t.Logf, pkg, file, content, pattern, posn, want); err != nil {
t.Errorf("%s: @inline(%v, %v): %v", posn, note.Args[0], note.Args[1], err)
continue
}
}
}
}
})
}
}
// doInlineNote executes an assertion specified by a single
// @inline(re"pattern", want) note in a comment. It finds the first
// match of regular expression 'pattern' on the same line, finds the
// innermost enclosing CallExpr, and inlines it.
//
// Finally it checks that, on success, the transformed file is equal
// to want (a []byte), or on failure that the error message matches
// want (a *Regexp).
func doInlineNote(logf func(string, ...any), pkg *packages.Package, file *ast.File, content []byte, pattern *regexp.Regexp, posn token.Position, want any) error {
// Find extent of pattern match within commented line.
var startPos, endPos token.Pos
{
tokFile := pkg.Fset.File(file.Pos())
lineStartOffset := int(tokFile.LineStart(posn.Line)) - tokFile.Base()
line := content[lineStartOffset:]
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line = line[:i]
}
matches := pattern.FindSubmatchIndex(line)
var start, end int // offsets
switch len(matches) {
case 2:
// no subgroups: return the range of the regexp expression
start, end = matches[0], matches[1]
case 4:
// one subgroup: return its range
start, end = matches[2], matches[3]
default:
return fmt.Errorf("invalid location regexp %q: expect either 0 or 1 subgroups, got %d",
pattern, len(matches)/2-1)
}
startPos = tokFile.Pos(lineStartOffset + start)
endPos = tokFile.Pos(lineStartOffset + end)
}
// Find innermost call enclosing the pattern match.
var caller *inline.Caller
{
path, _ := astutil.PathEnclosingInterval(file, startPos, endPos)
for _, n := range path {
if call, ok := n.(*ast.CallExpr); ok {
caller = &inline.Caller{
Fset: pkg.Fset,
Types: pkg.Types,
Info: pkg.TypesInfo,
File: file,
Call: call,
Content: content,
}
break
}
}
if caller == nil {
return fmt.Errorf("no enclosing call")
}
}
// Is it a static function call?
fn := typeutil.StaticCallee(caller.Info, caller.Call)
if fn == nil {
return fmt.Errorf("cannot inline: not a static call")
}
// Find callee function.
var (
calleePkg *packages.Package
calleeDecl *ast.FuncDecl
)
{
var same func(*ast.FuncDecl) bool
// Is the call within the package?
if fn.Pkg() == caller.Types {
calleePkg = pkg // same as caller
same = func(decl *ast.FuncDecl) bool {
return decl.Name.Pos() == fn.Pos()
}
} else {
// Different package. Load it now.
// (The primary load loaded all dependencies,
// but we choose to load it again, with
// a distinct token.FileSet and types.Importer,
// to keep the implementation honest.)
cfg := &packages.Config{
// TODO(adonovan): get the original module root more cleanly
Dir: filepath.Dir(filepath.Dir(pkg.GoFiles[0])),
Fset: token.NewFileSet(),
Mode: packages.LoadSyntax,
}
roots, err := packages.Load(cfg, fn.Pkg().Path())
if err != nil {
return fmt.Errorf("loading callee package: %v", err)
}
if packages.PrintErrors(roots) > 0 {
return fmt.Errorf("callee package had errors") // (see log)
}
calleePkg = roots[0]
posn := caller.Fset.Position(fn.Pos()) // callee posn wrt caller package
same = func(decl *ast.FuncDecl) bool {
// We can't rely on columns in export data:
// some variants replace it with 1.
// We can't expect file names to have the same prefix.
// export data for go1.20 std packages have $GOROOT written in
// them, so how are we supposed to find the source? Yuck!
// Ugh. need to samefile? Nope $GOROOT just won't work
// This is highly client specific anyway.
posn2 := calleePkg.Fset.Position(decl.Name.Pos())
return posn.Filename == posn2.Filename &&
posn.Line == posn2.Line
}
}
for _, file := range calleePkg.Syntax {
for _, decl := range file.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok && same(decl) {
calleeDecl = decl
goto found
}
}
}
return fmt.Errorf("can't find FuncDecl for callee") // can't happen?
found:
}
// Do the inlining. For the purposes of the test,
// AnalyzeCallee and Inline are a single operation.
got, err := func() ([]byte, error) {
filename := calleePkg.Fset.File(calleeDecl.Pos()).Name()
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
callee, err := inline.AnalyzeCallee(
calleePkg.Fset,
calleePkg.Types,
calleePkg.TypesInfo,
calleeDecl,
content)
if err != nil {
return nil, err
}
if err := checkTranscode(callee); err != nil {
return nil, err
}
check := checkNoMutation(caller.File)
defer check()
return inline.Inline(logf, caller, callee)
}()
if err != nil {
if wantRE, ok := want.(*regexp.Regexp); ok {
if !wantRE.MatchString(err.Error()) {
return fmt.Errorf("Inline failed with wrong error: %v (want error matching %q)", err, want)
}
return nil // expected error
}
return fmt.Errorf("Inline failed: %v", err) // success was expected
}
// Inline succeeded.
if want, ok := want.([]byte); ok {
got = append(bytes.TrimSpace(got), '\n')
want = append(bytes.TrimSpace(want), '\n')
if diff := diff.Unified("want", "got", string(want), string(got)); diff != "" {
return fmt.Errorf("Inline returned wrong output:\n%s\nWant:\n%s\nDiff:\n%s",
got, want, diff)
}
return nil
}
return fmt.Errorf("Inline succeeded unexpectedly: want error matching %q, got <<%s>>", want, got)
}
// TestTable is a table driven test, enabling more compact expression
// of single-package test cases than is possible with the txtar notation.
func TestTable(t *testing.T) {
// Each callee must declare a function or method named f,
// and each caller must call it.
const funcName = "f"
var tests = []struct {
descr string
callee, caller string // Go source files (sans package decl) of caller, callee
want string // expected new portion of caller file, or "error: regexp"
}{
{
"Basic",
`func f(x int) int { return x }`,
`var _ = f(0)`,
`var _ = (0)`,
},
{
"Empty body, no arg effects.",
`func f(x, y int) {}`,
`func _() { f(1, 2) }`,
`func _() {}`,
},
{
"Empty body, some arg effects.",
`func f(x, y, z int) {}`,
`func _() { f(1, recover().(int), 3) }`,
`func _() { _ = recover().(int) }`,
},
{
"Tail call.",
`func f() int { return 1 }`,
`func _() int { return f() }`,
`func _() int { return (1) }`,
},
{
"Void tail call.",
`func f() { println() }`,
`func _() { f() }`,
`func _() { println() }`,
},
{
"Void tail call with defer.", // => literalized
`func f() { defer f(); println() }`,
`func _() { f() }`,
`func _() { func() { defer f(); println() }() }`,
},
{
"Edge case: cannot literalize spread method call.",
`type I int
func g() (I, I)
func (r I) f(x, y I) I {
defer g() // force literalization
return x + y + r
}`,
`func _() I { return recover().(I).f(g()) }`,
`error: can't yet inline spread call to method`,
},
{
"Variadic cancellation (basic).",
`func f(args ...any) { defer f(&args); println(args) }`,
`func _(slice []any) { f(slice...) }`,
`func _(slice []any) { func(args []any) { defer f(&args); println(args) }(slice) }`,
},
{
"Variadic cancellation (literalization with parameter elimination).",
`func f(args ...any) { defer f(); println(args) }`,
`func _(slice []any) { f(slice...) }`,
`func _(slice []any) { func() { defer f(); println(slice) }() }`,
},
{
"Variadic cancellation (reduction).",
`func f(args ...any) { println(args) }`,
`func _(slice []any) { f(slice...) }`,
`func _(slice []any) { println(slice) }`,
},
// TODO(adonovan): improve coverage of the cross
// product of each strategy with the checklist of
// concerns enumerated in the package doc comment.
}
for _, test := range tests {
test := test
t.Run(test.descr, func(t *testing.T) {
fset := token.NewFileSet()
mustParse := func(filename string, content any) *ast.File {
f, err := parser.ParseFile(fset, filename, content, parser.ParseComments|parser.SkipObjectResolution)
if err != nil {
t.Fatalf("ParseFile: %v", err)
}
return f
}
// Parse callee file and find first func decl named f.
calleeContent := "package p\n" + test.callee
calleeFile := mustParse("callee.go", calleeContent)
var decl *ast.FuncDecl
for _, d := range calleeFile.Decls {
if d, ok := d.(*ast.FuncDecl); ok && d.Name.Name == funcName {
decl = d
break
}
}
if decl == nil {
t.Fatalf("declaration of func %s not found: %s", funcName, test.callee)
}
// Parse caller file and find first call to f().
callerContent := "package p\n" + test.caller
callerFile := mustParse("caller.go", callerContent)
var call *ast.CallExpr
ast.Inspect(callerFile, func(n ast.Node) bool {
if n, ok := n.(*ast.CallExpr); ok {
switch fun := n.Fun.(type) {
case *ast.SelectorExpr:
if fun.Sel.Name == funcName {
call = n
}
case *ast.Ident:
if fun.Name == funcName {
call = n
}
}
}
return call == nil
})
if call == nil {
t.Fatalf("call to %s not found: %s", funcName, test.caller)
}
// Type check both files as one package.
info := &types.Info{
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Types: make(map[ast.Expr]types.TypeAndValue),
Implicits: make(map[ast.Node]types.Object),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
Scopes: make(map[ast.Node]*types.Scope),
}
conf := &types.Config{Error: func(err error) { t.Error(err) }}
pkg, err := conf.Check("p", fset, []*ast.File{callerFile, calleeFile}, info)
if err != nil {
t.Fatal(err)
}
// Analyze callee and inline call.
doIt := func() ([]byte, error) {
callee, err := inline.AnalyzeCallee(fset, pkg, info, decl, []byte(calleeContent))
if err != nil {
return nil, err
}
if err := checkTranscode(callee); err != nil {
t.Fatal(err)
}
caller := &inline.Caller{
Fset: fset,
Types: pkg,
Info: info,
File: callerFile,
Call: call,
Content: []byte(callerContent),
}
check := checkNoMutation(caller.File)
defer check()
return inline.Inline(t.Logf, caller, callee)
}
gotContent, err := doIt()
// Want error?
if rest := strings.TrimPrefix(test.want, "error: "); rest != test.want {
if err == nil {
t.Fatalf("unexpected sucess: want error matching %q", rest)
}
msg := err.Error()
if ok, err := regexp.MatchString(rest, msg); err != nil {
t.Fatalf("invalid regexp: %v", err)
} else if !ok {
t.Fatalf("wrong error: %s (want match for %q)", msg, rest)
}
return
}
// Want success.
if err != nil {
t.Fatal(err)
}
// Compute a single-hunk line-based diff.
srcLines := strings.Split(callerContent, "\n")
gotLines := strings.Split(string(gotContent), "\n")
for len(srcLines) > 0 && len(gotLines) > 0 &&
srcLines[0] == gotLines[0] {
srcLines = srcLines[1:]
gotLines = gotLines[1:]
}
for len(srcLines) > 0 && len(gotLines) > 0 &&
srcLines[len(srcLines)-1] == gotLines[len(gotLines)-1] {
srcLines = srcLines[:len(srcLines)-1]
gotLines = gotLines[:len(gotLines)-1]
}
got := strings.Join(gotLines, "\n")
if strings.TrimSpace(got) != strings.TrimSpace(test.want) {
t.Errorf("\nInlining this call:\t%s\nof this callee: \t%s\nproduced:\n%s\nWant:\n\n%s",
test.caller,
test.callee,
got,
test.want)
}
// Check that resulting code type-checks.
newCallerFile := mustParse("newcaller.go", gotContent)
if _, err := conf.Check("p", fset, []*ast.File{newCallerFile, calleeFile}, nil); err != nil {
t.Fatalf("modified source failed to typecheck: <<%s>>", gotContent)
}
})
}
}
// -- helpers --
// checkNoMutation returns a function that, when called,
// asserts that file was not modified since the checkNoMutation call.
func checkNoMutation(file *ast.File) func() {
pre := deepHash(file)
return func() {
post := deepHash(file)
if pre != post {
panic("Inline mutated caller.File")
}
}
}
// checkTranscode replaces *callee by the results of gob-encoding and
// then decoding it, to test that these operations are lossless.
func checkTranscode(callee *inline.Callee) error {
// Perform Gob transcoding so that it is exercised by the test.
var enc bytes.Buffer
if err := gob.NewEncoder(&enc).Encode(callee); err != nil {
return fmt.Errorf("internal error: gob encoding failed: %v", err)
}
*callee = inline.Callee{}
if err := gob.NewDecoder(&enc).Decode(callee); err != nil {
return fmt.Errorf("internal error: gob decoding failed: %v", err)
}
return nil
}
// TODO(adonovan): publish this a helper (#61386).
func extractTxtar(ar *txtar.Archive, dir string) error {
for _, file := range ar.Files {
name := filepath.Join(dir, file.Name)
if err := os.MkdirAll(filepath.Dir(name), 0777); err != nil {
return err
}
if err := os.WriteFile(name, file.Data, 0666); err != nil {
return err
}
}
return nil
}
// deepHash computes a cryptographic hash of an ast.Node so that
// if the data structure is mutated, the hash changes.
// It assumes Go variables do not change address.
//
// TODO(adonovan): consider publishing this in the astutil package.
//
// TODO(adonovan): consider a variant that reports where in the tree
// the mutation occurred (obviously at a cost in space).
func deepHash(n ast.Node) any {
seen := make(map[unsafe.Pointer]bool) // to break cycles
hasher := sha256.New()
le := binary.LittleEndian
writeUint64 := func(v uint64) {
var bs [8]byte
le.PutUint64(bs[:], v)
hasher.Write(bs[:])
}
var visit func(reflect.Value)
visit = func(v reflect.Value) {
switch v.Kind() {
case reflect.Ptr:
ptr := v.UnsafePointer()
writeUint64(uint64(uintptr(ptr)))
if !v.IsNil() {
if !seen[ptr] {
seen[ptr] = true
// Skip types we don't handle yet, but don't care about.
switch v.Interface().(type) {
case *ast.Scope:
return // involves a map
}
visit(v.Elem())
}
}
case reflect.Struct:
for i := 0; i < v.Type().NumField(); i++ {
visit(v.Field(i))
}
case reflect.Slice:
ptr := v.UnsafePointer()
// We may encounter different slices at the same address,
// so don't mark ptr as "seen".
writeUint64(uint64(uintptr(ptr)))
writeUint64(uint64(v.Len()))
writeUint64(uint64(v.Cap()))
for i := 0; i < v.Len(); i++ {
visit(v.Index(i))
}
case reflect.Interface:
if v.IsNil() {
writeUint64(0)
} else {
rtype := reflect.ValueOf(v.Type()).UnsafePointer()
writeUint64(uint64(uintptr(rtype)))
visit(v.Elem())
}
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer:
panic(v) // unreachable in AST
default: // bool, string, number
if v.Kind() == reflect.String { // proper framing
writeUint64(uint64(v.Len()))
}
binary.Write(hasher, le, v.Interface())
}
}
visit(reflect.ValueOf(n))
var hash [sha256.Size]byte
hasher.Sum(hash[:0])
return hash
}