vulncheck: add support for callgraph source vuln detection

Cherry-picked: https://go-review.googlesource.com/c/exp/+/365114

Change-Id: Ic0611c071c9c8fa1ea2cf8818431b923de13b0f5
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395043
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/entries.go b/vulncheck/entries.go
new file mode 100644
index 0000000..834d590
--- /dev/null
+++ b/vulncheck/entries.go
@@ -0,0 +1,50 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"strings"
+
+	"golang.org/x/tools/go/ssa"
+)
+
+func entryPoints(topPackages []*ssa.Package) []*ssa.Function {
+	var entries []*ssa.Function
+	for _, pkg := range topPackages {
+		if pkg.Pkg.Name() == "main" {
+			// for "main" packages the only valid entry points are the "main"
+			// function and any "init#" functions, even if there are other
+			// exported functions or types. similarly to isEntry it should be
+			// safe to ignore the validity of the main or init# signatures,
+			// since the compiler will reject malformed definitions,
+			// and the init function is synthetic
+			entries = append(entries, memberFuncs(pkg.Members["main"], pkg.Prog)...)
+			for name, member := range pkg.Members {
+				if strings.HasPrefix(name, "init#") || name == "init" {
+					entries = append(entries, memberFuncs(member, pkg.Prog)...)
+				}
+			}
+			continue
+		}
+		for _, member := range pkg.Members {
+			for _, f := range memberFuncs(member, pkg.Prog) {
+				if isEntry(f) {
+					entries = append(entries, f)
+				}
+			}
+		}
+	}
+	return entries
+}
+
+func isEntry(f *ssa.Function) bool {
+	// it should be safe to ignore checking that the signature of the "init" function
+	// is valid, since it is synthetic
+	if f.Name() == "init" && f.Synthetic == "package initializer" {
+		return true
+	}
+
+	return f.Synthetic == "" && f.Object() != nil && f.Object().Exported()
+}
diff --git a/vulncheck/helpers_test.go b/vulncheck/helpers_test.go
index 453d47a..58afe38 100644
--- a/vulncheck/helpers_test.go
+++ b/vulncheck/helpers_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"fmt"
+	"sort"
 
 	"golang.org/x/tools/go/packages"
 	"golang.org/x/tools/go/packages/packagestest"
@@ -79,6 +80,8 @@
 			m[pred.Path] = append(m[pred.Path], n.Path)
 		}
 	}
+
+	sortStrMap(m)
 	return m
 }
 
@@ -90,9 +93,52 @@
 			m[pred.Path] = append(m[pred.Path], n.Path)
 		}
 	}
+
+	sortStrMap(m)
 	return m
 }
 
+func callGraphToStrMap(cg *CallGraph) map[string][]string {
+	type edge struct {
+		// src and dest are ids ofr source and
+		// destination nodes in a callgraph edge.
+		src, dst int
+	}
+	// seen edges, to avoid repetitions
+	seen := make(map[edge]bool)
+
+	funcName := func(fn *FuncNode) string {
+		if fn.RecvType == "" {
+			return fmt.Sprintf("%s.%s", fn.PkgPath, fn.Name)
+		}
+		return fmt.Sprintf("%s.%s", fn.RecvType, fn.Name)
+	}
+
+	m := make(map[string][]string)
+	for _, n := range cg.Funcs {
+		fName := funcName(n)
+		for _, callsite := range n.CallSites {
+			e := edge{src: callsite.Parent, dst: n.ID}
+			if seen[e] {
+				continue
+			}
+			caller := cg.Funcs[e.src]
+			callerName := funcName(caller)
+			m[callerName] = append(m[callerName], fName)
+		}
+	}
+
+	sortStrMap(m)
+	return m
+}
+
+// sortStrMap sorts the map string slice values to make them deterministic.
+func sortStrMap(m map[string][]string) {
+	for _, strs := range m {
+		sort.Strings(strs)
+	}
+}
+
 func loadPackages(e *packagestest.Exported, patterns ...string) ([]*packages.Package, error) {
 	e.Config.Mode |= packages.NeedModule | packages.NeedName | packages.NeedFiles |
 		packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedTypes |
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
index 7b96283..7f76787 100644
--- a/vulncheck/slicing.go
+++ b/vulncheck/slicing.go
@@ -44,11 +44,11 @@
 	}
 }
 
