vulncheck: add forward slicing logic
This code is copied from vulndb/internal/audit. The logic is used to
compute a forward reachable set of functions starting from program entry
points. This set is effectively a conservative estimate of the program
subset relevant for vulncheck that is passed to vta call graph analysis.
Cherry-picked: https://go-review.googlesource.com/c/exp/+/363663
Change-Id: I99d8af3825a4dc85a3c91559488c43128af25e17
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395040
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/slicing_test.go b/vulncheck/slicing_test.go
new file mode 100644
index 0000000..ae8bfb8
--- /dev/null
+++ b/vulncheck/slicing_test.go
@@ -0,0 +1,118 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+ "path"
+ "reflect"
+ "testing"
+
+ "golang.org/x/tools/go/callgraph/cha"
+ "golang.org/x/tools/go/packages/packagestest"
+ "golang.org/x/tools/go/ssa"
+ "golang.org/x/tools/go/ssa/ssautil"
+)
+
+// funcNames returns a set of function names for `funcs`.
+func funcNames(funcs map[*ssa.Function]bool) map[string]bool {
+ fs := make(map[string]bool)
+ for f := range funcs {
+ fs[dbFuncName(f)] = true
+ }
+ return fs
+}
+
+func TestSlicing(t *testing.T) {
+ // test program
+ p := `
+package slice
+
+func X() {}
+func Y() {}
+
+// not reachable
+func id(i int) int {
+ return i
+}
+
+// not reachable
+func inc(i int) int {
+ return i + 1
+}
+
+func Apply(b bool, h func()) {
+ if b {
+ func() {
+ print("applied")
+ }()
+ return
+ }
+ h()
+}
+
+type I interface {
+ Foo()
+}
+
+type A struct{}
+
+func (a A) Foo() {}
+
+// not reachable
+func (a A) Bar() {}
+
+type B struct{}
+
+func (b B) Foo() {}
+
+func debug(s string) {
+ print(s)
+}
+
+func Do(i I, input string) {
+ debug(input)
+
+ i.Foo()
+
+ func(x string) {
+ func(l int) {
+ print(l)
+ }(len(x))
+ }(input)
+}`
+
+ e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+ {
+ Name: "some/module",
+ Files: map[string]interface{}{"slice/slice.go": p},
+ },
+ })
+
+ pkgs, err := loadPackages(e, path.Join(e.Temp(), "/module/slice"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
+ prog.Build()
+
+ pkg := ssaPkgs[0]
+ sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
+ fs := funcNames(forwardReachableFrom(sources, cha.CallGraph(prog)))
+ want := map[string]bool{
+ "Apply": true,
+ "Apply$1": true,
+ "X": true,
+ "Y": true,
+ "Do": true,
+ "Do$1": true,
+ "Do$1$1": true,
+ "debug": true,
+ "A.Foo": true,
+ "B.Foo": true,
+ }
+ if !reflect.DeepEqual(want, fs) {
+ t.Errorf("want %v; got %v", want, fs)
+ }
+}