vulncheck: add forward slicing logic

This code is copied from vulndb/internal/audit. The logic is used to
compute a forward reachable set of functions starting from program entry
points. This set is effectively a conservative estimate of the program
subset relevant for vulncheck that is passed to vta call graph analysis.

Change-Id: I83c213c137976d5a3877eb478338affb9c72960b
Reviewed-on: https://go-review.googlesource.com/c/exp/+/363663
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
new file mode 100644
index 0000000..7b96283
--- /dev/null
+++ b/vulncheck/slicing.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/ssa"
+)
+
+// forwardReachableFrom computes the set of functions forward reachable from `sources`.
+// A function f is reachable from a function g if f is an anonymous function defined
+// in g or a function called in g as given by the callgraph `cg`.
+func forwardReachableFrom(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
+	m := make(map[*ssa.Function]bool)
+	for s := range sources {
+		forward(s, cg, m)
+	}
+	return m
+}
+
+func forward(f *ssa.Function, cg *callgraph.Graph, seen map[*ssa.Function]bool) {
+	if seen[f] {
+		return
+	}
+	seen[f] = true
+	var buf [10]*ssa.Value // avoid alloc in common case
+	for _, b := range f.Blocks {
+		for _, instr := range b.Instrs {
+			switch i := instr.(type) {
+			case ssa.CallInstruction:
+				for _, c := range siteCallees(i, cg) {
+					forward(c, cg, seen)
+				}
+			default:
+				for _, op := range i.Operands(buf[:0]) {
+					if fn, ok := (*op).(*ssa.Function); ok {
+						forward(fn, cg, seen)
+					}
+				}
+			}
+		}
+	}
+}
+
+// pruneSlice removes functions in `slice` that are in `toPrune`.
+func pruneSlice(slice map[*ssa.Function]bool, toPrune map[*ssa.Function]bool) {
+	for f := range slice {
+		if toPrune[f] {
+			delete(slice, f)
+		}
+	}
+}
diff --git a/vulncheck/slicing_test.go b/vulncheck/slicing_test.go
new file mode 100644
index 0000000..ae8bfb8
--- /dev/null
+++ b/vulncheck/slicing_test.go
@@ -0,0 +1,118 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"path"
+	"reflect"
+	"testing"
+
+	"golang.org/x/tools/go/callgraph/cha"
+	"golang.org/x/tools/go/packages/packagestest"
+	"golang.org/x/tools/go/ssa"
+	"golang.org/x/tools/go/ssa/ssautil"
+)
+
+// funcNames returns a set of function names for `funcs`.
+func funcNames(funcs map[*ssa.Function]bool) map[string]bool {
+	fs := make(map[string]bool)
+	for f := range funcs {
+		fs[dbFuncName(f)] = true
+	}
+	return fs
+}
+
+func TestSlicing(t *testing.T) {
+	// test program
+	p := `
+package slice
+
+func X() {}
+func Y() {}
+
+// not reachable
+func id(i int) int {
+        return i
+}
+
+// not reachable
+func inc(i int) int {
+        return i + 1
+}
+
+func Apply(b bool, h func()) {
+        if b {
+                func() {
+                        print("applied")
+                }()
+                return
+        }
+        h()
+}
+
+type I interface {
+        Foo()
+}
+
+type A struct{}
+
+func (a A) Foo() {}
+
+// not reachable
+func (a A) Bar() {}
+
+type B struct{}
+
+func (b B) Foo() {}
+
+func debug(s string) {
+        print(s)
+}
+
+func Do(i I, input string) {
+        debug(input)
+
+        i.Foo()
+
+        func(x string) {
+                func(l int) {
+                        print(l)
+                }(len(x))
+        }(input)
+}`
+
+	e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+		{
+			Name:  "some/module",
+			Files: map[string]interface{}{"slice/slice.go": p},
+		},
+	})
+
+	pkgs, err := loadPackages(e, path.Join(e.Temp(), "/module/slice"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
+	prog.Build()
+
+	pkg := ssaPkgs[0]
+	sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
+	fs := funcNames(forwardReachableFrom(sources, cha.CallGraph(prog)))
+	want := map[string]bool{
+		"Apply":   true,
+		"Apply$1": true,
+		"X":       true,
+		"Y":       true,
+		"Do":      true,
+		"Do$1":    true,
+		"Do$1$1":  true,
+		"debug":   true,
+		"A.Foo":   true,
+		"B.Foo":   true,
+	}
+	if !reflect.DeepEqual(want, fs) {
+		t.Errorf("want %v; got %v", want, fs)
+	}
+}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
new file mode 100644
index 0000000..861a251
--- /dev/null
+++ b/vulncheck/utils.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"go/types"
+	"strings"
+
+	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/types/typeutil"
+
+	"golang.org/x/tools/go/ssa"
+)
+
+// siteCallees computes a set of callees for call site `call` given program `callgraph`.
+func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
+	var matches []*ssa.Function
+
+	node := callgraph.Nodes[call.Parent()]
+	if node == nil {
+		return nil
+	}
+
+	for _, edge := range node.Out {
+		// Some callgraph analyses, such as CHA, might return synthetic (interface)
+		// methods as well as the concrete methods. Skip such synthetic functions.
+		if edge.Site == call {
+			matches = append(matches, edge.Callee.Func)
+		}
+	}
+	return matches
+}
+
+// dbFuncName computes a function name consistent with the namings used in vulnerability
+// databases. Effectively, a qualified name of a function local to its enclosing package.
+// If a receiver is a pointer, this information is not encoded in the resulting name. The
+// name of anonymous functions is simply "". The function names are unique subject to the
+// enclosing package, but not globally.
+//
+// Examples:
+//   func (a A) foo (...) {...}  -> A.foo
+//   func foo(...) {...}         -> foo
+//   func (b *B) bar (...) {...} -> B.bar
+func dbFuncName(f *ssa.Function) string {
+	var typeFormat func(t types.Type) string
+	typeFormat = func(t types.Type) string {
+		switch tt := t.(type) {
+		case *types.Pointer:
+			return typeFormat(tt.Elem())
+		case *types.Named:
+			return tt.Obj().Name()
+		default:
+			return types.TypeString(t, func(p *types.Package) string { return "" })
+		}
+	}
+	selectBound := func(f *ssa.Function) types.Type {
+		// If f is a "bound" function introduced by ssa for a given type, return the type.
+		// When "f" is a "bound" function, it will have 1 free variable of that type within
+		// the function. This is subject to change when ssa changes.
+		if len(f.FreeVars) == 1 && strings.HasPrefix(f.Synthetic, "bound ") {
+			return f.FreeVars[0].Type()
+		}
+		return nil
+	}
+	selectThunk := func(f *ssa.Function) types.Type {
+		// If f is a "thunk" function introduced by ssa for a given type, return the type.
+		// When "f" is a "thunk" function, the first parameter will have that type within
+		// the function. This is subject to change when ssa changes.
+		params := f.Signature.Params() // params.Len() == 1 then params != nil.
+		if strings.HasPrefix(f.Synthetic, "thunk ") && params.Len() >= 1 {
+			if first := params.At(0); first != nil {
+				return first.Type()
+			}
+		}
+		return nil
+	}
+	var qprefix string
+	if recv := f.Signature.Recv(); recv != nil {
+		qprefix = typeFormat(recv.Type())
+	} else if btype := selectBound(f); btype != nil {
+		qprefix = typeFormat(btype)
+	} else if ttype := selectThunk(f); ttype != nil {
+		qprefix = typeFormat(ttype)
+	}
+
+	if qprefix == "" {
+		return f.Name()
+	}
+	return qprefix + "." + f.Name()
+}
+
+// memberFuncs returns functions associated with the `member`:
+// 1) `member` itself if `member` is a function
+// 2) `member` methods if `member` is a type
+// 3) empty list otherwise
+func memberFuncs(member ssa.Member, prog *ssa.Program) []*ssa.Function {
+	switch t := member.(type) {
+	case *ssa.Type:
+		methods := typeutil.IntuitiveMethodSet(t.Type(), &prog.MethodSets)
+		var funcs []*ssa.Function
+		for _, m := range methods {
+			if f := prog.MethodValue(m); f != nil {
+				funcs = append(funcs, f)
+			}
+		}
+		return funcs
+	case *ssa.Function:
+		return []*ssa.Function{t}
+	default:
+		return nil
+	}
+}