go/ast/inspector: skip ranges that do not contain a node type

Skips inspecting a range of nodes that do not contain any nodes
of a given type. Computes a bitmask of type nodes between push
and pop events. Skips forward to pop if nodes in between cannot
match the type mask.

Benchmarking against previous implementation on "net":
- Preorder filtered by *ast.FuncDecl and *ast.FuncLit
  traversal is faster by 11x.
- Preorder filtered by *ast.CallExpr is faster by 10%.
- Unfiltered traveral is 5% slower.
- Constructing events 3% slower.
- Break even for additional computation is 5 *CallExpr
  filtered traversals or 1 *Func{Decl,Lit} filtered
  traversal.

Change-Id: If4cb566474b84186ff42fb80ed7e1ebb0f692cc2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/458075
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Tim King <taking@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/go/ast/inspector/inspector.go b/go/ast/inspector/inspector.go
index af5e17f..3fbfebf 100644
--- a/go/ast/inspector/inspector.go
+++ b/go/ast/inspector/inspector.go
@@ -53,10 +53,13 @@
 // of an ast.Node during a traversal.
 type event struct {
 	node  ast.Node
-	typ   uint64 // typeOf(node)
-	index int    // 1 + index of corresponding pop event, or 0 if this is a pop
+	typ   uint64 // typeOf(node) on push event, or union of typ strictly between push and pop events on pop events
+	index int    // index of corresponding push or pop event
 }
 
+// TODO: Experiment with storing only the second word of event.node (unsafe.Pointer).
+// Type can be recovered from the sole bit in typ.
+
 // Preorder visits all the nodes of the files supplied to New in
 // depth-first order. It calls f(n) for each node n before it visits
 // n's children.
@@ -72,10 +75,17 @@
 	mask := maskOf(types)
 	for i := 0; i < len(in.events); {
 		ev := in.events[i]
-		if ev.typ&mask != 0 {
-			if ev.index > 0 {
+		if ev.index > i {
+			// push
+			if ev.typ&mask != 0 {
 				f(ev.node)
 			}
+			pop := ev.index
+			if in.events[pop].typ&mask == 0 {
+				// Subtrees do not contain types: skip them and pop.
+				i = pop + 1
+				continue
+			}
 		}
 		i++
 	}
@@ -94,15 +104,24 @@
 	mask := maskOf(types)
 	for i := 0; i < len(in.events); {
 		ev := in.events[i]
-		if ev.typ&mask != 0 {
-			if ev.index > 0 {
-				// push
+		if ev.index > i {
+			// push
+			pop := ev.index
+			if ev.typ&mask != 0 {
 				if !f(ev.node, true) {
-					i = ev.index // jump to corresponding pop + 1
+					i = pop + 1 // jump to corresponding pop + 1
 					continue
 				}
-			} else {
-				// pop
+			}
+			if in.events[pop].typ&mask == 0 {
+				// Subtrees do not contain types: skip them.
+				i = pop
+				continue
+			}
+		} else {
+			// pop
+			push := ev.index
+			if in.events[push].typ&mask != 0 {
 				f(ev.node, false)
 			}
 		}
@@ -119,19 +138,26 @@
 	var stack []ast.Node
 	for i := 0; i < len(in.events); {
 		ev := in.events[i]
-		if ev.index > 0 {
+		if ev.index > i {
 			// push
+			pop := ev.index
 			stack = append(stack, ev.node)
 			if ev.typ&mask != 0 {
 				if !f(ev.node, true, stack) {
-					i = ev.index
+					i = pop + 1
 					stack = stack[:len(stack)-1]
 					continue
 				}
 			}
+			if in.events[pop].typ&mask == 0 {
+				// Subtrees does not contain types: skip them.
+				i = pop
+				continue
+			}
 		} else {
 			// pop
-			if ev.typ&mask != 0 {
+			push := ev.index
+			if in.events[push].typ&mask != 0 {
 				f(ev.node, false, stack)
 			}
 			stack = stack[:len(stack)-1]
@@ -157,25 +183,31 @@
 	events := make([]event, 0, capacity)
 
 	var stack []event
+	stack = append(stack, event{}) // include an extra event so file nodes have a parent
 	for _, f := range files {
 		ast.Inspect(f, func(n ast.Node) bool {
 			if n != nil {
 				// push
 				ev := event{
 					node:  n,
-					typ:   typeOf(n),
+					typ:   0,           // temporarily used to accumulate type bits of subtree
 					index: len(events), // push event temporarily holds own index
 				}
 				stack = append(stack, ev)
 				events = append(events, ev)
 			} else {
 				// pop
-				ev := stack[len(stack)-1]
-				stack = stack[:len(stack)-1]
+				top := len(stack) - 1
+				ev := stack[top]
+				typ := typeOf(ev.node)
+				push := ev.index
+				parent := top - 1
 
-				events[ev.index].index = len(events) + 1 // make push refer to pop
+				events[push].typ = typ            // set type of push
+				stack[parent].typ |= typ | ev.typ // parent's typ contains push and pop's typs.
+				events[push].index = len(events)  // make push refer to pop
 
-				ev.index = 0 // turn ev into a pop event
+				stack = stack[:top]
 				events = append(events, ev)
 			}
 			return true
diff --git a/go/ast/inspector/inspector_test.go b/go/ast/inspector/inspector_test.go
index 9e53918..e88d584 100644
--- a/go/ast/inspector/inspector_test.go
+++ b/go/ast/inspector/inspector_test.go
@@ -244,9 +244,11 @@
 // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
 // traversals.
 //
-// BenchmarkNewInspector   4.5 ms
-// BenchmarkNewInspect	   0.33ms
-// BenchmarkASTInspect    1.2  ms
+// BenchmarkASTInspect     1.0 ms
+// BenchmarkNewInspector   2.2 ms
+// BenchmarkInspect        0.39ms
+// BenchmarkInspectFilter  0.01ms
+// BenchmarkInspectCalls   0.14ms
 
 func BenchmarkNewInspector(b *testing.B) {
 	// Measure one-time construction overhead.
@@ -274,6 +276,42 @@
 	}
 }
 
+func BenchmarkInspectFilter(b *testing.B) {
+	b.StopTimer()
+	inspect := inspector.New(netFiles)
+	b.StartTimer()
+
+	// Measure marginal cost of traversal.
+	nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)}
+	var ndecls, nlits int
+	for i := 0; i < b.N; i++ {
+		inspect.Preorder(nodeFilter, func(n ast.Node) {
+			switch n.(type) {
+			case *ast.FuncDecl:
+				ndecls++
+			case *ast.FuncLit:
+				nlits++
+			}
+		})
+	}
+}
+
+func BenchmarkInspectCalls(b *testing.B) {
+	b.StopTimer()
+	inspect := inspector.New(netFiles)
+	b.StartTimer()
+
+	// Measure marginal cost of traversal.
+	nodeFilter := []ast.Node{(*ast.CallExpr)(nil)}
+	var ncalls int
+	for i := 0; i < b.N; i++ {
+		inspect.Preorder(nodeFilter, func(n ast.Node) {
+			_ = n.(*ast.CallExpr)
+			ncalls++
+		})
+	}
+}
+
 func BenchmarkASTInspect(b *testing.B) {
 	var ndecls, nlits int
 	for i := 0; i < b.N; i++ {