internal/vuln: refactor client and improve ByPackagePrefix

Refactor vuln client to share more code, and improve the algorithm
used by ByPackagePrefix. ByPackagePrefix now downloads OSV entries
more selectively by filtering based on info available in the modules
index.

Change-Id: I4c62aa38b8224207d1774e6d87277bbd36d18710
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/486457
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Run-TryBot: Tatiana Bradley <tatianabradley@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
diff --git a/go.mod b/go.mod
index 906c1a9..a8d01b7 100644
--- a/go.mod
+++ b/go.mod
@@ -36,11 +36,12 @@
 	github.com/yuin/goldmark v1.4.13
 	github.com/yuin/goldmark-emoji v1.0.1
 	go.opencensus.io v0.23.0
-	golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4
+	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
+	golang.org/x/mod v0.6.0
 	golang.org/x/net v0.7.0
 	golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
 	golang.org/x/text v0.7.0
-	golang.org/x/tools v0.1.13-0.20220928184430-f80e98464e27
+	golang.org/x/tools v0.2.0
 	google.golang.org/api v0.63.0
 	google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa
 	google.golang.org/grpc v1.43.0
@@ -98,7 +99,7 @@
 	github.com/xanzy/ssh-agent v0.3.0 // indirect
 	github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da // indirect
 	go.uber.org/atomic v1.6.0 // indirect
-	golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b // indirect
+	golang.org/x/crypto v0.1.0 // indirect
 	golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
 	golang.org/x/sys v0.5.0 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
diff --git a/go.sum b/go.sum
index 73ea600..99e0cd2 100644
--- a/go.sum
+++ b/go.sum
@@ -1129,8 +1129,8 @@
 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b h1:Qwe1rC8PSniVfAFPFJeyUkB+zcysC3RgJBAGk7eqBEU=
-golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
+golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
 golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -1145,6 +1145,8 @@
 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
 golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
+golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -1178,8 +1180,8 @@
 golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I=
+golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
 golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -1481,8 +1483,8 @@
 golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
 golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
 golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
-golang.org/x/tools v0.1.13-0.20220928184430-f80e98464e27 h1:mOqz7ZhDqMSA3LafrO1Q+1yLQ/KCnCy2/5xiFQVkCWQ=
-golang.org/x/tools v0.1.13-0.20220928184430-f80e98464e27/go.mod h1:VsjNM1dMo+Ofkp5d7y7fOdQZD8MTXSQ4w3EPk65AvKU=
+golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE=
+golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
 golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/internal/vuln/client.go b/internal/vuln/client.go
index 5a425b5..e03e714 100644
--- a/internal/vuln/client.go
+++ b/internal/vuln/client.go
@@ -8,14 +8,14 @@
 	"bytes"
 	"context"
 	"encoding/json"
-	"fmt"
 	"path/filepath"
 	"sort"
 	"strings"
-	"sync"
 
+	"golang.org/x/exp/slices"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/osv"
+	"golang.org/x/pkgsite/internal/stdlib"
 	"golang.org/x/sync/errgroup"
 )
 
