vulncheck: add forward slicing logic
This code is copied from vulndb/internal/audit. The logic is used to
compute a forward reachable set of functions starting from program entry
points. This set is effectively a conservative estimate of the program
subset relevant for vulncheck that is passed to vta call graph analysis.
Change-Id: I83c213c137976d5a3877eb478338affb9c72960b
Reviewed-on: https://go-review.googlesource.com/c/exp/+/363663
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
new file mode 100644
index 0000000..7b96283
--- /dev/null
+++ b/vulncheck/slicing.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+ "golang.org/x/tools/go/callgraph"
+ "golang.org/x/tools/go/ssa"
+)
+
+// forwardReachableFrom computes the set of functions forward reachable from `sources`.
+// A function f is reachable from a function g if f is an anonymous function defined
+// in g or a function called in g as given by the callgraph `cg`.
+func forwardReachableFrom(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
+ m := make(map[*ssa.Function]bool)
+ for s := range sources {
+ forward(s, cg, m)
+ }
+ return m
+}
+
+func forward(f *ssa.Function, cg *callgraph.Graph, seen map[*ssa.Function]bool) {
+ if seen[f] {
+ return
+ }
+ seen[f] = true
+ var buf [10]*ssa.Value // avoid alloc in common case
+ for _, b := range f.Blocks {
+ for _, instr := range b.Instrs {
+ switch i := instr.(type) {
+ case ssa.CallInstruction:
+ for _, c := range siteCallees(i, cg) {
+ forward(c, cg, seen)
+ }
+ default:
+ for _, op := range i.Operands(buf[:0]) {
+ if fn, ok := (*op).(*ssa.Function); ok {
+ forward(fn, cg, seen)
+ }
+ }
+ }
+ }
+ }
+}
+
+// pruneSlice removes functions in `slice` that are in `toPrune`.
+func pruneSlice(slice map[*ssa.Function]bool, toPrune map[*ssa.Function]bool) {
+ for f := range slice {
+ if toPrune[f] {
+ delete(slice, f)
+ }
+ }
+}
diff --git a/vulncheck/slicing_test.go b/vulncheck/slicing_test.go
new file mode 100644
index 0000000..ae8bfb8
--- /dev/null
+++ b/vulncheck/slicing_test.go
@@ -0,0 +1,118 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+ "path"
+ "reflect"
+ "testing"
+
+ "golang.org/x/tools/go/callgraph/cha"
+ "golang.org/x/tools/go/packages/packagestest"
+ "golang.org/x/tools/go/ssa"
+ "golang.org/x/tools/go/ssa/ssautil"
+)
+
+// funcNames returns a set of function names for `funcs`.
+func funcNames(funcs map[*ssa.Function]bool) map[string]bool {
+ fs := make(map[string]bool)
+ for f := range funcs {
+ fs[dbFuncName(f)] = true
+ }
+ return fs
+}
+
+func TestSlicing(t *testing.T) {
+ // test program
+ p := `
+package slice
+
+func X() {}
+func Y() {}
+
+// not reachable
+func id(i int) int {
+ return i
+}
+
+// not reachable
+func inc(i int) int {
+ return i + 1
+}
+
+func Apply(b bool, h func()) {
+ if b {
+ func() {
+ print("applied")
+ }()
+ return
+ }
+ h()
+}
+
+type I interface {
+ Foo()
+}
+
+type A struct{}
+
+func (a A) Foo() {}
+
+// not reachable
+func (a A) Bar() {}
+
+type B struct{}
+
+func (b B) Foo() {}
+
+func debug(s string) {
+ print(s)
+}
+
+func Do(i I, input string) {
+ debug(input)
+
+ i.Foo()
+
+ func(x string) {
+ func(l int) {
+ print(l)
+ }(len(x))
+ }(input)
+}`
+
+ e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+ {
+ Name: "some/module",
+ Files: map[string]interface{}{"slice/slice.go": p},
+ },
+ })
+
+ pkgs, err := loadPackages(e, path.Join(e.Temp(), "/module/slice"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
+ prog.Build()
+
+ pkg := ssaPkgs[0]
+ sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
+ fs := funcNames(forwardReachableFrom(sources, cha.CallGraph(prog)))
+ want := map[string]bool{
+ "Apply": true,
+ "Apply$1": true,
+ "X": true,
+ "Y": true,
+ "Do": true,
+ "Do$1": true,
+ "Do$1$1": true,
+ "debug": true,
+ "A.Foo": true,
+ "B.Foo": true,
+ }
+ if !reflect.DeepEqual(want, fs) {
+ t.Errorf("want %v; got %v", want, fs)
+ }
+}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
new file mode 100644
index 0000000..861a251
--- /dev/null
+++ b/vulncheck/utils.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+ "go/types"
+ "strings"
+
+ "golang.org/x/tools/go/callgraph"
+ "golang.org/x/tools/go/types/typeutil"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+// siteCallees computes a set of callees for call site `call` given program `callgraph`.
+func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
+ var matches []*ssa.Function
+
+ node := callgraph.Nodes[call.Parent()]
+ if node == nil {
+ return nil
+ }
+
+ for _, edge := range node.Out {
+ // Some callgraph analyses, such as CHA, might return synthetic (interface)
+ // methods as well as the concrete methods. Skip such synthetic functions.
+ if edge.Site == call {
+ matches = append(matches, edge.Callee.Func)
+ }
+ }
+ return matches
+}
+
+// dbFuncName computes a function name consistent with the namings used in vulnerability
+// databases. Effectively, a qualified name of a function local to its enclosing package.
+// If a receiver is a pointer, this information is not encoded in the resulting name. The
+// name of anonymous functions is simply "". The function names are unique subject to the
+// enclosing package, but not globally.
+//
+// Examples:
+// func (a A) foo (...) {...} -> A.foo
+// func foo(...) {...} -> foo
+// func (b *B) bar (...) {...} -> B.bar
+func dbFuncName(f *ssa.Function) string {
+ var typeFormat func(t types.Type) string
+ typeFormat = func(t types.Type) string {
+ switch tt := t.(type) {
+ case *types.Pointer:
+ return typeFormat(tt.Elem())
+ case *types.Named:
+ return tt.Obj().Name()
+ default:
+ return types.TypeString(t, func(p *types.Package) string { return "" })
+ }
+ }
+ selectBound := func(f *ssa.Function) types.Type {
+ // If f is a "bound" function introduced by ssa for a given type, return the type.
+ // When "f" is a "bound" function, it will have 1 free variable of that type within
+ // the function. This is subject to change when ssa changes.
+ if len(f.FreeVars) == 1 && strings.HasPrefix(f.Synthetic, "bound ") {
+ return f.FreeVars[0].Type()
+ }
+ return nil
+ }
+ selectThunk := func(f *ssa.Function) types.Type {
+ // If f is a "thunk" function introduced by ssa for a given type, return the type.
+ // When "f" is a "thunk" function, the first parameter will have that type within
+ // the function. This is subject to change when ssa changes.
+ params := f.Signature.Params() // params.Len() == 1 then params != nil.
+ if strings.HasPrefix(f.Synthetic, "thunk ") && params.Len() >= 1 {
+ if first := params.At(0); first != nil {
+ return first.Type()
+ }
+ }
+ return nil
+ }
+ var qprefix string
+ if recv := f.Signature.Recv(); recv != nil {
+ qprefix = typeFormat(recv.Type())
+ } else if btype := selectBound(f); btype != nil {
+ qprefix = typeFormat(btype)
+ } else if ttype := selectThunk(f); ttype != nil {
+ qprefix = typeFormat(ttype)
+ }
+
+ if qprefix == "" {
+ return f.Name()
+ }
+ return qprefix + "." + f.Name()
+}
+
+// memberFuncs returns functions associated with the `member`:
+// 1) `member` itself if `member` is a function
+// 2) `member` methods if `member` is a type
+// 3) empty list otherwise
+func memberFuncs(member ssa.Member, prog *ssa.Program) []*ssa.Function {
+ switch t := member.(type) {
+ case *ssa.Type:
+ methods := typeutil.IntuitiveMethodSet(t.Type(), &prog.MethodSets)
+ var funcs []*ssa.Function
+ for _, m := range methods {
+ if f := prog.MethodValue(m); f != nil {
+ funcs = append(funcs, f)
+ }
+ }
+ return funcs
+ case *ssa.Function:
+ return []*ssa.Function{t}
+ default:
+ return nil
+ }
+}