-// pruneSlice removes functions in `slice` that are in `toPrune`.
-func pruneSlice(slice map[*ssa.Function]bool, toPrune map[*ssa.Function]bool) {
-	for f := range slice {
-		if toPrune[f] {
-			delete(slice, f)
+// pruneSet removes functions in `set` that are in `toPrune`.
+func pruneSet(set, toPrune map[*ssa.Function]bool) {
+	for f := range set {
+		if !toPrune[f] {
+			delete(set, f)
 		}
 	}
 }
diff --git a/vulncheck/source.go b/vulncheck/source.go
index cf4961e..4ec495c 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -5,7 +5,11 @@
 package vulncheck
 
 import (
+	"golang.org/x/tools/go/callgraph"
 	"golang.org/x/tools/go/packages"
+	"golang.org/x/tools/go/ssa"
+	"golang.org/x/tools/go/ssa/ssautil"
+	"golang.org/x/vulndb/osv"
 )
 
 // Source detects vulnerabilities in pkgs and computes slices of
@@ -16,10 +20,6 @@
 //  - call graph leading to the use of a known vulnerable function
 //    or method
 func Source(pkgs []*packages.Package, cfg *Config) (*Result, error) {
-	if !cfg.ImportsOnly {
-		panic("call graph feature is currently unsupported")
-	}
-
 	modVulns, err := fetchVulnerabilities(cfg.Client, extractModules(pkgs))
 	if err != nil {
 		return nil, err
@@ -28,8 +28,21 @@
 	result := &Result{
 		Imports:  &ImportGraph{Packages: make(map[int]*PkgNode)},
 		Requires: &RequireGraph{Modules: make(map[int]*ModNode)},
+		Calls:    &CallGraph{Funcs: make(map[int]*FuncNode)},
 	}
+
 	vulnPkgModSlice(pkgs, modVulns, result)
+
+	if cfg.ImportsOnly {
+		return result, nil
+	}
+
+	prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
+	prog.Build()
+	entries := entryPoints(ssaPkgs)
+	cg := callGraph(prog, entries)
+	vulnCallGraphSlice(entries, modVulns, cg, result)
+
 	return result, nil
 }
 
@@ -223,3 +236,128 @@
 	}
 	return id
 }
+
+func vulnCallGraphSlice(entries []*ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result) {
+	// analyzedFuncs contains information on functions analyzed thus far.
+	// If a function is mapped to nil, this means it has been visited
+	// but it does not lead to a vulnerable call. Otherwise, a visited
+	// function is mapped to Calls function node.
+	analyzedFuncs := make(map[*ssa.Function]*FuncNode)
+	for _, entry := range entries {
+		// Top level entries that lead to vulnerable calls
+		// are stored as result.Calls graph entry points.
+		if e := vulnCallSlice(entry, modVulns, cg, result, analyzedFuncs); e != nil {
+			result.Calls.Entries = append(result.Calls.Entries, e)
+		}
+	}
+}
+
+// funID is an id counter for nodes of Calls graph.
+var funID int = 0
+
+func nextFunID() int {
+	funID++
+	return funID
+}
+
+// vulnCallSlice checks if f has some vulnerabilities or transitively calls
+// a function with known vulnerabilities. If so, populates result.Calls
+// graph with this reachability information and returns the result.Call
+// function node. Otherwise, returns nil.
+func vulnCallSlice(f *ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result, analyzed map[*ssa.Function]*FuncNode) *FuncNode {
+	if fn, ok := analyzed[f]; ok {
+		return fn
+	}
+
+	fn := cg.Nodes[f]
+	if fn == nil {
+		return nil
+	}
+
+	// Check if f has known vulnerabilities.
+	vulns := modVulns.VulnsForSymbol(f.Package().Pkg.Path(), dbFuncName(f))
+
+	var funNode *FuncNode
+	// If there are vulnerabilities for f, create node for f and
+	// save it immediatelly. This allows us to include F in the
+	// slice when analyzing chain V -> F -> V where V is vulnerable.
+	if len(vulns) > 0 {
+		funNode = funcNode(f)
+	}
+	analyzed[f] = funNode
+
+	// Recursively compute which callees lead to a call of a
+	// vulnerable function. Remember the nodes of such callees.
+	type siteNode struct {
+		call ssa.CallInstruction
+		fn   *FuncNode
+	}
+	var onSlice []siteNode
+	for _, edge := range fn.Out {
+		if calleeNode := vulnCallSlice(edge.Callee.Func, modVulns, cg, result, analyzed); calleeNode != nil {
+			onSlice = append(onSlice, siteNode{call: edge.Site, fn: calleeNode})
+		}
+	}
+
+	// If f is not vulnerable nor it transitively leads
+	// to vulnerable calls, jump out.
+	if len(onSlice) == 0 && len(vulns) == 0 {
+		return nil
+	}
+
+	// If f is not vulnerable, then at this point it has
+	// to be on the path leading to a vulnerable call.
+	if funNode == nil {
+		funNode = funcNode(f)
+		analyzed[f] = funNode
+	}
+	result.Calls.Funcs[funNode.ID] = funNode
+
+	// Save node predecessor information.
+	for _, calleeSliceInfo := range onSlice {
+		call, node := calleeSliceInfo.call, calleeSliceInfo.fn
+		cs := &CallSite{
+			Parent:   funNode.ID,
+			Name:     call.Common().Value.Name(),
+			RecvType: callRecvType(call),
+			Resolved: resolved(call),
+			Pos:      instrPosition(call),
+		}
+		node.CallSites = append(node.CallSites, cs)
+	}
+
+	// Populate CallSink field for each detected vuln symbol.
+	for _, osv := range vulns {
+		for _, affected := range osv.Affected {
+			if affected.Package.Name != funNode.PkgPath {
+				continue
+			}
+			for _, symbol := range affected.EcosystemSpecific.Symbols {
+				addCallSinkForVuln(funNode.ID, osv, symbol, funNode.PkgPath, result)
+			}
+		}
+	}
+	return funNode
+}
+
+func funcNode(f *ssa.Function) *FuncNode {
+	id := nextFunID()
+	return &FuncNode{
+		ID:       id,
+		Name:     f.Name(),
+		PkgPath:  f.Package().Pkg.Path(),
+		RecvType: funcRecvType(f),
+		Pos:      funcPosition(f),
+	}
+}
+
+// addCallSinkForVuln adds callID as call sink to vuln of result.Vulns
+// identified with <osv, symbol, pkg>.
+func addCallSinkForVuln(callID int, osv *osv.Entry, symbol, pkg string, result *Result) {
+	for _, vuln := range result.Vulns {
+		if vuln.OSV == osv && vuln.Symbol == symbol && vuln.PkgPath == pkg {
+			vuln.CallSink = callID
+			return
+		}
+	}
+}
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index 8e1effa..df0e272 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -158,3 +158,223 @@
 		t.Errorf("want %v requires graph; got %v", wantRequires, rgStrMap)
 	}
 }
