gopls/internal/vulncheck: use internal/govulncheck

Copy the x/vuln/cmd/govulncheck/internal/govulncheck package
and use it in internal/vulncheck.

Fixes golang/go#52985.

Change-Id: I3fb16b3d486ac462fca36aa53fd46e576041102d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/407114
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/gopls/internal/govulncheck/source.go b/gopls/internal/govulncheck/source.go
new file mode 100644
index 0000000..752a831
--- /dev/null
+++ b/gopls/internal/govulncheck/source.go
@@ -0,0 +1,129 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+	"context"
+	"fmt"
+	"sort"
+	"strings"
+
+	"golang.org/x/tools/go/packages"
+	"golang.org/x/vuln/client"
+	"golang.org/x/vuln/vulncheck"
+)
+
+// A PackageError contains errors from loading a set of packages.
+type PackageError struct {
+	Errors []packages.Error
+}
+
+func (e *PackageError) Error() string {
+	var b strings.Builder
+	fmt.Fprintln(&b, "Packages contain errors:")
+	for _, e := range e.Errors {
+		fmt.Println(&b, e)
+	}
+	return b.String()
+}
+
+// LoadPackages loads the packages matching patterns using cfg, after setting
+// the cfg mode flags that vulncheck needs for analysis.
+// If the packages contain errors, a PackageError is returned containing a list of the errors,
+// along with the packages themselves.
+func LoadPackages(cfg *packages.Config, patterns ...string) ([]*vulncheck.Package, error) {
+	cfg.Mode |= packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles |
+		packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes |
+		packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps |
+		packages.NeedModule
+
+	pkgs, err := packages.Load(cfg, patterns...)
+	vpkgs := vulncheck.Convert(pkgs)
+	if err != nil {
+		return nil, err
+	}
+	var perrs []packages.Error
+	packages.Visit(pkgs, nil, func(p *packages.Package) {
+		perrs = append(perrs, p.Errors...)
+	})
+	if len(perrs) > 0 {
+		err = &PackageError{perrs}
+	}
+	return vpkgs, err
+}
+
+// Source calls vulncheck.Source on the Go source in pkgs. It returns the result
+// with Vulns trimmed to those that are actually called.
+func Source(ctx context.Context, pkgs []*vulncheck.Package, c client.Client) (*vulncheck.Result, error) {
+	r, err := vulncheck.Source(ctx, pkgs, &vulncheck.Config{Client: c})
+	if err != nil {
+		return nil, err
+	}
+	// Keep only the vulns that are called.
+	var vulns []*vulncheck.Vuln
+	for _, v := range r.Vulns {
+		if v.CallSink != 0 {
+			vulns = append(vulns, v)
+		}
+	}
+	r.Vulns = vulns
+	return r, nil
+}
+
+// CallInfo is information about calls to vulnerable functions.
+type CallInfo struct {
+	CallStacks     map[*vulncheck.Vuln][]vulncheck.CallStack // all call stacks
+	VulnGroups     [][]*vulncheck.Vuln                       // vulns grouped by ID and package
+	ModuleVersions map[string]string                         // map from module paths to versions
+	TopPackages    map[string]bool                           // top-level packages
+}
+
+// GetCallInfo computes call stacks and related information from a vulncheck.Result.
+// I also makes a set of top-level packages from pkgs.
+func GetCallInfo(r *vulncheck.Result, pkgs []*vulncheck.Package) *CallInfo {
+	pset := map[string]bool{}
+	for _, p := range pkgs {
+		pset[p.PkgPath] = true
+	}
+	return &CallInfo{
+		CallStacks:     vulncheck.CallStacks(r),
+		VulnGroups:     groupByIDAndPackage(r.Vulns),
+		ModuleVersions: moduleVersionMap(r.Modules),
+		TopPackages:    pset,
+	}
+}
+
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+	groups := map[[2]string][]*vulncheck.Vuln{}
+	for _, v := range vs {
+		key := [2]string{v.OSV.ID, v.PkgPath}
+		groups[key] = append(groups[key], v)
+	}
+
+	var res [][]*vulncheck.Vuln
+	for _, g := range groups {
+		res = append(res, g)
+	}
+	sort.Slice(res, func(i, j int) bool {
+		return res[i][0].PkgPath < res[j][0].PkgPath
+	})
+	return res
+}
+
+// moduleVersionMap builds a map from module paths to versions.
+func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
+	moduleVersions := map[string]string{}
+	for _, m := range mods {
+		v := m.Version
+		if m.Replace != nil {
+			v = m.Replace.Version
+		}
+		moduleVersions[m.Path] = v
+	}
+	return moduleVersions
+}
diff --git a/gopls/internal/govulncheck/util.go b/gopls/internal/govulncheck/util.go
new file mode 100644
index 0000000..baa2d96
--- /dev/null
+++ b/gopls/internal/govulncheck/util.go
@@ -0,0 +1,109 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+	"fmt"
+	"strings"
+
+	"golang.org/x/mod/semver"
+	"golang.org/x/vuln/osv"
+	"golang.org/x/vuln/vulncheck"
+)
+
+// LatestFixed returns the latest fixed version in the list of affected ranges,
+// or the empty string if there are no fixed versions.
+func LatestFixed(as []osv.Affected) string {
+	v := ""
+	for _, a := range as {
+		for _, r := range a.Ranges {
+			if r.Type == osv.TypeSemver {
+				for _, e := range r.Events {
+					if e.Fixed != "" && (v == "" || semver.Compare(e.Fixed, v) > 0) {
+						v = e.Fixed
+					}
+				}
+			}
+		}
+	}
+	return v
+}
+
+// SummarizeCallStack returns a short description of the call stack.
+// It uses one of two forms, depending on what the lowest function F in topPkgs
+// calls:
+//   - If it calls a function V from the vulnerable package, then summarizeCallStack
+//     returns "F calls V".
+//   - If it calls a function G in some other package, which eventually calls V,
+//     it returns "F calls G, which eventually calls V".
+//
+// If it can't find any of these functions, summarizeCallStack returns the empty string.
+func SummarizeCallStack(cs vulncheck.CallStack, topPkgs map[string]bool, vulnPkg string) string {
+	// Find the lowest function in the top packages.
+	iTop := lowest(cs, func(e vulncheck.StackEntry) bool {
+		return topPkgs[PkgPath(e.Function)]
+	})
+	if iTop < 0 {
+		return ""
+	}
+	// Find the highest function in the vulnerable package that is below iTop.
+	iVuln := highest(cs[iTop+1:], func(e vulncheck.StackEntry) bool {
+		return PkgPath(e.Function) == vulnPkg
+	})
+	if iVuln < 0 {
+		return ""
+	}
+	iVuln += iTop + 1 // adjust for slice in call to highest.
+	topName := FuncName(cs[iTop].Function)
+	vulnName := FuncName(cs[iVuln].Function)
+	if iVuln == iTop+1 {
+		return fmt.Sprintf("%s calls %s", topName, vulnName)
+	}
+	return fmt.Sprintf("%s calls %s, which eventually calls %s",
+		topName, FuncName(cs[iTop+1].Function), vulnName)
+}
+
+// highest returns the highest (one with the smallest index) entry in the call
+// stack for which f returns true.
+func highest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
+	for i := 0; i < len(cs); i++ {
+		if f(cs[i]) {
+			return i
+		}
+	}
+	return -1
+}
+
+// lowest returns the lowest (one with the largets index) entry in the call
+// stack for which f returns true.
+func lowest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
+	for i := len(cs) - 1; i >= 0; i-- {
+		if f(cs[i]) {
+			return i
+		}
+	}
+	return -1
+}
+
+// PkgPath returns the package path from fn.
+func PkgPath(fn *vulncheck.FuncNode) string {
+	if fn.PkgPath != "" {
+		return fn.PkgPath
+	}
+	s := strings.TrimPrefix(fn.RecvType, "*")
+	if i := strings.LastIndexByte(s, '.'); i > 0 {
+		s = s[:i]
+	}
+	return s
+}
+
+// FuncName returns the function name from fn, adjusted
+// to remove pointer annotations.
+func FuncName(fn *vulncheck.FuncNode) string {
+	return strings.TrimPrefix(fn.String(), "*")
+}
diff --git a/gopls/internal/govulncheck/util_test.go b/gopls/internal/govulncheck/util_test.go
new file mode 100644
index 0000000..3288cd8
--- /dev/null
+++ b/gopls/internal/govulncheck/util_test.go
@@ -0,0 +1,85 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+	"strings"
+	"testing"
+
+	"golang.org/x/vuln/vulncheck"
+)
+
+func TestPkgPath(t *testing.T) {
+	for _, test := range []struct {
+		in   vulncheck.FuncNode
+		want string
+	}{
+		{
+			vulncheck.FuncNode{PkgPath: "math", Name: "Floor"},
+			"math",
+		},
+		{
+			vulncheck.FuncNode{RecvType: "a.com/b.T", Name: "M"},
+			"a.com/b",
+		},
+		{
+			vulncheck.FuncNode{RecvType: "*a.com/b.T", Name: "M"},
+			"a.com/b",
+		},
+	} {
+		got := PkgPath(&test.in)
+		if got != test.want {
+			t.Errorf("%+v: got %q, want %q", test.in, got, test.want)
+		}
+	}
+}
+
+func TestSummarizeCallStack(t *testing.T) {
+	topPkgs := map[string]bool{"t1": true, "t2": true}
+	vulnPkg := "v"
+
+	for _, test := range []struct {
+		in, want string
+	}{
+		{"a.F", ""},
+		{"t1.F", ""},
+		{"v.V", ""},
+		{
+			"t1.F v.V",
+			"t1.F calls v.V",
+		},
+		{
+			"t1.F t2.G v.V1 v.v2",
+			"t2.G calls v.V1",
+		},
+		{
+			"t1.F x.Y t2.G a.H b.I c.J v.V",
+			"t2.G calls a.H, which eventually calls v.V",
+		},
+	} {
+		in := stringToCallStack(test.in)
+		got := SummarizeCallStack(in, topPkgs, vulnPkg)
+		if got != test.want {
+			t.Errorf("%s:\ngot  %s\nwant %s", test.in, got, test.want)
+		}
+	}
+}
+
+func stringToCallStack(s string) vulncheck.CallStack {
+	var cs vulncheck.CallStack
+	for _, e := range strings.Fields(s) {
+		parts := strings.Split(e, ".")
+		cs = append(cs, vulncheck.StackEntry{
+			Function: &vulncheck.FuncNode{
+				PkgPath: parts[0],
+				Name:    parts[1],
+			},
+		})
+	}
+	return cs
+}
diff --git a/gopls/internal/vulncheck/command.go b/gopls/internal/vulncheck/command.go
index 459ecca..a89354f 100644
--- a/gopls/internal/vulncheck/command.go
+++ b/gopls/internal/vulncheck/command.go
@@ -17,7 +17,6 @@
 	gvc "golang.org/x/tools/gopls/internal/govulncheck"
 	"golang.org/x/tools/internal/lsp/command"
 	"golang.org/x/vuln/client"
