| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| //go:build go1.23 |
| |
| package cursor_test |
| |
| import ( |
| "fmt" |
| "go/ast" |
| "go/build" |
| "go/parser" |
| "go/token" |
| "iter" |
| "log" |
| "math/rand" |
| "path/filepath" |
| "reflect" |
| "slices" |
| "strings" |
| "testing" |
| |
| "golang.org/x/tools/go/ast/inspector" |
| "golang.org/x/tools/internal/astutil/cursor" |
| "golang.org/x/tools/internal/astutil/edge" |
| ) |
| |
| // net/http package |
| var ( |
| netFset = token.NewFileSet() |
| netFiles []*ast.File |
| netInspect *inspector.Inspector |
| ) |
| |
| func init() { |
| files, err := parseNetFiles() |
| if err != nil { |
| log.Fatal(err) |
| } |
| netFiles = files |
| netInspect = inspector.New(netFiles) |
| } |
| |
| func parseNetFiles() ([]*ast.File, error) { |
| pkg, err := build.Default.Import("net", "", 0) |
| if err != nil { |
| return nil, err |
| } |
| var files []*ast.File |
| for _, filename := range pkg.GoFiles { |
| filename = filepath.Join(pkg.Dir, filename) |
| f, err := parser.ParseFile(netFset, filename, nil, 0) |
| if err != nil { |
| return nil, err |
| } |
| files = append(files, f) |
| } |
| return files, nil |
| } |
| |
| // compare calls t.Error if !slices.Equal(nodesA, nodesB). |
| func compare[N comparable](t *testing.T, nodesA, nodesB []N) { |
| if len(nodesA) != len(nodesB) { |
| t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) |
| } else { |
| for i := range nodesA { |
| if a, b := nodesA[i], nodesB[i]; a != b { |
| t.Errorf("node %d is inconsistent: %T, %T", i, a, b) |
| } |
| } |
| } |
| } |
| |
| // firstN(n, seq), returns a slice of up to n elements of seq. |
| func firstN[T any](n int, seq iter.Seq[T]) (res []T) { |
| for x := range seq { |
| res = append(res, x) |
| if len(res) == n { |
| break |
| } |
| } |
| return res |
| } |
| |
| func TestCursor_Preorder(t *testing.T) { |
| inspect := netInspect |
| |
| nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)} |
| |
| // reference implementation |
| var want []ast.Node |
| for cur := range cursor.Root(inspect).Preorder(nodeFilter...) { |
| want = append(want, cur.Node()) |
| } |
| |
| // Check entire sequence. |
| got := slices.Collect(inspect.PreorderSeq(nodeFilter...)) |
| compare(t, got, want) |
| |
| // Check that break works. |
| got = got[:0] |
| for _, c := range firstN(10, cursor.Root(inspect).Preorder(nodeFilter...)) { |
| got = append(got, c.Node()) |
| } |
| compare(t, got, want[:10]) |
| } |
| |
| func TestCursor_nestedTraversal(t *testing.T) { |
| const src = `package a |
| func f() { |
| print("hello") |
| } |
| func g() { |
| print("goodbye") |
| panic("oops") |
| } |
| ` |
| fset := token.NewFileSet() |
| f, _ := parser.ParseFile(fset, "a.go", src, 0) |
| inspect := inspector.New([]*ast.File{f}) |
| |
| var ( |
| funcDecls = []ast.Node{(*ast.FuncDecl)(nil)} |
| callExprs = []ast.Node{(*ast.CallExpr)(nil)} |
| nfuncs = 0 |
| ncalls = 0 |
| ) |
| |
| for curFunc := range cursor.Root(inspect).Preorder(funcDecls...) { |
| _ = curFunc.Node().(*ast.FuncDecl) |
| |
| // Check edge and index. |
| if k, idx := curFunc.ParentEdge(); k != edge.File_Decls || idx != nfuncs { |
| t.Errorf("%v.ParentEdge() = (%v, %d), want edge.File_Decls, %d", curFunc, k, idx, nfuncs) |
| } |
| |
| nfuncs++ |
| stack := slices.Collect(curFunc.Enclosing()) |
| |
| // Stacks are convenient to print! |
| if got, want := fmt.Sprint(stack), "[*ast.FuncDecl *ast.File]"; got != want { |
| t.Errorf("curFunc.Enclosing() = %q, want %q", got, want) |
| } |
| |
| // Parent, iterated, is Enclosing stack. |
| i := 0 |
| for c := curFunc; c.Node() != nil; c = c.Parent() { |
| if got, want := stack[i], c; got != want { |
| t.Errorf("Enclosing[%d] = %v; Parent()^%d = %v", i, got, i, want) |
| } |
| i++ |
| } |
| |
| wantStack := "[*ast.CallExpr *ast.ExprStmt *ast.BlockStmt *ast.FuncDecl *ast.File]" |
| |
| // nested Preorder traversal |
| preorderCount := 0 |
| for curCall := range curFunc.Preorder(callExprs...) { |
| _ = curCall.Node().(*ast.CallExpr) |
| preorderCount++ |
| stack := slices.Collect(curCall.Enclosing()) |
| if got := fmt.Sprint(stack); got != wantStack { |
| t.Errorf("curCall.Enclosing() = %q, want %q", got, wantStack) |
| } |
| } |
| |
| // nested Inspect traversal |
| inspectCount := 0 |
| curFunc.Inspect(callExprs, func(curCall cursor.Cursor) (proceed bool) { |
| _ = curCall.Node().(*ast.CallExpr) |
| inspectCount++ |
| stack := slices.Collect(curCall.Enclosing()) |
| if got := fmt.Sprint(stack); got != wantStack { |
| t.Errorf("curCall.Enclosing() = %q, want %q", got, wantStack) |
| } |
| return true |
| }) |
| |
| if inspectCount != preorderCount { |
| t.Errorf("Inspect (%d) and Preorder (%d) events are not consistent", inspectCount, preorderCount) |
| } |
| |
| ncalls += preorderCount |
| } |
| |
| if nfuncs != 2 { |
| t.Errorf("Found %d FuncDecls, want 2", nfuncs) |
| } |
| if ncalls != 3 { |
| t.Errorf("Found %d CallExprs, want 3", ncalls) |
| } |
| } |
| |
| func TestCursor_Children(t *testing.T) { |
| inspect := netInspect |
| |
| // Assert that Cursor.Children agrees with |
| // reference implementation for every node. |
| var want, got []ast.Node |
| for c := range cursor.Root(inspect).Preorder() { |
| |
| // reference implementation |
| want = want[:0] |
| { |
| parent := c.Node() |
| ast.Inspect(parent, func(n ast.Node) bool { |
| if n != nil && n != parent { |
| want = append(want, n) |
| } |
| return n == parent // descend only into parent |
| }) |
| } |
| |
| // Check cursor-based implementation |
| // (uses FirstChild+NextSibling). |
| got = got[:0] |
| for child := range c.Children() { |
| got = append(got, child.Node()) |
| } |
| |
| if !slices.Equal(got, want) { |
| t.Errorf("For %v\n"+ |
| "Using FirstChild+NextSibling: %v\n"+ |
| "Using ast.Inspect: %v", |
| c, sliceTypes(got), sliceTypes(want)) |
| } |
| |
| // Second cursor-based implementation |
| // using LastChild+PrevSibling+reverse. |
| got = got[:0] |
| for c, ok := c.LastChild(); ok; c, ok = c.PrevSibling() { |
| got = append(got, c.Node()) |
| } |
| slices.Reverse(got) |
| |
| if !slices.Equal(got, want) { |
| t.Errorf("For %v\n"+ |
| "Using LastChild+PrevSibling: %v\n"+ |
| "Using ast.Inspect: %v", |
| c, sliceTypes(got), sliceTypes(want)) |
| } |
| } |
| } |
| |
| func TestCursor_Inspect(t *testing.T) { |
| inspect := netInspect |
| |
| // In all three loops, we'll gather both kinds of type switches, |
| // but we'll prune the traversal from descending into (value) switches. |
| switches := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} |
| |
| // reference implementation (ast.Inspect) |
| var nodesA []ast.Node |
| for _, f := range netFiles { |
| ast.Inspect(f, func(n ast.Node) (proceed bool) { |
| switch n.(type) { |
| case *ast.SwitchStmt, *ast.TypeSwitchStmt: |
| nodesA = append(nodesA, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| } |
| return true |
| }) |
| } |
| |
| // Test Cursor.Inspect implementation. |
| var nodesB []ast.Node |
| cursor.Root(inspect).Inspect(switches, func(c cursor.Cursor) (proceed bool) { |
| n := c.Node() |
| nodesB = append(nodesB, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| return false |
| }) |
| compare(t, nodesA, nodesB) |
| |
| // Test WithStack implementation. |
| var nodesC []ast.Node |
| inspect.WithStack(switches, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { |
| if push { |
| nodesC = append(nodesC, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| } |
| return false |
| }) |
| compare(t, nodesA, nodesC) |
| } |
| |
| func TestCursor_FindNode(t *testing.T) { |
| inspect := netInspect |
| |
| // Enumerate all nodes of a particular type, |
| // then check that FindPos can find them, |
| // starting at the root. |
| // |
| // (We use BasicLit because they are numerous.) |
| root := cursor.Root(inspect) |
| for c := range root.Preorder((*ast.BasicLit)(nil)) { |
| node := c.Node() |
| got, ok := root.FindNode(node) |
| if !ok { |
| t.Errorf("root.FindNode failed") |
| } else if got != c { |
| t.Errorf("root.FindNode returned %v, want %v", got, c) |
| } |
| } |
| |
| // Same thing, but searching only within subtrees (each FuncDecl). |
| for funcDecl := range root.Preorder((*ast.FuncDecl)(nil)) { |
| for c := range funcDecl.Preorder((*ast.BasicLit)(nil)) { |
| node := c.Node() |
| got, ok := funcDecl.FindNode(node) |
| if !ok { |
| t.Errorf("funcDecl.FindNode failed") |
| } else if got != c { |
| t.Errorf("funcDecl.FindNode returned %v, want %v", got, c) |
| } |
| |
| // Also, check that we cannot find the BasicLit |
| // beneath a different FuncDecl. |
| if prevFunc, ok := funcDecl.PrevSibling(); ok { |
| got, ok := prevFunc.FindNode(node) |
| if ok { |
| t.Errorf("prevFunc.FindNode succeeded unexpectedly: %v", got) |
| } |
| } |
| } |
| } |
| } |
| |
| // TestCursor_FindPos_order ensures that FindPos does not assume files are in Pos order. |
| func TestCursor_FindPos_order(t *testing.T) { |
| // Pick an arbitrary decl. |
| target := netFiles[7].Decls[0] |
| |
| // Find the target decl by its position. |
| cur, ok := cursor.Root(netInspect).FindByPos(target.Pos(), target.End()) |
| if !ok || cur.Node() != target { |
| t.Fatalf("unshuffled: FindPos(%T) = (%v, %t)", target, cur, ok) |
| } |
| |
| // Shuffle the files out of Pos order. |
| files := slices.Clone(netFiles) |
| rand.Shuffle(len(files), func(i, j int) { |
| files[i], files[j] = files[j], files[i] |
| }) |
| |
| // Find it again. |
| inspect := inspector.New(files) |
| cur, ok = cursor.Root(inspect).FindByPos(target.Pos(), target.End()) |
| if !ok || cur.Node() != target { |
| t.Fatalf("shuffled: FindPos(%T) = (%v, %t)", target, cur, ok) |
| } |
| } |
| |
| func TestCursor_Edge(t *testing.T) { |
| root := cursor.Root(netInspect) |
| for cur := range root.Preorder() { |
| if cur == root { |
| continue // root node |
| } |
| |
| var ( |
| parent = cur.Parent() |
| e, idx = cur.ParentEdge() |
| ) |
| |
| // ast.File, child of root? |
| if parent.Node() == nil { |
| if e != edge.Invalid || idx != -1 { |
| t.Errorf("%v.Edge = (%v, %d), want (Invalid, -1)", cur, e, idx) |
| } |
| continue |
| } |
| |
| // Check Edge.NodeType matches type of Parent.Node. |
| if e.NodeType() != reflect.TypeOf(parent.Node()) { |
| t.Errorf("Edge.NodeType = %v, Parent.Node has type %T", |
| e.NodeType(), parent.Node()) |
| } |
| |
| // Check c.Edge.Get(c.Parent.Node) == c.Node. |
| if got := e.Get(parent.Node(), idx); got != cur.Node() { |
| t.Errorf("cur=%v@%s: %s.Get(cur.Parent().Node(), %d) = %T@%s, want cur.Node()", |
| cur, netFset.Position(cur.Node().Pos()), e, idx, got, netFset.Position(got.Pos())) |
| } |
| |
| // Check c.Parent.ChildAt(c.ParentEdge()) == c. |
| if got := parent.ChildAt(e, idx); got != cur { |
| t.Errorf("cur=%v@%s: cur.Parent().ChildAt(%v, %d) = %T@%s, want cur", |
| cur, netFset.Position(cur.Node().Pos()), e, idx, got.Node(), netFset.Position(got.Node().Pos())) |
| } |
| |
| // Check that reflection on the parent finds the current node. |
| fv := reflect.ValueOf(parent.Node()).Elem().FieldByName(e.FieldName()) |
| if idx >= 0 { |
| fv = fv.Index(idx) // element of []ast.Node |
| } |
| if fv.Kind() == reflect.Interface { |
| fv = fv.Elem() // e.g. ast.Expr -> *ast.Ident |
| } |
| got := fv.Interface().(ast.Node) |
| if got != cur.Node() { |
| t.Errorf("%v.Edge = (%v, %d); FieldName/Index reflection gave %T@%s, not original node", |
| cur, e, idx, got, netFset.Position(got.Pos())) |
| } |
| |
| // Check that Cursor.Child is the reverse of Parent. |
| if cur.Parent().Child(cur.Node()) != cur { |
| t.Errorf("Cursor.Parent.Child = %v, want %v", cur.Parent().Child(cur.Node()), cur) |
| } |
| |
| // Check invariants of Contains: |
| |
| // A cursor contains itself. |
| if !cur.Contains(cur) { |
| t.Errorf("!cur.Contains(cur): %v", cur) |
| } |
| // A parent contains its child, but not the inverse. |
| if !parent.Contains(cur) { |
| t.Errorf("!cur.Parent().Contains(cur): %v", cur) |
| } |
| if cur.Contains(parent) { |
| t.Errorf("cur.Contains(cur.Parent()): %v", cur) |
| } |
| // A grandparent contains its grandchild, but not the inverse. |
| if grandparent := cur.Parent(); grandparent.Node() != nil { |
| if !grandparent.Contains(cur) { |
| t.Errorf("!cur.Parent().Parent().Contains(cur): %v", cur) |
| } |
| if cur.Contains(grandparent) { |
| t.Errorf("cur.Contains(cur.Parent().Parent()): %v", cur) |
| } |
| } |
| // A cursor and its uncle/aunt do not contain each other. |
| if uncle, ok := parent.NextSibling(); ok { |
| if uncle.Contains(cur) { |
| t.Errorf("cur.Parent().NextSibling().Contains(cur): %v", cur) |
| } |
| if cur.Contains(uncle) { |
| t.Errorf("cur.Contains(cur.Parent().NextSibling()): %v", cur) |
| } |
| } |
| } |
| } |
| |
| func is[T any](x any) bool { |
| _, ok := x.(T) |
| return ok |
| } |
| |
| // sliceTypes is a debugging helper that formats each slice element with %T. |
| func sliceTypes[T any](slice []T) string { |
| var buf strings.Builder |
| buf.WriteByte('[') |
| for i, elem := range slice { |
| if i > 0 { |
| buf.WriteByte(' ') |
| } |
| fmt.Fprintf(&buf, "%T", elem) |
| } |
| buf.WriteByte(']') |
| return buf.String() |
| } |
| |
| // (partially duplicates benchmark in go/ast/inspector) |
| func BenchmarkInspectCalls(b *testing.B) { |
| inspect := netInspect |
| b.ResetTimer() |
| |
| // Measure marginal cost of traversal. |
| |
| callExprs := []ast.Node{(*ast.CallExpr)(nil)} |
| |
| b.Run("Preorder", func(b *testing.B) { |
| var ncalls int |
| for range b.N { |
| inspect.Preorder(callExprs, func(n ast.Node) { |
| _ = n.(*ast.CallExpr) |
| ncalls++ |
| }) |
| } |
| }) |
| |
| b.Run("WithStack", func(b *testing.B) { |
| var ncalls int |
| for range b.N { |
| inspect.WithStack(callExprs, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { |
| _ = n.(*ast.CallExpr) |
| if push { |
| ncalls++ |
| } |
| return true |
| }) |
| } |
| }) |
| |
| b.Run("Cursor", func(b *testing.B) { |
| var ncalls int |
| for range b.N { |
| for cur := range cursor.Root(inspect).Preorder(callExprs...) { |
| _ = cur.Node().(*ast.CallExpr) |
| ncalls++ |
| } |
| } |
| }) |
| |
| b.Run("CursorEnclosing", func(b *testing.B) { |
| var ncalls int |
| for range b.N { |
| for cur := range cursor.Root(inspect).Preorder(callExprs...) { |
| _ = cur.Node().(*ast.CallExpr) |
| for range cur.Enclosing() { |
| } |
| ncalls++ |
| } |
| } |
| }) |
| } |
| |
| // This benchmark compares methods for finding a known node in a tree. |
| func BenchmarkCursor_FindNode(b *testing.B) { |
| root := cursor.Root(netInspect) |
| |
| callExprs := []ast.Node{(*ast.CallExpr)(nil)} |
| |
| // Choose a needle in the haystack to use as the search target: |
| // a CallExpr not too near the start nor at too shallow a depth. |
| var needle cursor.Cursor |
| { |
| count := 0 |
| found := false |
| for c := range root.Preorder(callExprs...) { |
| count++ |
| if count >= 1000 && iterlen(c.Enclosing()) >= 6 { |
| needle = c |
| found = true |
| break |
| } |
| } |
| if !found { |
| b.Fatal("can't choose needle") |
| } |
| } |
| |
| b.ResetTimer() |
| |
| b.Run("Cursor.Preorder", func(b *testing.B) { |
| needleNode := needle.Node() |
| for range b.N { |
| var found cursor.Cursor |
| for c := range root.Preorder(callExprs...) { |
| if c.Node() == needleNode { |
| found = c |
| break |
| } |
| } |
| if found != needle { |
| b.Errorf("Preorder search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| |
| // This method is about 10-15% faster than Cursor.Preorder. |
| b.Run("Cursor.FindNode", func(b *testing.B) { |
| for range b.N { |
| found, ok := root.FindNode(needle.Node()) |
| if !ok || found != needle { |
| b.Errorf("FindNode search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| |
| // This method is about 100x (!) faster than Cursor.Preorder. |
| b.Run("Cursor.FindPos", func(b *testing.B) { |
| needleNode := needle.Node() |
| for range b.N { |
| found, ok := root.FindByPos(needleNode.Pos(), needleNode.End()) |
| if !ok || found != needle { |
| b.Errorf("FindPos search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| } |
| |
| func iterlen[T any](seq iter.Seq[T]) (len int) { |
| for range seq { |
| len++ |
| } |
| return |
| } |