internal/frontend, internal/vuln: replace getVulnEntries with vuln.Client

Instead of passing around a function, getVulnEntries, pass the actual
vuln client and call it directly.

Update the TestClient to implement the GetByModules function so that
tests can use it.

The purpose of this change is to further isolate calls to the vulndb
Client to the internal/vuln package, and to make the code easier to
understand by removing a function parameter.

For golang/go#58928

Change-Id: I8bef528034a1caa44b99da2f185990338ec9cd5f
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/474537
Reviewed-by: Jamal Carvalho <jamal@golang.org>
Run-TryBot: Tatiana Bradley <tatianabradley@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
diff --git a/internal/frontend/search.go b/internal/frontend/search.go
index 7711d29..5a0778c 100644
--- a/internal/frontend/search.go
+++ b/internal/frontend/search.go
@@ -129,11 +129,7 @@
 	if len(filters) > 0 {
 		symbol = filters[0]
 	}
-	var getVulnEntries vuln.VulnEntriesFunc
-	if vulnClient != nil {
-		getVulnEntries = vulnClient.ByModule
-	}
-	page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, getVulnEntries)
+	page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, vulnClient)
 	if err != nil {
 		// Instead of returning a 500, return a 408, since symbol searches may
 		// timeout for very popular symbols.
@@ -236,7 +232,7 @@
 // fetchSearchPage fetches data matching the search query from the database and
 // returns a SearchPage.
 func fetchSearchPage(ctx context.Context, db *postgres.DB, cq, symbol string,
-	pageParams paginationParams, searchSymbols bool, getVulnEntries vuln.VulnEntriesFunc) (*SearchPage, error) {
+	pageParams paginationParams, searchSymbols bool, vulnClient *vuln.Client) (*SearchPage, error) {
 	maxResultCount := maxSearchOffset + pageParams.limit
 
 	// Pageless search: always start from the beginning.
@@ -258,8 +254,8 @@
 		results = append(results, sr)
 	}
 
-	if getVulnEntries != nil {
-		addVulns(ctx, results, getVulnEntries)
+	if vulnClient != nil {
+		addVulns(ctx, results, vulnClient)
 	}
 
 	var numResults int
@@ -400,13 +396,13 @@
 	}, nil
 }
 
-func searchVulnAlias(ctx context.Context, mode, cq string, vulnClient *vuln.Client) (_ *searchAction, err error) {
+func searchVulnAlias(ctx context.Context, mode, cq string, vc *vuln.Client) (_ *searchAction, err error) {
 	defer derrors.Wrap(&err, "searchVulnAlias(%q, %q)", mode, cq)
 
 	if mode != searchModeVuln || !isVulnAlias(cq) {
 		return nil, nil
 	}
-	aliasEntries, err := vulnClient.ByAlias(ctx, cq)
+	aliasEntries, err := vc.ByAlias(ctx, cq)
 	if err != nil {
 		return nil, err
 	}
@@ -607,7 +603,7 @@
 
 // addVulns adds vulnerability information to search results by consulting the
 // vulnerability database.
-func addVulns(ctx context.Context, rs []*SearchResult, getVulnEntries vuln.VulnEntriesFunc) {
+func addVulns(ctx context.Context, rs []*SearchResult, vc *vuln.Client) {
 	// Get all vulns concurrently.
 	var wg sync.WaitGroup
 	// TODO(golang/go#48223): throttle concurrency?
@@ -616,7 +612,7 @@
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
-			r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, getVulnEntries)
+			r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, vc)
 		}()
 	}
 	wg.Wait()
diff --git a/internal/frontend/search_test.go b/internal/frontend/search_test.go
index 3306fe3..200d13f 100644
--- a/internal/frontend/search_test.go
+++ b/internal/frontend/search_test.go
@@ -312,12 +312,7 @@
 			}},
 		}}
 
-		getVulnEntries = func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
-			if modulePath == moduleFoo.ModulePath {
-				return vulnEntries, nil
-			}
-			return nil, nil
-		}
+		vc = vuln.NewTestClient(vulnEntries)
 	)
 
 	for _, m := range []*internal.Module{moduleFoo, moduleBar} {
@@ -392,7 +387,7 @@
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, getVulnEntries)
+			got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, vc)
 			if err != nil {
 				t.Fatalf("fetchSearchPage(db, %q): %v", test.query, err)
 			}