-	"golang.org/x/vuln/vulncheck"
 )
 
 func init() {
@@ -73,53 +72,36 @@
 		packages.NeedTypesSizes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps
 
 	log.Println("loading packages...")
-
-	loadedPkgs, err := packages.Load(cfg, patterns...)
+	loadedPkgs, err := gvc.LoadPackages(cfg, patterns...)
 	if err != nil {
 		log.Printf("package load failed: %v", err)
 		return nil, err
 	}
 	log.Printf("loaded %d packages\n", len(loadedPkgs))
 
-	pkgs := vulncheck.Convert(loadedPkgs)
-	r, err := vulncheck.Source(ctx, pkgs, &vulncheck.Config{
-		Client: c.Client,
-	})
+	r, err := gvc.Source(ctx, loadedPkgs, c.Client)
 	if err != nil {
 		return nil, err
 	}
-
-	// Skip vulns that are in the import graph but have no calls to them.
-	var vulns []*vulncheck.Vuln
-	for _, v := range r.Vulns {
-		if v.CallSink != 0 {
-			vulns = append(vulns, v)
-		}
-	}
-
-	callStacks := vulncheck.CallStacks(r)
-	// Create set of top-level packages, used to find representative symbols
-	topPackages := map[string]bool{}
-	for _, p := range pkgs {
-		topPackages[p.PkgPath] = true
-	}
-	vulnGroups := groupByIDAndPackage(vulns)
-	moduleVersions := moduleVersionMap(r.Modules)
-
-	return toVulns(callStacks, moduleVersions, topPackages, vulnGroups)
+	callInfo := gvc.GetCallInfo(r, loadedPkgs)
+	return toVulns(callInfo)
 	// TODO: add import graphs.
 }
 
