cmd/govulncheck: show entry symbols

Show the symbols at the top of the call chains, rather than the
vulnerable symbols. These are the functions in the module that
govulncheck is running on.

Also, group Vulns by ID and package. vulncheck will return a Vuln for
each OSV Entry ID, package, and symbol.

Change-Id: I9bd3d7ef710ce016e8f93da26bdf97919e1441a0
Reviewed-on: https://go-review.googlesource.com/c/exp/+/391520
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/cmd/govulncheck/main.go b/cmd/govulncheck/main.go
index e89b661..cfa4b4d 100644
--- a/cmd/govulncheck/main.go
+++ b/cmd/govulncheck/main.go
@@ -22,6 +22,7 @@
 	"go/build"
 	"log"
 	"os"
+	"sort"
 	"strings"
 
 	"golang.org/x/exp/vulncheck"
@@ -83,9 +84,7 @@
 	if err != nil {
 		die("govulncheck: %s", err)
 	}
-	vcfg := &vulncheck.Config{
-		Client: dbClient,
-	}
+	vcfg := &vulncheck.Config{Client: dbClient}
 	ctx := context.Background()
 
 	patterns := flag.Args()
@@ -148,24 +147,45 @@
 			moduleVersions[m.Path] = m.Version
 		}
 	})
+	callStacks := vulncheck.CallStacks(r)
 
 	const labelWidth = 16
-
 	line := func(label, text string) {
 		fmt.Printf("%-*s%s\n", labelWidth, label, text)
 	}
 
-	for _, v := range r.Vulns {
-		current := moduleVersions[v.ModPath]
-		fixed := "v" + latestFixed(v.OSV.Affected)
-		ref := fmt.Sprintf("https://pkg.go.dev/vuln/%s", v.OSV.ID)
-		line("package:", v.PkgPath)
+	vulnGroups := groupByIDAndPackage(r.Vulns)
+	for _, vg := range vulnGroups {
+		// All the vulns in vg have the same PkgPath, ModPath and OSV.
+		// All have a non-zero CallSink.
+		v0 := vg[0]
+
+		current := moduleVersions[v0.ModPath]
+		fixed := "v" + latestFixed(v0.OSV.Affected)
+		ref := fmt.Sprintf("https://pkg.go.dev/vuln/%s", v0.OSV.ID)
+
+		// Collect unique top of call stacks.
+		fns := map[*vulncheck.FuncNode]bool{}
+		for _, v := range vg {
+			for _, cs := range callStacks[v] {
+				fns[cs[0].Function] = true
+			}
+		}
+		// Use first top of first vuln as representative.
+		rep := funcName(callStacks[v0][0][0].Function)
+		var syms string
+		if len(fns) == 1 {
+			syms = rep
+		} else {
+			syms = fmt.Sprintf("%s and %d others", rep, len(fns)-1)
+		}
+		line("package:", v0.PkgPath)
 		line("your version:", current)
 		line("fixed version:", fixed)
-		line("symbol:", v.Symbol)
+		line("symbols:", syms)
 		line("reference:", ref)
 
-		desc := strings.Split(wrap(v.OSV.Details, 80-labelWidth), "\n")
+		desc := strings.Split(wrap(v0.OSV.Details, 80-labelWidth), "\n")
 		for i, l := range desc {
 			if i == 0 {
 				line("description:", l)
@@ -177,6 +197,28 @@
 	}
 }
 
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+	groups := map[[2]string][]*vulncheck.Vuln{}
+	for _, v := range vs {
+		if v.CallSink == 0 {
+			// Skip this vuln because although it appears in the
+			// import graph, there are no calls to it.
+			continue
+		}
+		key := [2]string{v.OSV.ID, v.PkgPath}
+		groups[key] = append(groups[key], v)
+	}
+
+	var res [][]*vulncheck.Vuln
+	for _, g := range groups {
+		res = append(res, g)
+	}
+	sort.Slice(res, func(i, j int) bool {
+		return res[i][0].PkgPath < res[j][0].PkgPath
+	})
+	return res
+}
+
 func isFile(path string) bool {
 	s, err := os.Stat(path)
 	if err != nil {
@@ -220,6 +262,10 @@
 	return v
 }
 
+func funcName(fn *vulncheck.FuncNode) string {
+	return strings.TrimPrefix(fn.String(), "*")
+}
+
 func die(format string, args ...interface{}) {
 	fmt.Fprintf(os.Stderr, format+"\n", args...)
 	os.Exit(1)