vulncheck: include more references to functions
Includes references to functions in operators of CallInstructions
when forward slicing for VTA.
Additionally avoids allocating slices of callees when forward
slicing. On net/http benchmark overall makes callgraph construction
5% faster and consume 2% less memory.
Change-Id: I52c775c397fb8ae06d6129957fd27d2516b8e740
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/460422
Run-TryBot: Tim King <taking@google.com>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
index 7f76787..71fffc4 100644
--- a/vulncheck/slicing.go
+++ b/vulncheck/slicing.go
@@ -9,39 +9,40 @@
"golang.org/x/tools/go/ssa"
)
-// forwardReachableFrom computes the set of functions forward reachable from `sources`.
-// A function f is reachable from a function g if f is an anonymous function defined
-// in g or a function called in g as given by the callgraph `cg`.
-func forwardReachableFrom(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
- m := make(map[*ssa.Function]bool)
- for s := range sources {
- forward(s, cg, m)
- }
- return m
-}
+// forwardSlice computes the transitive closure of functions forward reachable
+// via calls in cg or referred to in an instruction starting from `sources`.
+func forwardSlice(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
+ seen := make(map[*ssa.Function]bool)
+ var visit func(f *ssa.Function)
+ visit = func(f *ssa.Function) {
+ if seen[f] {
+ return
+ }
+ seen[f] = true
-func forward(f *ssa.Function, cg *callgraph.Graph, seen map[*ssa.Function]bool) {
- if seen[f] {
- return
- }
- seen[f] = true
- var buf [10]*ssa.Value // avoid alloc in common case
- for _, b := range f.Blocks {
- for _, instr := range b.Instrs {
- switch i := instr.(type) {
- case ssa.CallInstruction:
- for _, c := range siteCallees(i, cg) {
- forward(c, cg, seen)
+ if n := cg.Nodes[f]; n != nil {
+ for _, e := range n.Out {
+ if e.Site != nil {
+ visit(e.Callee.Func)
}
- default:
- for _, op := range i.Operands(buf[:0]) {
+ }
+ }
+
+ var buf [10]*ssa.Value // avoid alloc in common case
+ for _, b := range f.Blocks {
+ for _, instr := range b.Instrs {
+ for _, op := range instr.Operands(buf[:0]) {
if fn, ok := (*op).(*ssa.Function); ok {
- forward(fn, cg, seen)
+ visit(fn)
}
}
}
}
}
+ for source := range sources {
+ visit(source)
+ }
+ return seen
}
// pruneSet removes functions in `set` that are in `toPrune`.
diff --git a/vulncheck/slicing_test.go b/vulncheck/slicing_test.go
index a7e1b0b..dfaa241 100644
--- a/vulncheck/slicing_test.go
+++ b/vulncheck/slicing_test.go
@@ -100,7 +100,7 @@
pkg := ssaPkgs[0]
sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
- fs := funcNames(forwardReachableFrom(sources, cha.CallGraph(prog)))
+ fs := funcNames(forwardSlice(sources, cha.CallGraph(prog)))
want := map[string]bool{
"Apply": true,
"Apply$1": true,
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index ee7ab9b..15dce72 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -70,7 +70,7 @@
initial := cha.CallGraph(prog)
allFuncs := ssautil.AllFunctions(prog)
- fslice := forwardReachableFrom(entrySlice, initial)
+ fslice := forwardSlice(entrySlice, initial)
// Keep only actually linked functions.
pruneSet(fslice, allFuncs)
@@ -81,7 +81,7 @@
// Repeat the process once more, this time using
// the produced VTA call graph as the base graph.
- fslice = forwardReachableFrom(entrySlice, vtaCg)
+ fslice = forwardSlice(entrySlice, vtaCg)
pruneSet(fslice, allFuncs)
if err := ctx.Err(); err != nil { // cancelled?
@@ -92,25 +92,6 @@
return cg, nil
}
-// siteCallees computes a set of callees for call site `call` given program `callgraph`.
-func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
- var matches []*ssa.Function
-
- node := callgraph.Nodes[call.Parent()]
- if node == nil {
- return nil
- }
-
- for _, edge := range node.Out {
- // Some callgraph analyses, such as CHA, might return synthetic (interface)
- // methods as well as the concrete methods. Skip such synthetic functions.
- if edge.Site == call {
- matches = append(matches, edge.Callee.Func)
- }
- }
- return matches
-}
-
// dbTypeFormat formats the name of t according how types
// are encoded in vulnerability database:
// - pointer designation * is skipped