cmd/govulncheck: add GetCallInfo to internal package

Move more functionality to the internal/govulncheck package so
gopls can share it.

Change-Id: Iabb88b6e5af71cf22e54f1264d8b307cc719b9e1
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/406936
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
diff --git a/cmd/govulncheck/html.go b/cmd/govulncheck/html.go
index 3f29f43..090a464 100644
--- a/cmd/govulncheck/html.go
+++ b/cmd/govulncheck/html.go
@@ -13,13 +13,14 @@
 	"html/template"
 	"io"
 
+	"golang.org/x/vuln/cmd/govulncheck/internal/govulncheck"
 	"golang.org/x/vuln/vulncheck"
 )
 
 //go:embed static/*
 var staticContent embed.FS
 
-func html(w io.Writer, r *vulncheck.Result, callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) error {
+func html(w io.Writer, r *vulncheck.Result, ci *govulncheck.CallInfo) error {
 	tmpl, err := template.New("govulncheck.tmpl").Funcs(template.FuncMap{
 		"funcName": funcName,
 	}).ParseFS(staticContent, "static/govulncheck.tmpl")
@@ -43,21 +44,21 @@
 	}
 
 	var vulns []*vuln
-	for _, vg := range vulnGroups {
+	for _, vg := range ci.VulnGroups {
 		v0 := vg[0]
 		vn := &vuln{
 			ID:             v0.OSV.ID,
 			PkgPath:        v0.PkgPath,
-			CurrentVersion: moduleVersions[v0.ModPath],
+			CurrentVersion: ci.ModuleVersions[v0.ModPath],
 			FixedVersion:   "v" + latestFixed(v0.OSV.Affected),
 			Reference:      fmt.Sprintf("https://pkg.go.dev/vuln/%s", v0.OSV.ID),
 			Details:        v0.OSV.Details,
 		}
 		// Keep first call stack for each vuln.
 		for _, v := range vg {
-			if css := callStacks[v]; len(css) > 0 {
+			if css := ci.CallStacks[v]; len(css) > 0 {
 				vn.Stacks = append(vn.Stacks, callstack{
-					Summary: summarizeCallStack(css[0], topPackages, v.PkgPath),
+					Summary: summarizeCallStack(css[0], ci.TopPackages, v.PkgPath),
 					Stack:   css[0],
 				})
 			}
diff --git a/cmd/govulncheck/internal/govulncheck/source.go b/cmd/govulncheck/internal/govulncheck/source.go
index b02fc47..9f91665 100644
--- a/cmd/govulncheck/internal/govulncheck/source.go
+++ b/cmd/govulncheck/internal/govulncheck/source.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"fmt"
+	"sort"
 	"strings"
 
 	"golang.org/x/tools/go/packages"
@@ -70,3 +71,56 @@
 	r.Vulns = vulns
 	return r, nil
 }
+
+// CallInfo is information about calls to vulnerable functions.
+type CallInfo struct {
+	CallStacks     map[*vulncheck.Vuln][]vulncheck.CallStack // all call stacks
+	VulnGroups     [][]*vulncheck.Vuln                       // vulns grouped by ID and package
+	ModuleVersions map[string]string                         // map from module paths to versions
+	TopPackages    map[string]bool                           // top-level packages
+}
+
+// GetCallInfo computes call stacks and related information from a vulncheck.Result.
+// I also makes a set of top-level packages from pkgs.
+func GetCallInfo(r *vulncheck.Result, pkgs []*vulncheck.Package) *CallInfo {
+	pset := map[string]bool{}
+	for _, p := range pkgs {
+		pset[p.PkgPath] = true
+	}
+	return &CallInfo{
+		CallStacks:     vulncheck.CallStacks(r),
+		VulnGroups:     groupByIDAndPackage(r.Vulns),
+		ModuleVersions: moduleVersionMap(r.Modules),
+		TopPackages:    pset,
+	}
+}
+
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+	groups := map[[2]string][]*vulncheck.Vuln{}
+	for _, v := range vs {
+		key := [2]string{v.OSV.ID, v.PkgPath}
+		groups[key] = append(groups[key], v)
+	}
+
+	var res [][]*vulncheck.Vuln
+	for _, g := range groups {
+		res = append(res, g)
+	}
+	sort.Slice(res, func(i, j int) bool {
+		return res[i][0].PkgPath < res[j][0].PkgPath
+	})
+	return res
+}
+
+// moduleVersionMap builds a map from module paths to versions.
+func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
+	moduleVersions := map[string]string{}
+	for _, m := range mods {
+		v := m.Version
+		if m.Replace != nil {
+			v = m.Replace.Version
+		}
+		moduleVersions[m.Path] = v
+	}
+	return moduleVersions
+}
diff --git a/cmd/govulncheck/main.go b/cmd/govulncheck/main.go
index a7f8a82..050a142 100644
--- a/cmd/govulncheck/main.go
+++ b/cmd/govulncheck/main.go
@@ -128,20 +128,14 @@
 	if *jsonFlag {
 		writeJSON(r)
 	} else {
-		callStacks := vulncheck.CallStacks(r)
-		// Create set of top-level packages, used to find representative symbols
-		topPackages := map[string]bool{}
-		for _, p := range pkgs {
-			topPackages[p.PkgPath] = true
-		}
-		vulnGroups := groupByIDAndPackage(r.Vulns)
-		moduleVersions := moduleVersionMap(r.Modules)
+		// set of top-level packages, used to find representative symbols
+		ci := govulncheck.GetCallInfo(r, pkgs)
 		if *htmlFlag {
-			if err := html(os.Stdout, r, callStacks, moduleVersions, topPackages, vulnGroups); err != nil {
+			if err := html(os.Stdout, r, ci); err != nil {
 				die("writing HTML: %v", err)
 			}
 		} else {
-			writeText(r, callStacks, moduleVersions, topPackages, vulnGroups)
+			writeText(r, ci)
 		}
 	}
 	exitCode := 0
@@ -152,19 +146,6 @@
 	os.Exit(exitCode)
 }
 
-// moduleVersionMap builds a map from module paths to versions.
-func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
-	moduleVersions := map[string]string{}
-	for _, m := range mods {
-		v := m.Version
-		if m.Replace != nil {
-			v = m.Replace.Version
-		}
-		moduleVersions[m.Path] = v
-	}
-	return moduleVersions
-}
-
 func writeJSON(r *vulncheck.Result) {
 	b, err := json.MarshalIndent(r, "", "\t")
 	if err != nil {
@@ -174,23 +155,23 @@
 	fmt.Println()
 }
 
-func writeText(r *vulncheck.Result, callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) {
+func writeText(r *vulncheck.Result, ci *govulncheck.CallInfo) {
 
 	const labelWidth = 16
 	line := func(label, text string) {
 		fmt.Printf("%-*s%s\n", labelWidth, label, text)
 	}
-	for _, vg := range vulnGroups {
+	for _, vg := range ci.VulnGroups {
 		// All the vulns in vg have the same PkgPath, ModPath and OSV.
 		// All have a non-zero CallSink.
 		v0 := vg[0]
 		line("package:", v0.PkgPath)
-		line("your version:", moduleVersions[v0.ModPath])
+		line("your version:", ci.ModuleVersions[v0.ModPath])
 		line("fixed version:", "v"+latestFixed(v0.OSV.Affected))
 		var summaries []string
 		for _, v := range vg {
-			if css := callStacks[v]; len(css) > 0 {
-				if sum := summarizeCallStack(css[0], topPackages, v.PkgPath); sum != "" {
+			if css := ci.CallStacks[v]; len(css) > 0 {
+				if sum := summarizeCallStack(css[0], ci.TopPackages, v.PkgPath); sum != "" {
 					summaries = append(summaries, sum)
 				}
 			}
@@ -216,34 +197,6 @@
 	}
 }
 
-func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
-	groups := map[[2]string][]*vulncheck.Vuln{}
-	for _, v := range vs {
-		key := [2]string{v.OSV.ID, v.PkgPath}
-		groups[key] = append(groups[key], v)
-	}
-
-	var res [][]*vulncheck.Vuln
-	for _, g := range groups {
-		res = append(res, g)
-	}
-	sort.Slice(res, func(i, j int) bool {
-		return res[i][0].PkgPath < res[j][0].PkgPath
-	})
-	return res
-}
-
-func packageModule(p *packages.Package) *packages.Module {
-	m := p.Module
-	if m == nil {
-		return nil
-	}
-	if r := m.Replace; r != nil {
-		return r
-	}
-	return m
-}
-
 func isFile(path string) bool {
 	s, err := os.Stat(path)
 	if err != nil {