diff --git a/internal/frontend/tabs.go b/internal/frontend/tabs.go
index cd841ee..d7895dc 100644
--- a/internal/frontend/tabs.go
+++ b/internal/frontend/tabs.go
@@ -78,14 +78,14 @@
 // handler.
 func fetchDetailsForUnit(ctx context.Context, r *http.Request, tab string, ds internal.DataSource, um *internal.UnitMeta,
 	requestedVersion string, bc internal.BuildContext,
-	getVulnEntries vuln.VulnEntriesFunc) (_ any, err error) {
+	vc *vuln.Client) (_ any, err error) {
 	defer derrors.Wrap(&err, "fetchDetailsForUnit(r, %q, ds, um=%q,%q,%q)", tab, um.Path, um.ModulePath, um.Version)
 	switch tab {
 	case tabMain:
 		_, expandReadme := r.URL.Query()["readme"]
 		return fetchMainDetails(ctx, ds, um, requestedVersion, expandReadme, bc)
 	case tabVersions:
-		return fetchVersionsDetails(ctx, ds, um, getVulnEntries)
+		return fetchVersionsDetails(ctx, ds, um, vc)
 	case tabImports:
 		return fetchImportsDetails(ctx, ds, um.Path, um.ModulePath, um.Version)
 	case tabImportedBy:
diff --git a/internal/frontend/unit.go b/internal/frontend/unit.go
index c5194d8..af2f5bd 100644
--- a/internal/frontend/unit.go
+++ b/internal/frontend/unit.go
@@ -135,11 +135,7 @@
 	// It's also okay to provide just one (e.g. GOOS=windows), which will select
 	// the first doc with that value, ignoring the other one.
 	bc := internal.BuildContext{GOOS: r.FormValue("GOOS"), GOARCH: r.FormValue("GOARCH")}
-	var getVulnEntries vuln.VulnEntriesFunc
-	if s.vulnClient != nil {
-		getVulnEntries = s.vulnClient.ByModule
-	}
-	d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, getVulnEntries)
+	d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, s.vulnClient)
 	if err != nil {
 		return err
 	}
@@ -240,9 +236,8 @@
 	}
 
 	// Get vulnerability information.
-	if s.vulnClient != nil {
-		page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient.ByModule)
-	}
+	page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient)
+
 	s.servePage(ctx, w, tabSettings.TemplateName, page)
 	return nil
 }
diff --git a/internal/frontend/versions.go b/internal/frontend/versions.go
index 58d6a51..161ab78 100644
--- a/internal/frontend/versions.go
+++ b/internal/frontend/versions.go
@@ -85,7 +85,7 @@
 	Vulns               []vuln.Vuln
 }
 
-func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, getVulnEntries vuln.VulnEntriesFunc) (*VersionsDetails, error) {
+func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, vc *vuln.Client) (*VersionsDetails, error) {
 	db, ok := ds.(*postgres.DB)
 	if !ok {
 		// The proxydatasource does not support the imported by page.
@@ -114,7 +114,7 @@
 		}
 		return constructUnitURL(versionPath, mi.ModulePath, linkVersion(mi.ModulePath, mi.Version, mi.Version))
 	}
-	return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, getVulnEntries), nil
+	return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, vc), nil
 }
 
 // pathInVersion constructs the full import path of the package corresponding
@@ -146,7 +146,7 @@
 	modInfos []*internal.ModuleInfo,
 	sh *internal.SymbolHistory,
 	linkify func(v *internal.ModuleInfo) string,
-	getVulnEntries vuln.VulnEntriesFunc,
+	vc *vuln.Client,
 ) *VersionsDetails {
 	// lists organizes versions by VersionListKey.
 	lists := make(map[VersionListKey]*VersionList)
@@ -201,7 +201,7 @@
 		if mi.ModulePath == stdlib.ModulePath {
 			pkg = packagePath
 		}
-		vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, getVulnEntries)
+		vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, vc)
 		vl := lists[key]
 		if vl == nil {
 			seenLists = append(seenLists, key)
diff --git a/internal/frontend/versions_test.go b/internal/frontend/versions_test.go
index 756c8b5..70c6212 100644
--- a/internal/frontend/versions_test.go
+++ b/internal/frontend/versions_test.go
@@ -107,12 +107,7 @@
 			},
 		}},
 	}
-	getVulnEntries := func(_ context.Context, m string) ([]*osv.Entry, error) {
-		if m == modulePath1 {
-			return []*osv.Entry{vulnEntry}, nil
-		}
-		return nil, nil
-	}
+	vc := vuln.NewTestClient([]*osv.Entry{vulnEntry})
 
 	for _, tc := range []struct {
 		name        string
@@ -201,7 +196,7 @@
 				postgres.MustInsertModule(ctx, t, testDB, v)
 			}
 
-			got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, getVulnEntries)
+			got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, vc)
 			if err != nil {
 				t.Fatalf("fetchVersionsDetails(ctx, db, %q, %q): %v", tc.pkg.Path, tc.pkg.ModulePath, err)
 			}
diff --git a/internal/vuln/test_client.go b/internal/vuln/test_client.go
index 463a9ae..5915da4 100644
--- a/internal/vuln/test_client.go
+++ b/internal/vuln/test_client.go
@@ -6,33 +6,38 @@
 
 import (
 	"context"
-	"errors"
 
 	vulnc "golang.org/x/vuln/client"
 	"golang.org/x/vuln/osv"
 )
 