@@ -65,6 +65,41 @@
 func (c *Client) ByPackage(ctx context.Context, req *PackageRequest) (_ []*osv.Entry, err error) {
 	derrors.Wrap(&err, "ByPackage(%v)", req)
 
+	// Find the metadata for the module with the given module path.
+	ms, err := c.modulesFilter(ctx, func(m *ModuleMeta) bool {
+		return m.Path == req.Module
+	}, 1)
+	if err != nil {
+		return nil, err
+	}
+	if len(ms) == 0 {
+		return nil, nil
+	}
+
+	// Figure out which vulns we actually need to download.
+	var ids []string
+	for _, v := range ms[0].Vulns {
+		// We need to download the full entry if there is no fix,
+		// or the requested version is less than the vuln's
+		// highest fixed version.
+		if v.Fixed == "" || osv.LessSemver(req.Version, v.Fixed) {
+			ids = append(ids, v.ID)
+		}
+	}
+	if len(ids) == 0 {
+		return nil, nil
+	}
+
+	return c.byIDsFilter(ctx, ids, func(e *osv.Entry) bool {
+		return isAffected(e, req)
+	})
+}
+
+func (c *Client) modulesFilter(ctx context.Context, filter func(*ModuleMeta) bool, n int) ([]*ModuleMeta, error) {
+	if n == 0 {
+		return nil, nil
+	}
+
 	b, err := c.modules(ctx)
 	if err != nil {
 		return nil, err
@@ -75,67 +110,26 @@
 		return nil, err
 	}
 
-	var ids []string
+	ms := make([]*ModuleMeta, 0)
 	for dec.More() {
 		var m ModuleMeta
 		err := dec.Decode(&m)
 		if err != nil {
 			return nil, err
 		}
-		if m.Path == req.Module {
-			for _, v := range m.Vulns {
-				// We need to download the full entry if there is no fix,
-				// or the requested version is less than the vuln's
-				// highest fixed version.
-				if v.Fixed == "" || osv.LessSemver(req.Version, v.Fixed) {
-					ids = append(ids, v.ID)
-				}
+		if filter(&m) {
+			ms = append(ms, &m)
+			if len(ms) == n {
+				return ms, nil
 			}
-			// We found the requested module, so skip the rest.
-			break
 		}
 	}
 
-	if len(ids) == 0 {
+	if len(ms) == 0 {
 		return nil, nil
 	}
 
-	// Fetch all the entries in parallel, and create a slice
-	// containing all the actually affected entries.
-	g, gctx := errgroup.WithContext(ctx)
-	var mux sync.Mutex
-	g.SetLimit(10)
-	entries := make([]*osv.Entry, 0, len(ids))
-	for _, id := range ids {
-		id := id
-		g.Go(func() error {
-			entry, err := c.ByID(gctx, id)
-			if err != nil {
-				return err
-			}
-
-			if entry == nil {
-				return fmt.Errorf("vulnerability %s was found in %s but could not be retrieved", id, modulesEndpoint)
-			}
-
-			if isAffected(entry, req) {
-				mux.Lock()
-				entries = append(entries, entry)
-				mux.Unlock()
-			}
-
-			return nil
-		})
-	}
-	if err := g.Wait(); err != nil {
-		return nil, err
-	}
-
-	sort.SliceStable(entries, func(i, j int) bool {
-		return entries[i].ID < entries[j].ID
-	})
-
-	return entries, nil
+	return ms, nil
 }
 
 func isAffected(e *osv.Entry, req *PackageRequest) bool {
@@ -229,8 +223,7 @@
 	if err != nil {
 		return nil, err
 	}
-
-	sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] })
+	sortIDs(ids)
 
 	if n >= 0 && len(ids) > n {
 		ids = ids[:n]
@@ -239,6 +232,11 @@
 	return c.byIDs(ctx, ids)
 }
 
+func sortIDs(ids []string) {
+	sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] })
+
+}
+
 // ByPackagePrefix returns all the OSV entries that match the given
 // package prefix, in descending order by ID, or (nil, nil) if there
 // are none.