-func toVulns(callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) ([]Vuln, error) {
+func toVulns(ci *gvc.CallInfo) ([]Vuln, error) {
 	var vulns []Vuln
 
-	for _, vg := range vulnGroups {
+	for _, vg := range ci.VulnGroups {
 		v0 := vg[0]
+		lf := gvc.LatestFixed(v0.OSV.Affected)
+		if lf != "" && lf[0] != 'v' {
+			lf = "v" + lf
+		}
 		vuln := Vuln{
 			ID:             v0.OSV.ID,
 			PkgPath:        v0.PkgPath,
-			CurrentVersion: moduleVersions[v0.ModPath],
-			FixedVersion:   latestFixed(v0.OSV.Affected),
+			CurrentVersion: ci.ModuleVersions[v0.ModPath],
+			FixedVersion:   lf,
 			Details:        v0.OSV.Details,
 
 			Aliases: v0.OSV.Aliases,
@@ -130,9 +112,9 @@
 
 		// Keep first call stack for each vuln.
 		for _, v := range vg {
-			if css := callStacks[v]; len(css) > 0 {
+			if css := ci.CallStacks[v]; len(css) > 0 {
 				vuln.CallStacks = append(vuln.CallStacks, toCallStack(css[0]))
-				vuln.CallStackSummaries = append(vuln.CallStackSummaries, summarizeCallStack(css[0], topPackages, v.PkgPath))
+				vuln.CallStackSummaries = append(vuln.CallStackSummaries, gvc.SummarizeCallStack(css[0], ci.TopPackages, v.PkgPath))
 			}
 		}
 		vulns = append(vulns, vuln)
diff --git a/gopls/internal/vulncheck/util.go b/gopls/internal/vulncheck/util.go
index e2a437b..c329461 100644
--- a/gopls/internal/vulncheck/util.go
+++ b/gopls/internal/vulncheck/util.go
@@ -10,134 +10,13 @@
 import (
 	"fmt"
 	"go/token"
-	"sort"
-	"strings"
 
-	"golang.org/x/mod/semver"
+	gvc "golang.org/x/tools/gopls/internal/govulncheck"
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/vuln/osv"
 	"golang.org/x/vuln/vulncheck"
 )
 
-// TODO(hyangah): automate copy of golang.org/x/vuln/cmd.
-
-// moduleVersionMap builds a map from module paths to versions.
-func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
-	moduleVersions := map[string]string{}
-	for _, m := range mods {
-		v := m.Version
-		if m.Replace != nil {
-			v = m.Replace.Version
-		}
-		moduleVersions[m.Path] = v
-	}
-	return moduleVersions
-}
-
-func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
-	groups := map[[2]string][]*vulncheck.Vuln{}
-	for _, v := range vs {
-		key := [2]string{v.OSV.ID, v.PkgPath}
-		groups[key] = append(groups[key], v)
-	}
-
-	var res [][]*vulncheck.Vuln
-	for _, g := range groups {
-		res = append(res, g)
-	}
-	sort.Slice(res, func(i, j int) bool {
-		return res[i][0].PkgPath < res[j][0].PkgPath
-	})
-	return res
-}
-
-// latestFixed returns the latest fixed version in the list of affected ranges,
-// or the empty string if there are no fixed versions.
-func latestFixed(as []osv.Affected) string {
-	v := ""
-	for _, a := range as {
-		for _, r := range a.Ranges {
-			if r.Type == osv.TypeSemver {
-				for _, e := range r.Events {
-					if e.Fixed != "" && (v == "" || semver.Compare(e.Fixed, v) > 0) {
-						v = e.Fixed
-					}
-				}
-			}
-		}
-	}
-	if v == "" || v[0] == 'v' {
-		return v
-	}
-	return "v" + v
-}
-
-// summarizeCallStack returns a short description of the call stack.
-// It uses one of two forms, depending on what the lowest function F in topPkgs
-// calls:
-//   - If it calls a function V from the vulnerable package, then summarizeCallStack
-//     returns "F calls V".
-//   - If it calls a function G in some other package, which eventually calls V,
-//     it returns "F calls G, which eventually calls V".
-//
-// If it can't find any of these functions, summarizeCallStack returns the empty string.
-func summarizeCallStack(cs vulncheck.CallStack, topPkgs map[string]bool, vulnPkg string) string {
-	// Find the lowest function in the top packages.
-	iTop := lowest(cs, func(e vulncheck.StackEntry) bool {
-		return topPkgs[pkgPath(e.Function)]
-	})
-	if iTop < 0 {
-		return ""
-	}
-	// Find the highest function in the vulnerable package that is below iTop.
-	iVuln := highest(cs[iTop+1:], func(e vulncheck.StackEntry) bool {
-		return pkgPath(e.Function) == vulnPkg
-	})
-	if iVuln < 0 {
-		return ""
-	}
-	iVuln += iTop + 1 // adjust for slice in call to highest.
-	topName := funcName(cs[iTop].Function)
-	vulnName := funcName(cs[iVuln].Function)
-	if iVuln == iTop+1 {
-		return fmt.Sprintf("%s calls %s", topName, vulnName)
-	}
-	return fmt.Sprintf("%s calls %s, which eventually calls %s",
-		topName, funcName(cs[iTop+1].Function), vulnName)
-}
-
-// highest returns the highest (one with the smallest index) entry in the call
-// stack for which f returns true.
-func highest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
-	for i := 0; i < len(cs); i++ {
-		if f(cs[i]) {
-			return i
-		}
-	}
-	return -1
-}
-
-// lowest returns the lowest (one with the largets index) entry in the call
-// stack for which f returns true.
-func lowest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
-	for i := len(cs) - 1; i >= 0; i-- {
-		if f(cs[i]) {
-			return i
-		}
-	}
-	return -1
-}
-func pkgPath(fn *vulncheck.FuncNode) string {
-	if fn.PkgPath != "" {
-		return fn.PkgPath
-	}
-	s := strings.TrimPrefix(fn.RecvType, "*")
-	if i := strings.LastIndexByte(s, '.'); i > 0 {
-		s = s[:i]
-	}
-	return s
-}
-
 func toCallStack(src vulncheck.CallStack) CallStack {
 	var dest []StackEntry
 	for _, e := range src {
@@ -149,7 +28,7 @@
 func toStackEntry(src vulncheck.StackEntry) StackEntry {
 	f, call := src.Function, src.Call
 	pos := f.Pos
-	desc := funcName(f)
+	desc := gvc.FuncName(f)
 	if src.Call != nil {
 		pos = src.Call.Pos // Exact call site position is helpful.
 		if !call.Resolved {
@@ -166,10 +45,6 @@
 	}
 }
 
-func funcName(fn *vulncheck.FuncNode) string {
-	return strings.TrimPrefix(fn.String(), "*")
-}
-
 // href returns a URL embedded in the entry if any.
 // If no suitable URL is found, it returns a default entry in
 // pkg.go.dev/vuln.