+// NewTestClient creates an in-memory client for use in tests.
 func NewTestClient(entries []*osv.Entry) *Client {
 	c := &vulndbTestClient{
-		entries:    entries,
-		aliasToIDs: map[string][]string{},
+		entries:          entries,
+		aliasToIDs:       map[string][]string{},
+		modulesToEntries: map[string][]*osv.Entry{},
 	}
 	for _, e := range entries {
 		for _, a := range e.Aliases {
 			c.aliasToIDs[a] = append(c.aliasToIDs[a], e.ID)
 		}
+		for _, affected := range e.Affected {
+			c.modulesToEntries[affected.Package.Name] = append(c.modulesToEntries[affected.Package.Name], e)
+		}
 	}
 	return &Client{c: c}
 }
 
 type vulndbTestClient struct {
 	vulnc.Client
-	entries    []*osv.Entry
-	aliasToIDs map[string][]string
+	entries          []*osv.Entry
+	aliasToIDs       map[string][]string
+	modulesToEntries map[string][]*osv.Entry
 }
 
-func (c *vulndbTestClient) GetByModule(context.Context, string) ([]*osv.Entry, error) {
-	return nil, errors.New("unimplemented")
+func (c *vulndbTestClient) GetByModule(_ context.Context, module string) ([]*osv.Entry, error) {
+	return c.modulesToEntries[module], nil
 }
 
 func (c *vulndbTestClient) GetByID(_ context.Context, id string) (*osv.Entry, error) {
diff --git a/internal/vuln/vulns.go b/internal/vuln/vulns.go
index 4263bdb..8d32c43 100644
--- a/internal/vuln/vulns.go
+++ b/internal/vuln/vulns.go
@@ -34,27 +34,24 @@
 	Details string
 }
 
-type VulnEntriesFunc func(context.Context, string) ([]*osv.Entry, error)
-
 // VulnsForPackage obtains vulnerability information for the given package.
 // If packagePath is empty, it returns all entries for the module at version.
-// The getVulnEntries function should retrieve all entries for the given module path.
-// It is passed to facilitate testing.
 // If there is an error, VulnsForPackage returns a single Vuln that describes the error.
-func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, getVulnEntries VulnEntriesFunc) []Vuln {
-	vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, getVulnEntries)
+func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, vc *Client) []Vuln {
+	if vc == nil {
+		return nil
+	}
+
+	vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, vc)
 	if err != nil {
 		return []Vuln{{Details: fmt.Sprintf("could not get vulnerability data: %v", err)}}
 	}
 	return vs
 }
 
-func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, getVulnEntries VulnEntriesFunc) (_ []Vuln, err error) {
-	defer derrors.Wrap(&err, "vulns(%q, %q, %q)", modulePath, vers, packagePath)
+func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, vc *Client) (_ []Vuln, err error) {
+	defer derrors.Wrap(&err, "vulnsForPackage(%q, %q, %q)", modulePath, vers, packagePath)
 
-	if getVulnEntries == nil {
-		return nil, nil
-	}
 	// Stdlib pages requested at master will map to a pseudo version that puts
 	// all vulns in range. We can't really tell you're at master so version.IsPseudo
 	// is the best we can do. The result is vulns won't be reported for a pseudoversion
@@ -68,7 +65,7 @@
 		modulePath = vulnStdlibModulePath
 	}
 	// Get all the vulns for this module.
-	entries, err := getVulnEntries(ctx, modulePath)
+	entries, err := vc.ByModule(ctx, modulePath)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/vuln/vulns_test.go b/internal/vuln/vulns_test.go
index c723d12..d3ce181 100644
--- a/internal/vuln/vulns_test.go
+++ b/internal/vuln/vulns_test.go
@@ -6,7 +6,6 @@
 
 import (
 	"context"
-	"fmt"
 	"reflect"
 	"testing"
 
@@ -60,18 +59,7 @@
 		}},
 	}
 
-	get := func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
-		switch modulePath {
-		case "good.com":
-			return nil, nil
-		case "bad.com", "unfixable.com":
-			return []*osv.Entry{&e}, nil
-		case "stdlib":
-			return []*osv.Entry{&stdlib}, nil
-		default:
-			return nil, fmt.Errorf("unknown module %q", modulePath)
-		}
-	}
+	vc := NewTestClient([]*osv.Entry{&e, &stdlib})
 
 	testCases := []struct {
 		mod, pkg, version string
@@ -118,7 +106,7 @@
 		},
 	}
 	for _, tc := range testCases {
-		got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, get)
+		got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, vc)
 		if diff := cmp.Diff(tc.want, got); diff != "" {
 			t.Errorf("VulnsForPackage(%q, %q, %q) = %+v, mismatch (-want, +got):\n%s", tc.mod, tc.version, tc.pkg, tc.want, diff)
 		}