cmd/govulncheck: select representative symbols more carefully

Instead of using the entries (top of call stacks) as the symbols to
show to the user, use the lowest symbols on the call stacks from the
packages under analysis. This can greatly reduce the number of symbols.

For example, in k8s.io/kubernetes, many functions call
k8s.io/kubernetes/pkg/util/selinux.SELinuxEnabled, which then calls a
vulnerable symbol in github.com/opencontainers/selinux/go-selinux.

In this particular case, this CL reduces the number of
symbols from 2,384 to 2.

Cherry-picked: https://go-review.googlesource.com/c/exp/+/391894

Change-Id: Ib191cb8ec6a09e607673af7ccdcb34ea121a5b69
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395240
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/cmd/govulncheck/main.go b/cmd/govulncheck/main.go
index 308b081..7e9afc3 100644
--- a/cmd/govulncheck/main.go
+++ b/cmd/govulncheck/main.go
@@ -147,11 +147,7 @@
 	// Build a map from module paths to versions.
 	moduleVersions := map[string]string{}
 	packages.Visit(pkgs, nil, func(p *packages.Package) {
-		m := p.Module
-		if m != nil {
-			if m.Replace != nil {
-				m = m.Replace
-			}
+		if m := packageModule(p); m != nil {
 			moduleVersions[m.Path] = m.Version
 		}
 	})
@@ -162,6 +158,12 @@
 		fmt.Printf("%-*s%s\n", labelWidth, label, text)
 	}
 
+	// Create set of top-level packages, used to find
+	// representative symbols
+	topPackages := map[string]bool{}
+	for _, p := range pkgs {
+		topPackages[p.PkgPath] = true
+	}
 	vulnGroups := groupByIDAndPackage(r.Vulns)
 	for _, vg := range vulnGroups {
 		// All the vulns in vg have the same PkgPath, ModPath and OSV.
@@ -172,15 +174,9 @@
 		fixed := "v" + latestFixed(v0.OSV.Affected)
 		ref := fmt.Sprintf("https://pkg.go.dev/vuln/%s", v0.OSV.ID)
 
-		// Collect unique top of call stacks.
-		fns := map[*vulncheck.FuncNode]bool{}
-		for _, v := range vg {
-			for _, cs := range callStacks[v] {
-				fns[cs[0].Function] = true
-			}
-		}
-		// Use first top of first vuln as representative.
-		rep := funcName(callStacks[v0][0][0].Function)
+		fns := representativeFuncs(vg, topPackages, callStacks)
+		// Use first as representative.
+		rep := funcName(fns[0])
 		var syms string
 		if len(fns) == 1 {
 			syms = rep
@@ -227,6 +223,17 @@
 	return res
 }
 
+func packageModule(p *packages.Package) *packages.Module {
+	m := p.Module
+	if m == nil {
+		return nil
+	}
+	if r := m.Replace; r != nil {
+		return r
+	}
+	return m
+}
+
 func isFile(path string) bool {
 	s, err := os.Stat(path)
 	if err != nil {
@@ -270,6 +277,42 @@
 	return v
 }
 
+// representativeFuncs collects representative functions for the group
+// of vulns from the the call stacks.
+func representativeFuncs(vg []*vulncheck.Vuln, topPkgs map[string]bool, callStacks map[*vulncheck.Vuln][]vulncheck.CallStack) []*vulncheck.FuncNode {
+	// Collect unique top of call stacks.
+	fns := map[*vulncheck.FuncNode]bool{}
+	for _, v := range vg {
+		for _, cs := range callStacks[v] {
+			// Find the lowest function in the stack that is in
+			// one of the top packages.
+			for i := len(cs) - 1; i > 0; i-- {
+				pkg := pkgPath(cs[i].Function)
+				if topPkgs[pkg] {
+					fns[cs[i].Function] = true
+					break
+				}
+			}
+		}
+	}
+	var res []*vulncheck.FuncNode
+	for fn := range fns {
+		res = append(res, fn)
+	}
+	return res
+}
+
+func pkgPath(fn *vulncheck.FuncNode) string {
+	if fn.PkgPath != "" {
+		return fn.PkgPath
+	}
+	s := strings.TrimPrefix(fn.RecvType, "*")
+	if i := strings.LastIndexByte(s, '.'); i > 0 {
+		s = s[:i]
+	}
+	return s
+}
+
 func funcName(fn *vulncheck.FuncNode) string {
 	return strings.TrimPrefix(fn.String(), "*")
 }
diff --git a/cmd/govulncheck/main_test.go b/cmd/govulncheck/main_test.go
index 624e311..7b2a43e 100644
--- a/cmd/govulncheck/main_test.go
+++ b/cmd/govulncheck/main_test.go
@@ -10,6 +10,7 @@
 import (
 	"testing"
 
+	"golang.org/x/exp/vulncheck"
 	"golang.org/x/vuln/osv"
 )
 
@@ -84,3 +85,28 @@
 		})
 	}
 }
+
+func TestPkgPath(t *testing.T) {
+	for _, test := range []struct {
+		in   vulncheck.FuncNode
+		want string
+	}{
+		{
+			vulncheck.FuncNode{PkgPath: "math", Name: "Floor"},
+			"math",
+		},
+		{
+			vulncheck.FuncNode{RecvType: "a.com/b.T", Name: "M"},
+			"a.com/b",
+		},
+		{
+			vulncheck.FuncNode{RecvType: "*a.com/b.T", Name: "M"},
+			"a.com/b",
+		},
+	} {
+		got := pkgPath(&test.in)
+		if got != test.want {
+			t.Errorf("%+v: got %q, want %q", test.in, got, test.want)
+		}
+	}
+}