cmd/govulncheck: show entry symbols

Show the symbols at the top of the call chains, rather than the
vulnerable symbols. These are the functions in the module that
govulncheck is running on.

Also, group Vulns by ID and package. vulncheck will return a Vuln for
each OSV Entry ID, package, and symbol.

Cherry-picked: https://go-review.googlesource.com/c/exp/+/391520

Change-Id: I9bd3d7ef710ce016e8f93da26bdf97919e1441a0
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395238
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/cmd/govulncheck/main.go b/cmd/govulncheck/main.go
index a05e026..4b09e0b 100644
--- a/cmd/govulncheck/main.go
+++ b/cmd/govulncheck/main.go
@@ -25,6 +25,7 @@
 	"go/build"
 	"log"
 	"os"
+	"sort"
 	"strings"
 
 	"golang.org/x/exp/vulncheck"
@@ -86,9 +87,7 @@
 	if err != nil {
 		die("govulncheck: %s", err)
 	}
-	vcfg := &vulncheck.Config{
-		Client: dbClient,
-	}
+	vcfg := &vulncheck.Config{Client: dbClient}
 	ctx := context.Background()
 
 	patterns := flag.Args()
@@ -151,24 +150,45 @@
 			moduleVersions[m.Path] = m.Version
 		}
 	})
+	callStacks := vulncheck.CallStacks(r)
 
 	const labelWidth = 16
-
 	line := func(label, text string) {
 		fmt.Printf("%-*s%s\n", labelWidth, label, text)
 	}
 
-	for _, v := range r.Vulns {
-		current := moduleVersions[v.ModPath]
-		fixed := "v" + latestFixed(v.OSV.Affected)
-		ref := fmt.Sprintf("https://pkg.go.dev/vuln/%s", v.OSV.ID)
-		line("package:", v.PkgPath)
+	vulnGroups := groupByIDAndPackage(r.Vulns)
+	for _, vg := range vulnGroups {
+		// All the vulns in vg have the same PkgPath, ModPath and OSV.
+		// All have a non-zero CallSink.
+		v0 := vg[0]
+
+		current := moduleVersions[v0.ModPath]
+		fixed := "v" + latestFixed(v0.OSV.Affected)
+		ref := fmt.Sprintf("https://pkg.go.dev/vuln/%s", v0.OSV.ID)
+
+		// Collect unique top of call stacks.
+		fns := map[*vulncheck.FuncNode]bool{}
+		for _, v := range vg {
+			for _, cs := range callStacks[v] {
+				fns[cs[0].Function] = true
+			}
+		}
+		// Use first top of first vuln as representative.
+		rep := funcName(callStacks[v0][0][0].Function)
+		var syms string
+		if len(fns) == 1 {
+			syms = rep
+		} else {
+			syms = fmt.Sprintf("%s and %d others", rep, len(fns)-1)
+		}
+		line("package:", v0.PkgPath)
 		line("your version:", current)
 		line("fixed version:", fixed)
-		line("symbol:", v.Symbol)
+		line("symbols:", syms)
 		line("reference:", ref)
 
-		desc := strings.Split(wrap(v.OSV.Details, 80-labelWidth), "\n")
+		desc := strings.Split(wrap(v0.OSV.Details, 80-labelWidth), "\n")
 		for i, l := range desc {
 			if i == 0 {
 				line("description:", l)
@@ -180,6 +200,28 @@
 	}
 }
 
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+	groups := map[[2]string][]*vulncheck.Vuln{}
+	for _, v := range vs {
+		if v.CallSink == 0 {
+			// Skip this vuln because although it appears in the
+			// import graph, there are no calls to it.
+			continue
+		}
+		key := [2]string{v.OSV.ID, v.PkgPath}
+		groups[key] = append(groups[key], v)
+	}
+
+	var res [][]*vulncheck.Vuln
+	for _, g := range groups {
+		res = append(res, g)
+	}
+	sort.Slice(res, func(i, j int) bool {
+		return res[i][0].PkgPath < res[j][0].PkgPath
+	})
+	return res
+}
+
 func isFile(path string) bool {
 	s, err := os.Stat(path)
 	if err != nil {
@@ -223,6 +265,10 @@
 	return v
 }
 
+func funcName(fn *vulncheck.FuncNode) string {
+	return strings.TrimPrefix(fn.String(), "*")
+}
+
 func die(format string, args ...interface{}) {
 	fmt.Fprintf(os.Stderr, format+"\n", args...)
 	os.Exit(1)