@@ -249,25 +247,35 @@
 //     interpreted as a full path. (E.g. "example.com/module/package" matches
 //     the prefix "example.com/module" but not "example.com/mod")
 func (c *Client) ByPackagePrefix(ctx context.Context, prefix string) (_ []*osv.Entry, err error) {
-	allEntries, err := c.Entries(ctx, -1)
-	if err != nil {
-		return nil, err
-	}
+	derrors.Wrap(&err, "ByPackagePrefix(%s)", prefix)
 
 	prefix = strings.TrimSuffix(prefix, "/")
-	match := func(s string) bool {
-		return s == prefix || strings.HasPrefix(s, prefix+"/")
+	prefixPath := prefix + "/"
+	prefixMatch := func(s string) bool {
+		return s == prefix || strings.HasPrefix(s, prefixPath)
 	}
 
-	// Returns whether any of the affected modules or packages of the
-	// entry start with the prefix.
-	matchesQuery := func(e *osv.Entry) bool {
+	moduleMatch := func(m *ModuleMeta) bool {
+		// If the prefix possibly refers to a standard library package,
+		// always look at the stdlib and toolchain modules.
+		if stdlib.Contains(prefix) &&
+			(m.Path == osv.GoStdModulePath || m.Path == osv.GoCmdModulePath) {
+			return true
+		}
+		// Look at the module if it is either prefixed by the prefix,
+		// or it is itself a prefix of the prefix.
+		// (The latter case catches queries that are prefixes of the package
+		// path but longer than the module path).
+		return prefixMatch(m.Path) || strings.HasPrefix(prefix, m.Path)
+	}
+
+	entryMatch := func(e *osv.Entry) bool {
 		for _, aff := range e.Affected {
-			if match(aff.Module.Path) {
+			if prefixMatch(aff.Module.Path) {
 				return true
 			}
 			for _, pkg := range aff.EcosystemSpecific.Packages {
-				if match(pkg.Path) {
+				if prefixMatch(pkg.Path) {
 					return true
 				}
 			}
@@ -275,20 +283,48 @@
 		return false
 	}
 
-	var entries []*osv.Entry
-	for _, entry := range allEntries {
-		if matchesQuery(entry) {
-			entries = append(entries, entry)
-		}
+	ms, err := c.modulesFilter(ctx, moduleMatch, -1)
+	if err != nil {
+		return nil, err
+	}
+	if len(ms) == 0 {
+		return nil, nil
 	}
 
-	return entries, nil
+	var ids []string
+	for _, m := range ms {
+		for _, vs := range m.Vulns {
+			ids = append(ids, vs.ID)
+		}
+	}
+	sortIDs(ids)
+	// Remove any duplicates.
+	ids = slices.Compact(ids)
+
+	return c.byIDsFilter(ctx, ids, entryMatch)
+}
+
+func (c *Client) byIDsFilter(ctx context.Context, ids []string, filter func(*osv.Entry) bool) (_ []*osv.Entry, err error) {
+	entries, err := c.byIDs(ctx, ids)
+	if err != nil {
+		return nil, err
+	}
+	var filtered []*osv.Entry
+	for _, entry := range entries {
+		if filter(entry) {
+			filtered = append(filtered, entry)
+		}
+	}
+	if len(filtered) == 0 {
+		return nil, nil
+	}
+	return filtered, nil
 }
 
 func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err error) {
 	entries := make([]*osv.Entry, len(ids))
 	g, gctx := errgroup.WithContext(ctx)
-	g.SetLimit(4)
+	g.SetLimit(10)
 	for i, id := range ids {
 		i, id := i, id
 		g.Go(func() error {
diff --git a/internal/vuln/client_test.go b/internal/vuln/client_test.go
index e48cb14..24b2984 100644
--- a/internal/vuln/client_test.go
+++ b/internal/vuln/client_test.go
@@ -458,7 +458,23 @@
 			},
 		},
 	}
-	vc, err := NewInMemoryClient([]*osv.Entry{stdlibCrypto, stdlibNet, thirdParty})
+	// Entry containing two modules with a common prefix.
+	commonPrefix := &osv.Entry{
+		ID: "4-COMMON-PREFIX",
+		Affected: []osv.Affected{
+			{
+				Module: osv.Module{
+					Path: "example.com/module",
+				},
+			},
+			{
+				Module: osv.Module{
+					Path: "example.com/module/inner",
+				},
+			},
+		},
+	}
+	vc, err := NewInMemoryClient([]*osv.Entry{stdlibCrypto, stdlibNet, thirdParty, commonPrefix})
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -517,6 +533,13 @@
 			query: "golang.org/x",
 			want:  []*osv.Entry{thirdParty, stdlibCrypto, stdlibNet},
 		},
+		{
+			name: "entries not duplicated",
+			// Query is both an exact match and a prefix for another
+			// module, but entry should only show up once.
+			query: "example.com/module",
+			want:  []*osv.Entry{commonPrefix},
+		},
 	} {
 		t.Run(tc.name, func(t *testing.T) {
 			got, err := vc.ByPackagePrefix(context.Background(), tc.query)
diff --git a/internal/vuln/vulns.go b/internal/vuln/vulns.go
index 6edf372..05906eb 100644
--- a/internal/vuln/vulns.go
+++ b/internal/vuln/vulns.go
@@ -11,10 +11,9 @@
 	"go/token"
 	"strings"
 
-	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/osv"
 	"golang.org/x/pkgsite/internal/stdlib"
-	"golang.org/x/pkgsite/internal/version"
+	vers "golang.org/x/pkgsite/internal/version"
 )
 
 // A Vuln contains information to display about a vulnerability.
@@ -33,16 +32,6 @@
 		return nil
 	}
 
-	vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, vc)
-	if err != nil {
-		return []Vuln{{Details: fmt.Sprintf("could not get vulnerability data: %v", err)}}
-	}
-	return vs
-}
-
-func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, vc *Client) (_ []Vuln, err error) {
-	defer derrors.Wrap(&err, "vulnsForPackage(%q, %q, %q)", modulePath, vers, packagePath)
-
 	// Handle special module paths.
 	if modulePath == stdlib.ModulePath {
 		// Stdlib pages requested at master will map to a pseudo version
@@ -51,8 +40,8 @@
 		// is the best we can do. The result is vulns won't be reported for a
 		// pseudoversion that refers to a commit that is in a vulnerable range.
 		switch {
-		case version.IsPseudo(vers):
-			return nil, nil
+		case vers.IsPseudo(version):
+			return nil
 		case strings.HasPrefix(packagePath, "cmd/"):
 			modulePath = osv.GoCmdModulePath
 		default:
@@ -61,12 +50,12 @@
 	}
 
 	// Get all the vulns for this package/version.
-	entries, err := vc.ByPackage(ctx, &PackageRequest{Module: modulePath, Package: packagePath, Version: vers})
+	entries, err := vc.ByPackage(ctx, &PackageRequest{Module: modulePath, Package: packagePath, Version: version})
 	if err != nil {
-		return nil, err
+		return []Vuln{{Details: fmt.Sprintf("could not get vulnerability data: %v", err)}}
 	}
 
-	return toVulns(entries), nil
+	return toVulns(entries)
 }
 
 func toVulns(entries []*osv.Entry) []Vuln {