| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package inspector_test |
| |
| import ( |
| "fmt" |
| "go/ast" |
| "go/parser" |
| "go/token" |
| "iter" |
| "math/rand" |
| "reflect" |
| "slices" |
| "strings" |
| "testing" |
| |
| "golang.org/x/tools/go/ast/edge" |
| "golang.org/x/tools/go/ast/inspector" |
| ) |
| |
| func TestCursor_Preorder(t *testing.T) { |
| inspect := netInspect |
| |
| nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)} |
| |
| // reference implementation |
| var want []ast.Node |
| for cur := range inspect.Root().Preorder(nodeFilter...) { |
| want = append(want, cur.Node()) |
| } |
| |
| // Check entire sequence. |
| got := slices.Collect(inspect.PreorderSeq(nodeFilter...)) |
| compare(t, got, want) |
| |
| // Check that break works. |
| got = got[:0] |
| for _, c := range firstN(10, inspect.Root().Preorder(nodeFilter...)) { |
| got = append(got, c.Node()) |
| } |
| compare(t, got, want[:10]) |
| } |
| |
| func TestCursor_nestedTraversal(t *testing.T) { |
| const src = `package a |
| func f() { |
| print("hello") |
| } |
| func g() { |
| print("goodbye") |
| panic("oops") |
| } |
| ` |
| fset := token.NewFileSet() |
| f, _ := parser.ParseFile(fset, "a.go", src, 0) |
| inspect := inspector.New([]*ast.File{f}) |
| |
| var ( |
| funcDecls = []ast.Node{(*ast.FuncDecl)(nil)} |
| callExprs = []ast.Node{(*ast.CallExpr)(nil)} |
| nfuncs = 0 |
| ncalls = 0 |
| ) |
| |
| for curFunc := range inspect.Root().Preorder(funcDecls...) { |
| _ = curFunc.Node().(*ast.FuncDecl) |
| |
| // Check edge and index. |
| if k, idx := curFunc.ParentEdge(); k != edge.File_Decls || idx != nfuncs { |
| t.Errorf("%v.ParentEdge() = (%v, %d), want edge.File_Decls, %d", curFunc, k, idx, nfuncs) |
| } |
| |
| nfuncs++ |
| stack := slices.Collect(curFunc.Enclosing()) |
| |
| // Stacks are convenient to print! |
| if got, want := fmt.Sprint(stack), "[*ast.FuncDecl *ast.File]"; got != want { |
| t.Errorf("curFunc.Enclosing() = %q, want %q", got, want) |
| } |
| |
| // Parent, iterated, is Enclosing stack. |
| i := 0 |
| for c := curFunc; c.Node() != nil; c = c.Parent() { |
| if got, want := stack[i], c; got != want { |
| t.Errorf("Enclosing[%d] = %v; Parent()^%d = %v", i, got, i, want) |
| } |
| i++ |
| } |
| |
| wantStack := "[*ast.CallExpr *ast.ExprStmt *ast.BlockStmt *ast.FuncDecl *ast.File]" |
| |
| // nested Preorder traversal |
| preorderCount := 0 |
| for curCall := range curFunc.Preorder(callExprs...) { |
| _ = curCall.Node().(*ast.CallExpr) |
| preorderCount++ |
| stack := slices.Collect(curCall.Enclosing()) |
| if got := fmt.Sprint(stack); got != wantStack { |
| t.Errorf("curCall.Enclosing() = %q, want %q", got, wantStack) |
| } |
| } |
| |
| // nested Inspect traversal |
| inspectCount := 0 |
| curFunc.Inspect(callExprs, func(curCall inspector.Cursor) (proceed bool) { |
| _ = curCall.Node().(*ast.CallExpr) |
| inspectCount++ |
| stack := slices.Collect(curCall.Enclosing()) |
| if got := fmt.Sprint(stack); got != wantStack { |
| t.Errorf("curCall.Enclosing() = %q, want %q", got, wantStack) |
| } |
| return true |
| }) |
| |
| if inspectCount != preorderCount { |
| t.Errorf("Inspect (%d) and Preorder (%d) events are not consistent", inspectCount, preorderCount) |
| } |
| |
| ncalls += preorderCount |
| } |
| |
| if nfuncs != 2 { |
| t.Errorf("Found %d FuncDecls, want 2", nfuncs) |
| } |
| if ncalls != 3 { |
| t.Errorf("Found %d CallExprs, want 3", ncalls) |
| } |
| } |
| |
| func TestCursor_Children(t *testing.T) { |
| inspect := netInspect |
| |
| // Assert that Cursor.Children agrees with |
| // reference implementation for every node. |
| var want, got []ast.Node |
| for c := range inspect.Root().Preorder() { |
| |
| // reference implementation |
| want = want[:0] |
| { |
| parent := c.Node() |
| ast.Inspect(parent, func(n ast.Node) bool { |
| if n != nil && n != parent { |
| want = append(want, n) |
| } |
| return n == parent // descend only into parent |
| }) |
| } |
| |
| // Check cursor-based implementation |
| // (uses FirstChild+NextSibling). |
| got = got[:0] |
| for child := range c.Children() { |
| got = append(got, child.Node()) |
| } |
| |
| if !slices.Equal(got, want) { |
| t.Errorf("For %v\n"+ |
| "Using FirstChild+NextSibling: %v\n"+ |
| "Using ast.Inspect: %v", |
| c, sliceTypes(got), sliceTypes(want)) |
| } |
| |
| // Second cursor-based implementation |
| // using LastChild+PrevSibling+reverse. |
| got = got[:0] |
| for c, ok := c.LastChild(); ok; c, ok = c.PrevSibling() { |
| got = append(got, c.Node()) |
| } |
| slices.Reverse(got) |
| |
| if !slices.Equal(got, want) { |
| t.Errorf("For %v\n"+ |
| "Using LastChild+PrevSibling: %v\n"+ |
| "Using ast.Inspect: %v", |
| c, sliceTypes(got), sliceTypes(want)) |
| } |
| } |
| } |
| |
| func TestCursor_Inspect(t *testing.T) { |
| inspect := netInspect |
| |
| // In all three loops, we'll gather both kinds of type switches, |
| // but we'll prune the traversal from descending into (value) switches. |
| switches := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} |
| |
| // reference implementation (ast.Inspect) |
| var nodesA []ast.Node |
| for _, f := range netFiles { |
| ast.Inspect(f, func(n ast.Node) (proceed bool) { |
| switch n.(type) { |
| case *ast.SwitchStmt, *ast.TypeSwitchStmt: |
| nodesA = append(nodesA, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| } |
| return true |
| }) |
| } |
| |
| // Test Cursor.Inspect implementation. |
| var nodesB []ast.Node |
| inspect.Root().Inspect(switches, func(c inspector.Cursor) (proceed bool) { |
| n := c.Node() |
| nodesB = append(nodesB, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| }) |
| compare(t, nodesA, nodesB) |
| |
| // Test WithStack implementation. |
| var nodesC []ast.Node |
| inspect.WithStack(switches, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { |
| if push { |
| nodesC = append(nodesC, n) |
| return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt |
| } |
| return false |
| }) |
| compare(t, nodesA, nodesC) |
| } |
| |
| func TestCursor_FindNode(t *testing.T) { |
| inspect := netInspect |
| |
| // Enumerate all nodes of a particular type, |
| // then check that FindPos can find them, |
| // starting at the root. |
| // |
| // (We use BasicLit because they are numerous.) |
| root := inspect.Root() |
| for c := range root.Preorder((*ast.BasicLit)(nil)) { |
| node := c.Node() |
| got, ok := root.FindNode(node) |
| if !ok { |
| t.Errorf("root.FindNode failed") |
| } else if got != c { |
| t.Errorf("root.FindNode returned %v, want %v", got, c) |
| } |
| } |
| |
| // Same thing, but searching only within subtrees (each FuncDecl). |
| for funcDecl := range root.Preorder((*ast.FuncDecl)(nil)) { |
| for c := range funcDecl.Preorder((*ast.BasicLit)(nil)) { |
| node := c.Node() |
| got, ok := funcDecl.FindNode(node) |
| if !ok { |
| t.Errorf("funcDecl.FindNode failed") |
| } else if got != c { |
| t.Errorf("funcDecl.FindNode returned %v, want %v", got, c) |
| } |
| |
| // Also, check that we cannot find the BasicLit |
| // beneath a different FuncDecl. |
| if prevFunc, ok := funcDecl.PrevSibling(); ok { |
| got, ok := prevFunc.FindNode(node) |
| if ok { |
| t.Errorf("prevFunc.FindNode succeeded unexpectedly: %v", got) |
| } |
| } |
| } |
| } |
| } |
| |
| // TestCursor_FindPos_order ensures that FindPos does not assume files are in Pos order. |
| func TestCursor_FindPos_order(t *testing.T) { |
| // Pick an arbitrary decl. |
| target := netFiles[7].Decls[0] |
| |
| // Find the target decl by its position. |
| cur, ok := netInspect.Root().FindByPos(target.Pos(), target.End()) |
| if !ok || cur.Node() != target { |
| t.Fatalf("unshuffled: FindPos(%T) = (%v, %t)", target, cur, ok) |
| } |
| |
| // Shuffle the files out of Pos order. |
| files := slices.Clone(netFiles) |
| rand.Shuffle(len(files), func(i, j int) { |
| files[i], files[j] = files[j], files[i] |
| }) |
| |
| // Find it again. |
| inspect := inspector.New(files) |
| cur, ok = inspect.Root().FindByPos(target.Pos(), target.End()) |
| if !ok || cur.Node() != target { |
| t.Fatalf("shuffled: FindPos(%T) = (%v, %t)", target, cur, ok) |
| } |
| } |
| |
| func TestCursor_Edge(t *testing.T) { |
| root := netInspect.Root() |
| for cur := range root.Preorder() { |
| if cur == root { |
| continue // root node |
| } |
| |
| var ( |
| parent = cur.Parent() |
| e, idx = cur.ParentEdge() |
| ) |
| |
| // ast.File, child of root? |
| if parent.Node() == nil { |
| if e != edge.Invalid || idx != -1 { |
| t.Errorf("%v.Edge = (%v, %d), want (Invalid, -1)", cur, e, idx) |
| } |
| continue |
| } |
| |
| // Check Edge.NodeType matches type of Parent.Node. |
| if e.NodeType() != reflect.TypeOf(parent.Node()) { |
| t.Errorf("Edge.NodeType = %v, Parent.Node has type %T", |
| e.NodeType(), parent.Node()) |
| } |
| |
| // Check c.Edge.Get(c.Parent.Node) == c.Node. |
| if got := e.Get(parent.Node(), idx); got != cur.Node() { |
| t.Errorf("cur=%v@%s: %s.Get(cur.Parent().Node(), %d) = %T@%s, want cur.Node()", |
| cur, netFset.Position(cur.Node().Pos()), e, idx, got, netFset.Position(got.Pos())) |
| } |
| |
| // Check c.Parent.ChildAt(c.ParentEdge()) == c. |
| if got := parent.ChildAt(e, idx); got != cur { |
| t.Errorf("cur=%v@%s: cur.Parent().ChildAt(%v, %d) = %T@%s, want cur", |
| cur, netFset.Position(cur.Node().Pos()), e, idx, got.Node(), netFset.Position(got.Node().Pos())) |
| } |
| |
| // Check that reflection on the parent finds the current node. |
| fv := reflect.ValueOf(parent.Node()).Elem().FieldByName(e.FieldName()) |
| if idx >= 0 { |
| fv = fv.Index(idx) // element of []ast.Node |
| } |
| if fv.Kind() == reflect.Interface { |
| fv = fv.Elem() // e.g. ast.Expr -> *ast.Ident |
| } |
| got := fv.Interface().(ast.Node) |
| if got != cur.Node() { |
| t.Errorf("%v.Edge = (%v, %d); FieldName/Index reflection gave %T@%s, not original node", |
| cur, e, idx, got, netFset.Position(got.Pos())) |
| } |
| |
| // Check that Cursor.Child is the reverse of Parent. |
| if cur.Parent().Child(cur.Node()) != cur { |
| t.Errorf("Cursor.Parent.Child = %v, want %v", cur.Parent().Child(cur.Node()), cur) |
| } |
| |
| // Check invariants of Contains: |
| |
| // A cursor contains itself. |
| if !cur.Contains(cur) { |
| t.Errorf("!cur.Contains(cur): %v", cur) |
| } |
| // A parent contains its child, but not the inverse. |
| if !parent.Contains(cur) { |
| t.Errorf("!cur.Parent().Contains(cur): %v", cur) |
| } |
| if cur.Contains(parent) { |
| t.Errorf("cur.Contains(cur.Parent()): %v", cur) |
| } |
| // A grandparent contains its grandchild, but not the inverse. |
| if grandparent := cur.Parent(); grandparent.Node() != nil { |
| if !grandparent.Contains(cur) { |
| t.Errorf("!cur.Parent().Parent().Contains(cur): %v", cur) |
| } |
| if cur.Contains(grandparent) { |
| t.Errorf("cur.Contains(cur.Parent().Parent()): %v", cur) |
| } |
| } |
| // A cursor and its uncle/aunt do not contain each other. |
| if uncle, ok := parent.NextSibling(); ok { |
| if uncle.Contains(cur) { |
| t.Errorf("cur.Parent().NextSibling().Contains(cur): %v", cur) |
| } |
| if cur.Contains(uncle) { |
| t.Errorf("cur.Contains(cur.Parent().NextSibling()): %v", cur) |
| } |
| } |
| } |
| } |
| |
| // Regression test for FuncDecl.Type irregularity in FindByPos (#75997). |
| func TestCursor_FindByPos(t *testing.T) { |
| // Observe that the range of FuncType has a hole between |
| // the "func" token and the start of Type.Params. |
| // The hole contains FuncDecl.{Recv,Name}. |
| // |
| // ~~~~~~~~~~~~FuncDecl~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // ~Recv~ ~Name~ |
| // ~~~~--------------~~~~FuncType~~~~~~ |
| // ~Params~ ~Results~ |
| const src = `package a; func (recv) method(params) (results) { body }` |
| var ( |
| fset = token.NewFileSet() |
| f, _ = parser.ParseFile(fset, "a.go", src, 0) // ignore parse errors |
| tokFile = fset.File(f.FileStart) |
| inspect = inspector.New([]*ast.File{f}) |
| ) |
| format := func(start, end token.Pos) string { |
| var ( |
| startOffset = tokFile.Offset(start) |
| endOffset = tokFile.Offset(end) |
| ) |
| return fmt.Sprintf("%s<<%s>>%s", src[:startOffset], src[startOffset:endOffset], src[endOffset:]) |
| } |
| |
| d := f.Decls[0].(*ast.FuncDecl) |
| |
| // Each test case specifies a [pos-end) range for |
| // FindByPos and the syntax node it should find. |
| for _, test := range []struct { |
| start, end token.Pos |
| want ast.Node |
| }{ |
| // pure subtrees |
| {d.Pos(), d.End(), d}, // decl |
| {d.Recv.Pos(), d.Recv.End(), d.Recv}, // recv |
| {d.Name.Pos(), d.Name.End(), d.Name}, // name |
| // (A FuncDecl can't have both Recv and TypeParams, so skip this one.) |
| // {d.Type.TypeParams.Pos(), d.Type.TypeParams.End(), d.Type.TypeParams}, |
| {d.Type.Params.Pos(), d.Type.Params.End(), d.Type.Params}, // params |
| {d.Type.Results.Pos(), d.Type.Results.End(), d.Type.Results}, // results |
| {d.Body.Pos(), d.Body.End(), d.Body}, // body |
| |
| // single tokens |
| { |
| // "func" |
| d.Type.Func, d.Type.Func + 4, |
| d, // arguably this should be d.Type |
| }, |
| { |
| // "(" FieldList |
| d.Recv.Pos(), d.Recv.Pos() + 1, |
| d.Recv, |
| }, |
| { |
| // "recv" Ident |
| d.Recv.List[0].Pos(), d.Recv.List[0].Pos() + 1, |
| d.Recv.List[0].Type, |
| }, |
| { |
| // "name" Ident |
| d.Name.Pos(), d.Name.Pos() + 1, |
| d.Name, |
| }, |
| { |
| // "(" FieldList |
| d.Type.Params.Pos(), d.Type.Params.Pos() + 1, |
| d.Type.Params, |
| }, |
| { |
| // "params" Ident |
| d.Type.Params.List[0].Pos(), d.Type.Params.List[0].Pos() + 1, |
| d.Type.Params.List[0].Type, |
| }, |
| { |
| // "(" FieldList |
| d.Type.Results.Pos(), d.Type.Results.Pos() + 1, |
| d.Type.Results, |
| }, |
| { |
| // "results" Ident |
| d.Type.Results.List[0].Pos(), d.Type.Results.List[0].Pos() + 1, |
| d.Type.Results.List[0].Type, |
| }, |
| { |
| // "{" BlockStmt |
| d.Body.Pos(), d.Body.Pos() + 1, |
| d.Body, |
| }, |
| { |
| // "body" Ident |
| d.Body.List[0].Pos(), d.Body.List[0].Pos() + 1, |
| d.Body.List[0].(*ast.ExprStmt).X, |
| }, |
| } { |
| cur, ok := inspect.Root().FindByPos(test.start, test.end) |
| if !ok || cur.Node() == nil { |
| t.Errorf("%s: FindByPos failed", format(test.start, test.end)) |
| continue |
| } |
| got := cur.Node() |
| if got != test.want { |
| t.Errorf("FindByPos:\ninput:\t%s\ngot:\t%s (%T)\nwant:\t%s (%T)", |
| format(test.start, test.end), |
| format(got.Pos(), got.End()), got, |
| format(test.want.Pos(), test.want.End()), test.want) |
| } |
| } |
| } |
| |
| func is[T any](x any) bool { |
| _, ok := x.(T) |
| return ok |
| } |
| |
| // sliceTypes is a debugging helper that formats each slice element with %T. |
| func sliceTypes[T any](slice []T) string { |
| var buf strings.Builder |
| buf.WriteByte('[') |
| for i, elem := range slice { |
| if i > 0 { |
| buf.WriteByte(' ') |
| } |
| fmt.Fprintf(&buf, "%T", elem) |
| } |
| buf.WriteByte(']') |
| return buf.String() |
| } |
| |
| func BenchmarkInspectCalls(b *testing.B) { |
| inspect := netInspect |
| |
| // Measure marginal cost of traversal. |
| |
| callExprs := []ast.Node{(*ast.CallExpr)(nil)} |
| |
| b.Run("Preorder", func(b *testing.B) { |
| var ncalls int |
| for b.Loop() { |
| inspect.Preorder(callExprs, func(n ast.Node) { |
| _ = n.(*ast.CallExpr) |
| ncalls++ |
| }) |
| } |
| }) |
| |
| b.Run("WithStack", func(b *testing.B) { |
| var ncalls int |
| for b.Loop() { |
| inspect.WithStack(callExprs, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { |
| _ = n.(*ast.CallExpr) |
| if push { |
| ncalls++ |
| } |
| return true |
| }) |
| } |
| }) |
| |
| b.Run("Cursor", func(b *testing.B) { |
| var ncalls int |
| for b.Loop() { |
| for cur := range inspect.Root().Preorder(callExprs...) { |
| _ = cur.Node().(*ast.CallExpr) |
| ncalls++ |
| } |
| } |
| }) |
| |
| b.Run("CursorEnclosing", func(b *testing.B) { |
| var ncalls int |
| for b.Loop() { |
| for cur := range inspect.Root().Preorder(callExprs...) { |
| _ = cur.Node().(*ast.CallExpr) |
| for range cur.Enclosing() { |
| } |
| ncalls++ |
| } |
| } |
| }) |
| } |
| |
| // This benchmark compares methods for finding a known node in a tree. |
| func BenchmarkCursor_FindNode(b *testing.B) { |
| root := netInspect.Root() |
| |
| callExprs := []ast.Node{(*ast.CallExpr)(nil)} |
| |
| // Choose a needle in the haystack to use as the search target: |
| // a CallExpr not too near the start nor at too shallow a depth. |
| var needle inspector.Cursor |
| { |
| count := 0 |
| found := false |
| for c := range root.Preorder(callExprs...) { |
| count++ |
| if count >= 1000 && iterlen(c.Enclosing()) >= 6 { |
| needle = c |
| found = true |
| break |
| } |
| } |
| if !found { |
| b.Fatal("can't choose needle") |
| } |
| } |
| |
| b.ResetTimer() |
| |
| b.Run("Cursor.Preorder", func(b *testing.B) { |
| needleNode := needle.Node() |
| for b.Loop() { |
| var found inspector.Cursor |
| for c := range root.Preorder(callExprs...) { |
| if c.Node() == needleNode { |
| found = c |
| break |
| } |
| } |
| if found != needle { |
| b.Errorf("Preorder search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| |
| // This method is about 10-15% faster than Cursor.Preorder. |
| b.Run("Cursor.FindNode", func(b *testing.B) { |
| for b.Loop() { |
| found, ok := root.FindNode(needle.Node()) |
| if !ok || found != needle { |
| b.Errorf("FindNode search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| |
| // This method is about 100x (!) faster than Cursor.Preorder. |
| b.Run("Cursor.FindPos", func(b *testing.B) { |
| needleNode := needle.Node() |
| for b.Loop() { |
| found, ok := root.FindByPos(needleNode.Pos(), needleNode.End()) |
| if !ok || found != needle { |
| b.Errorf("FindPos search failed: got %v, want %v", found, needle) |
| } |
| } |
| }) |
| } |
| |
| func iterlen[T any](seq iter.Seq[T]) (len int) { |
| for range seq { |
| len++ |
| } |
| return |
| } |