vulncheck: add support for callgraph source vuln detection
Cherry-picked: https://go-review.googlesource.com/c/exp/+/365114
Change-Id: Ic0611c071c9c8fa1ea2cf8818431b923de13b0f5
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395043
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/entries.go b/vulncheck/entries.go
new file mode 100644
index 0000000..834d590
--- /dev/null
+++ b/vulncheck/entries.go
@@ -0,0 +1,50 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+ "strings"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+func entryPoints(topPackages []*ssa.Package) []*ssa.Function {
+ var entries []*ssa.Function
+ for _, pkg := range topPackages {
+ if pkg.Pkg.Name() == "main" {
+ // for "main" packages the only valid entry points are the "main"
+ // function and any "init#" functions, even if there are other
+ // exported functions or types. similarly to isEntry it should be
+ // safe to ignore the validity of the main or init# signatures,
+ // since the compiler will reject malformed definitions,
+ // and the init function is synthetic
+ entries = append(entries, memberFuncs(pkg.Members["main"], pkg.Prog)...)
+ for name, member := range pkg.Members {
+ if strings.HasPrefix(name, "init#") || name == "init" {
+ entries = append(entries, memberFuncs(member, pkg.Prog)...)
+ }
+ }
+ continue
+ }
+ for _, member := range pkg.Members {
+ for _, f := range memberFuncs(member, pkg.Prog) {
+ if isEntry(f) {
+ entries = append(entries, f)
+ }
+ }
+ }
+ }
+ return entries
+}
+
+func isEntry(f *ssa.Function) bool {
+ // it should be safe to ignore checking that the signature of the "init" function
+ // is valid, since it is synthetic
+ if f.Name() == "init" && f.Synthetic == "package initializer" {
+ return true
+ }
+
+ return f.Synthetic == "" && f.Object() != nil && f.Object().Exported()
+}
diff --git a/vulncheck/helpers_test.go b/vulncheck/helpers_test.go
index 453d47a..58afe38 100644
--- a/vulncheck/helpers_test.go
+++ b/vulncheck/helpers_test.go
@@ -6,6 +6,7 @@
import (
"fmt"
+ "sort"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/packages/packagestest"
@@ -79,6 +80,8 @@
m[pred.Path] = append(m[pred.Path], n.Path)
}
}
+
+ sortStrMap(m)
return m
}
@@ -90,9 +93,52 @@
m[pred.Path] = append(m[pred.Path], n.Path)
}
}
+
+ sortStrMap(m)
return m
}
+func callGraphToStrMap(cg *CallGraph) map[string][]string {
+ type edge struct {
+ // src and dest are ids ofr source and
+ // destination nodes in a callgraph edge.
+ src, dst int
+ }
+ // seen edges, to avoid repetitions
+ seen := make(map[edge]bool)
+
+ funcName := func(fn *FuncNode) string {
+ if fn.RecvType == "" {
+ return fmt.Sprintf("%s.%s", fn.PkgPath, fn.Name)
+ }
+ return fmt.Sprintf("%s.%s", fn.RecvType, fn.Name)
+ }
+
+ m := make(map[string][]string)
+ for _, n := range cg.Funcs {
+ fName := funcName(n)
+ for _, callsite := range n.CallSites {
+ e := edge{src: callsite.Parent, dst: n.ID}
+ if seen[e] {
+ continue
+ }
+ caller := cg.Funcs[e.src]
+ callerName := funcName(caller)
+ m[callerName] = append(m[callerName], fName)
+ }
+ }
+
+ sortStrMap(m)
+ return m
+}
+
+// sortStrMap sorts the map string slice values to make them deterministic.
+func sortStrMap(m map[string][]string) {
+ for _, strs := range m {
+ sort.Strings(strs)
+ }
+}
+
func loadPackages(e *packagestest.Exported, patterns ...string) ([]*packages.Package, error) {
e.Config.Mode |= packages.NeedModule | packages.NeedName | packages.NeedFiles |
packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedTypes |
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
index 7b96283..7f76787 100644
--- a/vulncheck/slicing.go
+++ b/vulncheck/slicing.go
@@ -44,11 +44,11 @@
}
}
-// pruneSlice removes functions in `slice` that are in `toPrune`.
-func pruneSlice(slice map[*ssa.Function]bool, toPrune map[*ssa.Function]bool) {
- for f := range slice {
- if toPrune[f] {
- delete(slice, f)
+// pruneSet removes functions in `set` that are in `toPrune`.
+func pruneSet(set, toPrune map[*ssa.Function]bool) {
+ for f := range set {
+ if !toPrune[f] {
+ delete(set, f)
}
}
}
diff --git a/vulncheck/source.go b/vulncheck/source.go
index cf4961e..4ec495c 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -5,7 +5,11 @@
package vulncheck
import (
+ "golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
+ "golang.org/x/tools/go/ssa"
+ "golang.org/x/tools/go/ssa/ssautil"
+ "golang.org/x/vulndb/osv"
)
// Source detects vulnerabilities in pkgs and computes slices of
@@ -16,10 +20,6 @@
// - call graph leading to the use of a known vulnerable function
// or method
func Source(pkgs []*packages.Package, cfg *Config) (*Result, error) {
- if !cfg.ImportsOnly {
- panic("call graph feature is currently unsupported")
- }
-
modVulns, err := fetchVulnerabilities(cfg.Client, extractModules(pkgs))
if err != nil {
return nil, err
@@ -28,8 +28,21 @@
result := &Result{
Imports: &ImportGraph{Packages: make(map[int]*PkgNode)},
Requires: &RequireGraph{Modules: make(map[int]*ModNode)},
+ Calls: &CallGraph{Funcs: make(map[int]*FuncNode)},
}
+
vulnPkgModSlice(pkgs, modVulns, result)
+
+ if cfg.ImportsOnly {
+ return result, nil
+ }
+
+ prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
+ prog.Build()
+ entries := entryPoints(ssaPkgs)
+ cg := callGraph(prog, entries)
+ vulnCallGraphSlice(entries, modVulns, cg, result)
+
return result, nil
}
@@ -223,3 +236,128 @@
}
return id
}
+
+func vulnCallGraphSlice(entries []*ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result) {
+ // analyzedFuncs contains information on functions analyzed thus far.
+ // If a function is mapped to nil, this means it has been visited
+ // but it does not lead to a vulnerable call. Otherwise, a visited
+ // function is mapped to Calls function node.
+ analyzedFuncs := make(map[*ssa.Function]*FuncNode)
+ for _, entry := range entries {
+ // Top level entries that lead to vulnerable calls
+ // are stored as result.Calls graph entry points.
+ if e := vulnCallSlice(entry, modVulns, cg, result, analyzedFuncs); e != nil {
+ result.Calls.Entries = append(result.Calls.Entries, e)
+ }
+ }
+}
+
+// funID is an id counter for nodes of Calls graph.
+var funID int = 0
+
+func nextFunID() int {
+ funID++
+ return funID
+}
+
+// vulnCallSlice checks if f has some vulnerabilities or transitively calls
+// a function with known vulnerabilities. If so, populates result.Calls
+// graph with this reachability information and returns the result.Call
+// function node. Otherwise, returns nil.
+func vulnCallSlice(f *ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result, analyzed map[*ssa.Function]*FuncNode) *FuncNode {
+ if fn, ok := analyzed[f]; ok {
+ return fn
+ }
+
+ fn := cg.Nodes[f]
+ if fn == nil {
+ return nil
+ }
+
+ // Check if f has known vulnerabilities.
+ vulns := modVulns.VulnsForSymbol(f.Package().Pkg.Path(), dbFuncName(f))
+
+ var funNode *FuncNode
+ // If there are vulnerabilities for f, create node for f and
+ // save it immediatelly. This allows us to include F in the
+ // slice when analyzing chain V -> F -> V where V is vulnerable.
+ if len(vulns) > 0 {
+ funNode = funcNode(f)
+ }
+ analyzed[f] = funNode
+
+ // Recursively compute which callees lead to a call of a
+ // vulnerable function. Remember the nodes of such callees.
+ type siteNode struct {
+ call ssa.CallInstruction
+ fn *FuncNode
+ }
+ var onSlice []siteNode
+ for _, edge := range fn.Out {
+ if calleeNode := vulnCallSlice(edge.Callee.Func, modVulns, cg, result, analyzed); calleeNode != nil {
+ onSlice = append(onSlice, siteNode{call: edge.Site, fn: calleeNode})
+ }
+ }
+
+ // If f is not vulnerable nor it transitively leads
+ // to vulnerable calls, jump out.
+ if len(onSlice) == 0 && len(vulns) == 0 {
+ return nil
+ }
+
+ // If f is not vulnerable, then at this point it has
+ // to be on the path leading to a vulnerable call.
+ if funNode == nil {
+ funNode = funcNode(f)
+ analyzed[f] = funNode
+ }
+ result.Calls.Funcs[funNode.ID] = funNode
+
+ // Save node predecessor information.
+ for _, calleeSliceInfo := range onSlice {
+ call, node := calleeSliceInfo.call, calleeSliceInfo.fn
+ cs := &CallSite{
+ Parent: funNode.ID,
+ Name: call.Common().Value.Name(),
+ RecvType: callRecvType(call),
+ Resolved: resolved(call),
+ Pos: instrPosition(call),
+ }
+ node.CallSites = append(node.CallSites, cs)
+ }
+
+ // Populate CallSink field for each detected vuln symbol.
+ for _, osv := range vulns {
+ for _, affected := range osv.Affected {
+ if affected.Package.Name != funNode.PkgPath {
+ continue
+ }
+ for _, symbol := range affected.EcosystemSpecific.Symbols {
+ addCallSinkForVuln(funNode.ID, osv, symbol, funNode.PkgPath, result)
+ }
+ }
+ }
+ return funNode
+}
+
+func funcNode(f *ssa.Function) *FuncNode {
+ id := nextFunID()
+ return &FuncNode{
+ ID: id,
+ Name: f.Name(),
+ PkgPath: f.Package().Pkg.Path(),
+ RecvType: funcRecvType(f),
+ Pos: funcPosition(f),
+ }
+}
+
+// addCallSinkForVuln adds callID as call sink to vuln of result.Vulns
+// identified with <osv, symbol, pkg>.
+func addCallSinkForVuln(callID int, osv *osv.Entry, symbol, pkg string, result *Result) {
+ for _, vuln := range result.Vulns {
+ if vuln.OSV == osv && vuln.Symbol == symbol && vuln.PkgPath == pkg {
+ vuln.CallSink = callID
+ return
+ }
+ }
+}
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index 8e1effa..df0e272 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -158,3 +158,223 @@
t.Errorf("want %v requires graph; got %v", wantRequires, rgStrMap)
}
}
+
+// TestCallGraph checks for call graph vuln slicing correctness.
+// The inlined test code has the following call graph
+//
+// x.X
+// / | \
+// / d.D1 avuln.VulnData.Vuln1
+// / / |
+// c.C1 d.internal.Vuln1
+// |
+// avuln.VulnData.Vuln2
+//
+// --------------------y.Y-------------------------------
+// / / \ \ \ \
+// / / \ \ \ \
+// / / \ \ \ \
+// c.C4 c.vulnWrap.V.Vuln1(=nil) c.C2 bvuln.Vuln c.C3 c.C3$1
+// | | |
+// y.benign e.E
+//
+// and this slice
+//
+// x.X
+// / | \
+// / d.D1 avuln.VulnData.Vuln1
+// / /
+// c.C1
+// |
+// avuln.VulnData.Vuln2
+//
+// y.Y
+// |
+// bvuln.Vuln
+// | |
+// e.E
+// related to avuln.VulnData.{Vuln1, Vuln2} and bvuln.Vuln vulnerabilities.
+func TestCallGraph(t *testing.T) {
+ e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+ {
+ Name: "golang.org/entry",
+ Files: map[string]interface{}{
+ "x/x.go": `
+ package x
+
+ import (
+ "golang.org/cmod/c"
+ "golang.org/dmod/d"
+ )
+
+ func X(x bool) {
+ if x {
+ c.C1().Vuln1() // vuln use: Vuln1
+ } else {
+ d.D1() // no vuln use
+ }
+ }
+ `,
+ "y/y.go": `
+ package y
+
+ import (
+ "golang.org/cmod/c"
+ )
+
+ func Y(y bool) {
+ if y {
+ c.C2()() // vuln use: bvuln.Vuln
+ } else {
+ c.C3()()
+ w := c.C4(benign)
+ w.V.Vuln1() // no vuln use: Vuln1 does not belong to vulnerable type
+ }
+ }
+
+ func benign(i c.I) {}
+ `}},
+ {
+ Name: "golang.org/cmod@v1.1.3",
+ Files: map[string]interface{}{"c/c.go": `
+ package c
+
+ import (
+ "golang.org/amod/avuln"
+ "golang.org/bmod/bvuln"
+ )
+
+ type I interface {
+ Vuln1()
+ }
+
+ func C1() I {
+ v := avuln.VulnData{}
+ v.Vuln2() // vuln use
+ return v
+ }
+
+ func C2() func() {
+ return bvuln.Vuln
+ }
+
+ func C3() func() {
+ return func() {}
+ }
+
+ type vulnWrap struct {
+ V I
+ }
+
+ func C4(f func(i I)) vulnWrap {
+ f(avuln.VulnData{})
+ return vulnWrap{}
+ }
+ `},
+ },
+ {
+ Name: "golang.org/dmod@v0.5.0",
+ Files: map[string]interface{}{"d/d.go": `
+ package d
+
+ import (
+ "golang.org/cmod/c"
+ )
+
+ type internal struct{}
+
+ func (i internal) Vuln1() {}
+
+ func D1() {
+ c.C1() // transitive vuln use
+ var i c.I
+ i = internal{}
+ i.Vuln1() // no vuln use
+ }
+ `},
+ },
+ {
+ Name: "golang.org/amod@v1.1.3",
+ Files: map[string]interface{}{"avuln/avuln.go": `
+ package avuln
+
+ type VulnData struct {}
+ func (v VulnData) Vuln1() {}
+ func (v VulnData) Vuln2() {}
+ `},
+ },
+ {
+ Name: "golang.org/bmod@v0.5.0",
+ Files: map[string]interface{}{"bvuln/bvuln.go": `
+ package bvuln
+
+ import (
+ "golang.org/emod/e"
+ )
+
+ func Vuln() {
+ e.E(Vuln)
+ }
+ `},
+ },
+ {
+ Name: "golang.org/emod@v1.5.0",
+ Files: map[string]interface{}{"e/e.go": `
+ package e
+
+ func E(f func()) {
+ f()
+ }
+ `},
+ },
+ })
+ defer e.Cleanup()
+
+ // Make sure local vulns can be loaded.
+ fetchingInTesting = true
+ // Load x and y as entry packages.
+ pkgs, err := loadPackages(e, path.Join(e.Temp(), "entry/x"), path.Join(e.Temp(), "entry/y"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(pkgs) != 2 {
+ t.Fatal("failed to load x and y test packages")
+ }
+
+ cfg := &Config{
+ Client: testClient,
+ }
+ result, err := Source(pkgs, cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Check that we find the right number of vulnerabilities.
+ // There should be three entries as there are three vulnerable
+ // symbols in the two import-reachable OSVs.
+ if len(result.Vulns) != 3 {
+ t.Errorf("want 3 Vulns, got %d", len(result.Vulns))
+ }
+
+ // Check that vulnerabilities are connected to the call graph.
+ // For the test example, all vulns should have a call sink.
+ for _, v := range result.Vulns {
+ if v.CallSink == 0 {
+ t.Errorf("want CallSink !=0 for %v:%v; got 0", v.Symbol, v.PkgPath)
+ }
+ }
+
+ wantCalls := map[string][]string{
+ "golang.org/entry/x.X": {"golang.org/amod/avuln.VulnData.Vuln1", "golang.org/cmod/c.C1", "golang.org/dmod/d.D1"},
+ "golang.org/cmod/c.C1": {"golang.org/amod/avuln.VulnData.Vuln2"},
+ "golang.org/dmod/d.D1": {"golang.org/cmod/c.C1"},
+ "golang.org/entry/y.Y": {"golang.org/bmod/bvuln.Vuln"},
+ "golang.org/bmod/bvuln.Vuln": {"golang.org/emod/e.E"},
+ "golang.org/emod/e.E": {"golang.org/bmod/bvuln.Vuln"},
+ }
+
+ if callStrMap := callGraphToStrMap(result.Calls); !reflect.DeepEqual(wantCalls, callStrMap) {
+ t.Errorf("want %v call graph; got %v", wantCalls, callStrMap)
+ }
+}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index 861a251..eea2078 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -5,15 +5,42 @@
package vulncheck
import (
+ "bytes"
+ "go/token"
"go/types"
"strings"
"golang.org/x/tools/go/callgraph"
+ "golang.org/x/tools/go/callgraph/cha"
+ "golang.org/x/tools/go/callgraph/vta"
+ "golang.org/x/tools/go/ssa/ssautil"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/go/ssa"
)
+// callGraph builds a call graph of prog based on VTA analysis.
+func callGraph(prog *ssa.Program, entries []*ssa.Function) *callgraph.Graph {
+ entrySlice := make(map[*ssa.Function]bool)
+ for _, e := range entries {
+ entrySlice[e] = true
+ }
+ initial := cha.CallGraph(prog)
+ allFuncs := ssautil.AllFunctions(prog)
+
+ fslice := forwardReachableFrom(entrySlice, initial)
+ // Keep only actually linked functions.
+ pruneSet(fslice, allFuncs)
+ vtaCg := vta.CallGraph(fslice, initial)
+
+ // Repeat the process once more, this time using
+ // the produced VTA call graph as the base graph.
+ fslice = forwardReachableFrom(entrySlice, vtaCg)
+ pruneSet(fslice, allFuncs)
+
+ return vta.CallGraph(fslice, vtaCg)
+}
+
// siteCallees computes a set of callees for call site `call` given program `callgraph`.
func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
var matches []*ssa.Function
@@ -112,3 +139,43 @@
return nil
}
}
+
+// funcPosition gives the position of `f`. Returns empty token.Position
+// if no file information on `f` is available.
+func funcPosition(f *ssa.Function) *token.Position {
+ pos := f.Prog.Fset.Position(f.Pos())
+ return &pos
+}
+
+// instrPosition gives the position of `instr`. Returns empty token.Position
+// if no file information on `instr` is available.
+func instrPosition(instr ssa.Instruction) *token.Position {
+ pos := instr.Parent().Prog.Fset.Position(instr.Pos())
+ return &pos
+}
+
+func resolved(call ssa.CallInstruction) bool {
+ if call == nil {
+ return true
+ }
+ return call.Common().StaticCallee() != nil
+}
+
+func callRecvType(call ssa.CallInstruction) string {
+ if !call.Common().IsInvoke() {
+ return ""
+ }
+ buf := new(bytes.Buffer)
+ types.WriteType(buf, call.Common().Value.Type(), nil)
+ return buf.String()
+}
+
+func funcRecvType(f *ssa.Function) string {
+ v := f.Signature.Recv()
+ if v == nil {
+ return ""
+ }
+ buf := new(bytes.Buffer)
+ types.WriteType(buf, v.Type(), nil)
+ return buf.String()
+}