+
+// TestCallGraph checks for call graph vuln slicing correctness.
+// The inlined test code has the following call graph
+//
+//          x.X
+//        /  |  \
+//       /  d.D1 avuln.VulnData.Vuln1
+//      /  /  |
+//     c.C1  d.internal.Vuln1
+//      |
+//    avuln.VulnData.Vuln2
+//
+//         --------------------y.Y-------------------------------
+//        /           /              \         \         \       \
+//       /           /                \         \         \       \
+//      /           /                  \         \         \       \
+//    c.C4 c.vulnWrap.V.Vuln1(=nil)   c.C2   bvuln.Vuln   c.C3   c.C3$1
+//      |                                       | |
+//  y.benign                                    e.E
+//
+// and this slice
+//
+//          x.X
+//        /  |  \
+//       /  d.D1 avuln.VulnData.Vuln1
+//      /  /
+//     c.C1
+//      |
+//    avuln.VulnData.Vuln2
+//
+//     y.Y
+//      |
+//  bvuln.Vuln
+//     | |
+//     e.E
+// related to avuln.VulnData.{Vuln1, Vuln2} and bvuln.Vuln vulnerabilities.
+func TestCallGraph(t *testing.T) {
+	e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+		{
+			Name: "golang.org/entry",
+			Files: map[string]interface{}{
+				"x/x.go": `
+			package x
+
+			import (
+				"golang.org/cmod/c"
+				"golang.org/dmod/d"
+			)
+
+			func X(x bool) {
+				if x {
+					c.C1().Vuln1() // vuln use: Vuln1
+				} else {
+					d.D1() // no vuln use
+				}
+			}
+			`,
+				"y/y.go": `
+			package y
+
+			import (
+				"golang.org/cmod/c"
+			)
+
+			func Y(y bool) {
+				if y {
+					c.C2()() // vuln use: bvuln.Vuln
+				} else {
+					c.C3()()
+					w := c.C4(benign)
+					w.V.Vuln1() // no vuln use: Vuln1 does not belong to vulnerable type
+				}
+			}
+
+			func benign(i c.I) {}
+		`}},
+		{
+			Name: "golang.org/cmod@v1.1.3",
+			Files: map[string]interface{}{"c/c.go": `
+			package c
+
+			import (
+				"golang.org/amod/avuln"
+				"golang.org/bmod/bvuln"
+			)
+
+			type I interface {
+				Vuln1()
+			}
+
+			func C1() I {
+				v := avuln.VulnData{}
+				v.Vuln2() // vuln use
+				return v
+			}
+
+			func C2() func() {
+				return bvuln.Vuln
+			}
+
+			func C3() func() {
+				return func() {}
+			}
+
+			type vulnWrap struct {
+				V I
+			}
+
+			func C4(f func(i I)) vulnWrap {
+				f(avuln.VulnData{})
+				return vulnWrap{}
+			}
+			`},
+		},
+		{
+			Name: "golang.org/dmod@v0.5.0",
+			Files: map[string]interface{}{"d/d.go": `
+			package d
+
+			import (
+				"golang.org/cmod/c"
+			)
+
+			type internal struct{}
+
+			func (i internal) Vuln1() {}
+
+			func D1() {
+				c.C1() // transitive vuln use
+				var i c.I
+				i = internal{}
+				i.Vuln1() // no vuln use
+			}
+			`},
+		},
+		{
+			Name: "golang.org/amod@v1.1.3",
+			Files: map[string]interface{}{"avuln/avuln.go": `
+			package avuln
+
+			type VulnData struct {}
+			func (v VulnData) Vuln1() {}
+			func (v VulnData) Vuln2() {}
+			`},
+		},
+		{
+			Name: "golang.org/bmod@v0.5.0",
+			Files: map[string]interface{}{"bvuln/bvuln.go": `
+			package bvuln
+
+			import (
+				"golang.org/emod/e"
+			)
+
+			func Vuln() {
+				e.E(Vuln)
+			}
+			`},
+		},
+		{
+			Name: "golang.org/emod@v1.5.0",
+			Files: map[string]interface{}{"e/e.go": `
+			package e
+
+			func E(f func()) {
+				f()
+			}
+			`},
+		},
+	})
+	defer e.Cleanup()
+
+	// Make sure local vulns can be loaded.
+	fetchingInTesting = true
+	// Load x and y as entry packages.
+	pkgs, err := loadPackages(e, path.Join(e.Temp(), "entry/x"), path.Join(e.Temp(), "entry/y"))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(pkgs) != 2 {
+		t.Fatal("failed to load x and y test packages")
+	}
+
+	cfg := &Config{
+		Client: testClient,
+	}
+	result, err := Source(pkgs, cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that we find the right number of vulnerabilities.
+	// There should be three entries as there are three vulnerable
+	// symbols in the two import-reachable OSVs.
+	if len(result.Vulns) != 3 {
+		t.Errorf("want 3 Vulns, got %d", len(result.Vulns))
+	}
+
+	// Check that vulnerabilities are connected to the call graph.
+	// For the test example, all vulns should have a call sink.
+	for _, v := range result.Vulns {
+		if v.CallSink == 0 {
+			t.Errorf("want CallSink !=0 for %v:%v; got 0", v.Symbol, v.PkgPath)
+		}
+	}
+
+	wantCalls := map[string][]string{
+		"golang.org/entry/x.X":       {"golang.org/amod/avuln.VulnData.Vuln1", "golang.org/cmod/c.C1", "golang.org/dmod/d.D1"},
+		"golang.org/cmod/c.C1":       {"golang.org/amod/avuln.VulnData.Vuln2"},
+		"golang.org/dmod/d.D1":       {"golang.org/cmod/c.C1"},
+		"golang.org/entry/y.Y":       {"golang.org/bmod/bvuln.Vuln"},
+		"golang.org/bmod/bvuln.Vuln": {"golang.org/emod/e.E"},
+		"golang.org/emod/e.E":        {"golang.org/bmod/bvuln.Vuln"},
+	}
+
+	if callStrMap := callGraphToStrMap(result.Calls); !reflect.DeepEqual(wantCalls, callStrMap) {
+		t.Errorf("want %v call graph; got %v", wantCalls, callStrMap)
+	}
+}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index 861a251..eea2078 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -5,15 +5,42 @@
 package vulncheck
 
 import (
+	"bytes"
+	"go/token"
 	"go/types"
 	"strings"
 
 	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/callgraph/cha"
+	"golang.org/x/tools/go/callgraph/vta"
+	"golang.org/x/tools/go/ssa/ssautil"
 	"golang.org/x/tools/go/types/typeutil"
 
 	"golang.org/x/tools/go/ssa"
 )
 
