internal/frontend: add vulns to search results

For golang/go#48223

Change-Id: I6dd0adffa17c754c91dd952dd3f55d8a9c53a5de
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/348789
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/frontend/search.go b/internal/frontend/search.go
index 8eb4537..a1c12c4 100644
--- a/internal/frontend/search.go
+++ b/internal/frontend/search.go
@@ -12,6 +12,7 @@
 	"path"
 	"sort"
 	"strings"
+	"sync"
 	"time"
 	"unicode"
 	"unicode/utf8"
@@ -98,7 +99,7 @@
 		symbol = filters[0]
 	}
 	mode := searchMode(r)
-	page, err := fetchSearchPage(ctx, db, query, symbol, pageParams, mode == searchModeSymbol)
+	page, err := fetchSearchPage(ctx, db, query, symbol, pageParams, mode == searchModeSymbol, s.getVulnEntries)
 	if err != nil {
 		return fmt.Errorf("fetchSearchPage(ctx, db, %q): %v", query, err)
 	}
@@ -160,6 +161,7 @@
 	Name           string
 	PackagePath    string
 	ModulePath     string
+	Version        string
 	ChipText       string
 	Synopsis       string
 	DisplayVersion string
@@ -175,6 +177,7 @@
 	SymbolGOOS     string
 	SymbolGOARCH   string
 	SymbolLink     string
+	Vulns          []Vuln
 }
 
 type subResult struct {
@@ -185,7 +188,7 @@
 // fetchSearchPage fetches data matching the search query from the database and
 // returns a SearchPage.
 func fetchSearchPage(ctx context.Context, db *postgres.DB, query, symbol string,
-	pageParams paginationParams, searchSymbols bool) (*SearchPage, error) {
+	pageParams paginationParams, searchSymbols bool, getVulnEntries vulnEntriesFunc) (*SearchPage, error) {
 	maxResultCount := maxSearchOffset + pageParams.limit
 
 	offset := pageParams.offset()
@@ -210,6 +213,10 @@
 		results = append(results, sr)
 	}
 
+	if getVulnEntries != nil && experiment.IsActive(ctx, internal.ExperimentVulns) {
+		addVulns(results, getVulnEntries)
+	}
+
 	var numResults int
 	if len(dbresults) > 0 {
 		numResults = int(dbresults[0].NumResults)
@@ -250,6 +257,7 @@
 		Name:           name,
 		PackagePath:    r.PackagePath,
 		ModulePath:     r.ModulePath,
+		Version:        r.Version,
 		ChipText:       chipText,
 		Synopsis:       r.Synopsis,
 		DisplayVersion: displayVersion(r.ModulePath, r.Version, r.Version),
@@ -468,3 +476,21 @@
 
 	return absoluteTime(date)
 }
+
+// addVulns adds vulnerability information to search results by consulting the
+// vulnerability database.
+func addVulns(rs []*SearchResult, getVulnEntries vulnEntriesFunc) {
+	// Get all vulns concurrently.
+	var wg sync.WaitGroup
+	// TODO(golang/go#48223): throttle concurrency?
+	for _, r := range rs {
+		r := r
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			r.Vulns = Vulns(r.ModulePath, r.Version, r.PackagePath, getVulnEntries)
+		}()
+	}
+	wg.Wait()
+
+}
diff --git a/internal/frontend/search_test.go b/internal/frontend/search_test.go
index 614ecd1..5b24f20 100644
--- a/internal/frontend/search_test.go
+++ b/internal/frontend/search_test.go
@@ -20,6 +20,7 @@
 	"golang.org/x/pkgsite/internal/testing/sample"
 	"golang.org/x/text/language"
 	"golang.org/x/text/message"
+	"golang.org/x/vulndb/osv"
 )
 
 func TestSearchQueryAndMode(t *testing.T) {
@@ -69,6 +70,8 @@
 	defer cancel()
 	defer postgres.ResetTestDB(testDB, t)
 
+	ctx = experiment.NewContext(ctx, internal.ExperimentVulns)
+
 	var (
 		now       = sample.NowTruncated()
 		moduleFoo = &internal.Module{
@@ -139,7 +142,27 @@
 				},
 			},
 		}
+
+		vulnEntries = []*osv.Entry{{
+			ID:      "test",
+			Details: "vuln",
+			Affected: []osv.Affected{{
+				Package: osv.Package{Name: "github.com/mod/foo"},
+				Ranges: []osv.AffectsRange{{
+					Type:   osv.TypeSemver,
+					Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "1.9.0"}},
+				}},
+			}},
+		}}
+
+		getVulnEntries = func(modulePath string) ([]*osv.Entry, error) {
+			if modulePath == moduleFoo.ModulePath {
+				return vulnEntries, nil
+			}
+			return nil, nil
+		}
 	)
+
 	for _, m := range []*internal.Module{moduleFoo, moduleBar} {
 		postgres.MustInsertModule(ctx, t, testDB, m)
 	}
@@ -169,6 +192,7 @@
 						Name:           moduleBar.Packages()[0].Name,
 						PackagePath:    moduleBar.Packages()[0].Path,
 						ModulePath:     moduleBar.ModulePath,
+						Version:        "v1.0.0",
 						Synopsis:       moduleBar.Packages()[0].Documentation[0].Synopsis,
 						DisplayVersion: moduleBar.Version,
 						Licenses:       []string{"MIT"},
@@ -197,17 +221,19 @@
 						Name:           moduleFoo.Packages()[0].Name,
 						PackagePath:    moduleFoo.Packages()[0].Path,
 						ModulePath:     moduleFoo.ModulePath,
+						Version:        "v1.0.0",
 						Synopsis:       moduleFoo.Packages()[0].Documentation[0].Synopsis,
 						DisplayVersion: moduleFoo.Version,
 						Licenses:       []string{"MIT"},
 						CommitTime:     elapsedTime(moduleFoo.CommitTime),
+						Vulns:          []Vuln{{ID: "test", Details: "vuln", FixedVersion: "v1.9.0"}},
 					},
 				},
 			},
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false)
+			got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, getVulnEntries)
 			if err != nil {
 				t.Fatalf("fetchSearchPage(db, %q): %v", test.query, err)
 			}
@@ -246,6 +272,7 @@
 				Name:           "pkg",
 				PackagePath:    "m.com/pkg",
 				ModulePath:     "m.com",
+				Version:        "v1.0.0",
 				DisplayVersion: "v1.0.0",
 				NumImportedBy:  "3",
 			},
@@ -264,6 +291,7 @@
 				Name:           "cmd",
 				PackagePath:    "m.com/cmd",
 				ModulePath:     "m.com",
+				Version:        "v1.0.0",
 				DisplayVersion: "v1.0.0",
 				ChipText:       "command",
 				NumImportedBy:  "1,234",
@@ -282,6 +310,7 @@
 				Name:           "math",
 				PackagePath:    "math",
 				ModulePath:     "std",
+				Version:        "v1.14.0",
 				DisplayVersion: "go1.14",
 				ChipText:       "standard library",
 				NumImportedBy:  "0",
@@ -301,6 +330,7 @@
 				Name:           "pkg",
 				PackagePath:    "m.com/pkg",
 				ModulePath:     "m.com",
+				Version:        "v1.0.0",
 				DisplayVersion: "v1.0.0",
 				NumImportedBy:  "3.456",
 			},