blob: 9e539189680501801ce063ce26594b977d3f5a71 [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package inspector_test
import (
"go/ast"
"go/build"
"go/parser"
"go/token"
"log"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/internal/typeparams"
)
var netFiles []*ast.File
func init() {
files, err := parseNetFiles()
if err != nil {
log.Fatal(err)
}
netFiles = files
}
func parseNetFiles() ([]*ast.File, error) {
pkg, err := build.Default.Import("net", "", 0)
if err != nil {
return nil, err
}
fset := token.NewFileSet()
var files []*ast.File
for _, filename := range pkg.GoFiles {
filename = filepath.Join(pkg.Dir, filename)
f, err := parser.ParseFile(fset, filename, nil, 0)
if err != nil {
return nil, err
}
files = append(files, f)
}
return files, nil
}
// TestAllNodes compares Inspector against ast.Inspect.
func TestInspectAllNodes(t *testing.T) {
inspect := inspector.New(netFiles)
var nodesA []ast.Node
inspect.Nodes(nil, func(n ast.Node, push bool) bool {
if push {
nodesA = append(nodesA, n)
}
return true
})
var nodesB []ast.Node
for _, f := range netFiles {
ast.Inspect(f, func(n ast.Node) bool {
if n != nil {
nodesB = append(nodesB, n)
}
return true
})
}
compare(t, nodesA, nodesB)
}
func TestInspectGenericNodes(t *testing.T) {
if !typeparams.Enabled {
t.Skip("type parameters are not supported at this Go version")
}
// src is using the 16 identifiers i0, i1, ... i15 so
// we can easily verify that we've found all of them.
const src = `package a
type I interface { ~i0|i1 }
type T[i2, i3 interface{ ~i4 }] struct {}
func f[i5, i6 any]() {
_ = f[i7, i8]
var x T[i9, i10]
}
func (*T[i11, i12]) m()
var _ i13[i14, i15]
`
fset := token.NewFileSet()
f, _ := parser.ParseFile(fset, "a.go", src, 0)
inspect := inspector.New([]*ast.File{f})
found := make([]bool, 16)
indexListExprs := make(map[*typeparams.IndexListExpr]bool)
// Verify that we reach all i* identifiers, and collect IndexListExpr nodes.
inspect.Preorder(nil, func(n ast.Node) {
switch n := n.(type) {
case *ast.Ident:
if n.Name[0] == 'i' {
index, err := strconv.Atoi(n.Name[1:])
if err != nil {
t.Fatal(err)
}
found[index] = true
}
case *typeparams.IndexListExpr:
indexListExprs[n] = false
}
})
for i, v := range found {
if !v {
t.Errorf("missed identifier i%d", i)
}
}
// Verify that we can filter to IndexListExprs that we found in the first
// step.
if len(indexListExprs) == 0 {
t.Fatal("no index list exprs found")
}
inspect.Preorder([]ast.Node{&typeparams.IndexListExpr{}}, func(n ast.Node) {
ix := n.(*typeparams.IndexListExpr)
indexListExprs[ix] = true
})
for ix, v := range indexListExprs {
if !v {
t.Errorf("inspected node %v not filtered", ix)
}
}
}
// TestPruning compares Inspector against ast.Inspect,
// pruning descent within ast.CallExpr nodes.
func TestInspectPruning(t *testing.T) {
inspect := inspector.New(netFiles)
var nodesA []ast.Node
inspect.Nodes(nil, func(n ast.Node, push bool) bool {
if push {
nodesA = append(nodesA, n)
_, isCall := n.(*ast.CallExpr)
return !isCall // don't descend into function calls
}
return false
})
var nodesB []ast.Node
for _, f := range netFiles {
ast.Inspect(f, func(n ast.Node) bool {
if n != nil {
nodesB = append(nodesB, n)
_, isCall := n.(*ast.CallExpr)
return !isCall // don't descend into function calls
}
return false
})
}
compare(t, nodesA, nodesB)
}
func compare(t *testing.T, nodesA, nodesB []ast.Node) {
if len(nodesA) != len(nodesB) {
t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
} else {
for i := range nodesA {
if a, b := nodesA[i], nodesB[i]; a != b {
t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
}
}
}
}
func TestTypeFiltering(t *testing.T) {
const src = `package a
func f() {
print("hi")
panic("oops")
}
`
fset := token.NewFileSet()
f, _ := parser.ParseFile(fset, "a.go", src, 0)
inspect := inspector.New([]*ast.File{f})
var got []string
fn := func(n ast.Node, push bool) bool {
if push {
got = append(got, typeOf(n))
}
return true
}
// no type filtering
inspect.Nodes(nil, fn)
if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
t.Errorf("inspect: got %s, want %s", got, want)
}
// type filtering
nodeTypes := []ast.Node{
(*ast.BasicLit)(nil),
(*ast.CallExpr)(nil),
}
got = nil
inspect.Nodes(nodeTypes, fn)
if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
t.Errorf("inspect: got %s, want %s", got, want)
}
// inspect with stack
got = nil
inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
if push {
var line []string
for _, n := range stack {
line = append(line, typeOf(n))
}
got = append(got, strings.Join(line, " "))
}
return true
})
want := []string{
"File FuncDecl BlockStmt ExprStmt CallExpr",
"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
"File FuncDecl BlockStmt ExprStmt CallExpr",
"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("inspect: got %s, want %s", got, want)
}
}
func typeOf(n ast.Node) string {
return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
}
// The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x,
// but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
// traversals.
//
// BenchmarkNewInspector 4.5 ms
// BenchmarkNewInspect 0.33ms
// BenchmarkASTInspect 1.2 ms
func BenchmarkNewInspector(b *testing.B) {
// Measure one-time construction overhead.
for i := 0; i < b.N; i++ {
inspector.New(netFiles)
}
}
func BenchmarkInspect(b *testing.B) {
b.StopTimer()
inspect := inspector.New(netFiles)
b.StartTimer()
// Measure marginal cost of traversal.
var ndecls, nlits int
for i := 0; i < b.N; i++ {
inspect.Preorder(nil, func(n ast.Node) {
switch n.(type) {
case *ast.FuncDecl:
ndecls++
case *ast.FuncLit:
nlits++
}
})
}
}
func BenchmarkASTInspect(b *testing.B) {
var ndecls, nlits int
for i := 0; i < b.N; i++ {
for _, f := range netFiles {
ast.Inspect(f, func(n ast.Node) bool {
switch n.(type) {
case *ast.FuncDecl:
ndecls++
case *ast.FuncLit:
nlits++
}
return true
})
}
}
}