+// callGraph builds a call graph of prog based on VTA analysis.
+func callGraph(prog *ssa.Program, entries []*ssa.Function) *callgraph.Graph {
+	entrySlice := make(map[*ssa.Function]bool)
+	for _, e := range entries {
+		entrySlice[e] = true
+	}
+	initial := cha.CallGraph(prog)
+	allFuncs := ssautil.AllFunctions(prog)
+
+	fslice := forwardReachableFrom(entrySlice, initial)
+	// Keep only actually linked functions.
+	pruneSet(fslice, allFuncs)
+	vtaCg := vta.CallGraph(fslice, initial)
+
+	// Repeat the process once more, this time using
+	// the produced VTA call graph as the base graph.
+	fslice = forwardReachableFrom(entrySlice, vtaCg)
+	pruneSet(fslice, allFuncs)
+
+	return vta.CallGraph(fslice, vtaCg)
+}
+
 // siteCallees computes a set of callees for call site `call` given program `callgraph`.
 func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
 	var matches []*ssa.Function
@@ -112,3 +139,43 @@
 		return nil
 	}
 }
+
+// funcPosition gives the position of `f`. Returns empty token.Position
+// if no file information on `f` is available.
+func funcPosition(f *ssa.Function) *token.Position {
+	pos := f.Prog.Fset.Position(f.Pos())
+	return &pos
+}
+
+// instrPosition gives the position of `instr`. Returns empty token.Position
+// if no file information on `instr` is available.
+func instrPosition(instr ssa.Instruction) *token.Position {
+	pos := instr.Parent().Prog.Fset.Position(instr.Pos())
+	return &pos
+}
+
+func resolved(call ssa.CallInstruction) bool {
+	if call == nil {
+		return true
+	}
+	return call.Common().StaticCallee() != nil
+}
+
+func callRecvType(call ssa.CallInstruction) string {
+	if !call.Common().IsInvoke() {
+		return ""
+	}
+	buf := new(bytes.Buffer)
+	types.WriteType(buf, call.Common().Value.Type(), nil)
+	return buf.String()
+}
+
+func funcRecvType(f *ssa.Function) string {
+	v := f.Signature.Recv()
+	if v == nil {
+		return ""
+	}
+	buf := new(bytes.Buffer)
+	types.WriteType(buf, v.Type(), nil)
+	return buf.String()
+}