internal/postgres: limit searches to 100 results

With https://golang.org/cl/258198, it's misleading to return a large
result count when deep result pages cannot be viewed. Pushing down this
limit into search simplifies things significantly, because it allows us
to drop the result count estimate.

Also impose a maximum page size, and lift up limit checking above the
database query to find a potential redirect path.

Change-Id: I977540f55cec52d35e5715fe04487fca62144956
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/258223
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Trust: Robert Findley <rfindley@google.com>
diff --git a/internal/frontend/search.go b/internal/frontend/search.go
index 08eeb3e..b890c2a 100644
--- a/internal/frontend/search.go
+++ b/internal/frontend/search.go
@@ -46,7 +46,8 @@
 // fetchSearchPage fetches data matching the search query from the database and
 // returns a SearchPage.
 func fetchSearchPage(ctx context.Context, db *postgres.DB, query string, pageParams paginationParams) (*SearchPage, error) {
-	dbresults, err := db.Search(ctx, query, pageParams.limit, pageParams.offset())
+	maxResultCount := maxSearchOffset + pageParams.limit
+	dbresults, err := db.Search(ctx, query, pageParams.limit, pageParams.offset(), maxResultCount)
 	if err != nil {
 		return nil, err
 	}
@@ -100,16 +101,21 @@
 	return int(unit * math.Round(float64(estimate)/unit))
 }
 
-// maxSearchQueryLength represents the max number of characters that a search
-// query can be. For PostgreSQL 11, there is a max length of 2K bytes:
-// https://www.postgresql.org/docs/11/textsearch-limitations.html.
-// No valid searches on pkg.go.dev will need more than the
-// maxSearchQueryLength.
-const maxSearchQueryLength = 500
+// Search constraints.
+const (
+	// maxSearchQueryLength represents the max number of characters that a search
+	// query can be. For PostgreSQL 11, there is a max length of 2K bytes:
+	// https://www.postgresql.org/docs/11/textsearch-limitations.html. No valid
+	// searches on pkg.go.dev will need more than the maxSearchQueryLength.
+	maxSearchQueryLength = 500
 
-// maxSearchOffset is the maximum allowed. offset into the search results.
-// This prevents some very CPU-intensive queries from running.
-const maxSearchOffset = 100
+	// maxSearchOffset is the maximum allowed offset into the search results.
+	// This prevents some very CPU-intensive queries from running.
+	maxSearchOffset = 90
+
+	// maxSearchPageSize is the maximum allowed limit for search results.
+	maxSearchPageSize = 100
+)
 
 // serveSearch applies database data to the search template. Handles endpoint
 // /search?q=<query>. If <query> is an exact match for a package path, the user
@@ -139,21 +145,30 @@
 		http.Redirect(w, r, "/", http.StatusFound)
 		return nil
 	}
-	if path := searchRequestRedirectPath(ctx, ds, query); path != "" {
-		http.Redirect(w, r, path, http.StatusFound)
-		return nil
-	}
 	pageParams := newPaginationParams(r, defaultSearchLimit)
 	if pageParams.offset() > maxSearchOffset {
 		return &serverError{
 			status: http.StatusBadRequest,
 			epage: &errorPage{
 				messageTemplate: template.MakeTrustedTemplate(
-					`<h3 class="Error-message">Search offset too large.</h3>`),
+					`<h3 class="Error-message">Search page number too large.</h3>`),
+			},
+		}
+	}
+	if pageParams.limit > maxSearchPageSize {
+		return &serverError{
+			status: http.StatusBadRequest,
+			epage: &errorPage{
+				messageTemplate: template.MakeTrustedTemplate(
+					`<h3 class="Error-message">Search page size too large.</h3>`),
 			},
 		}
 	}
 
+	if path := searchRequestRedirectPath(ctx, ds, query); path != "" {
+		http.Redirect(w, r, path, http.StatusFound)
+		return nil
+	}
 	page, err := fetchSearchPage(ctx, db, query, pageParams)
 	if err != nil {
 		return fmt.Errorf("fetchSearchPage(ctx, db, %q): %v", query, err)
diff --git a/internal/postgres/benchmarks_test.go b/internal/postgres/benchmarks_test.go
index 6658d6c..3fe7f25 100644
--- a/internal/postgres/benchmarks_test.go
+++ b/internal/postgres/benchmarks_test.go
@@ -43,14 +43,14 @@
 		b.Fatal(err)
 	}
 	db := New(ddb)
-	searchers := map[string]func(context.Context, string, int, int) ([]*internal.SearchResult, error){
+	searchers := map[string]func(context.Context, string, int, int, int) ([]*internal.SearchResult, error){
 		"db.Search": db.Search,
 	}
 	for name, search := range searchers {
 		for _, query := range testQueries {
 			b.Run(name+":"+query, func(b *testing.B) {
 				for i := 0; i < b.N; i++ {
-					if _, err := search(ctx, query, 10, 0); err != nil {
+					if _, err := search(ctx, query, 10, 0, 100); err != nil {
 						b.Fatal(err)
 					}
 				}
diff --git a/internal/postgres/search.go b/internal/postgres/search.go
index dbdaf35..9bc2956 100644
--- a/internal/postgres/search.go
+++ b/internal/postgres/search.go
@@ -66,11 +66,6 @@
 	// err indicates a technical failure of the search query, or that results are
 	// not provably complete.
 	err error
-	// uncounted reports whether this response is missing total result counts. If
-	// uncounted is true, search will wait for either the hyperloglog count
-	// estimate, or for an alternate search method to return with
-	// uncounted=false.
-	uncounted bool
 }
 
 // searchEvent is used to log structured information about search events for
@@ -86,7 +81,7 @@
 }
 
 // A searcher is used to execute a single search request.
-type searcher func(db *DB, ctx context.Context, q string, limit, offset int) searchResponse
+type searcher func(db *DB, ctx context.Context, q string, limit, offset, maxResultCount int) searchResponse
 
 // The searchers used by Search.
 var searchers = map[string]searcher{
@@ -117,9 +112,9 @@
 // The gap in this optimization is search terms that are very frequent, but
 // rarely relevant: "int" or "package", for example. In these cases we'll pay
 // the penalty of a deep search that scans nearly every package.
-func (db *DB) Search(ctx context.Context, q string, limit, offset int) (_ []*internal.SearchResult, err error) {
+func (db *DB) Search(ctx context.Context, q string, limit, offset, maxResultCount int) (_ []*internal.SearchResult, err error) {
 	defer derrors.Wrap(&err, "DB.Search(ctx, %q, %d, %d)", q, limit, offset)
-	resp, err := db.hedgedSearch(ctx, q, limit, offset, searchers, nil)
+	resp, err := db.hedgedSearch(ctx, q, limit, offset, maxResultCount, searchers, nil)
 	if err != nil {
 		return nil, err
 	}
@@ -170,7 +165,7 @@
 // available result.
 // The optional guardTestResult func may be used to allow tests to control the
 // order in which search results are returned.
-func (db *DB) hedgedSearch(ctx context.Context, q string, limit, offset int, searchers map[string]searcher, guardTestResult func(string) func()) (*searchResponse, error) {
+func (db *DB) hedgedSearch(ctx context.Context, q string, limit, offset, maxResultCount int, searchers map[string]searcher, guardTestResult func(string) func()) (*searchResponse, error) {
 	searchStart := time.Now()
 	responses := make(chan searchResponse, len(searchers))
 	// cancel all unfinished searches when a result (or error) is returned. The
@@ -178,28 +173,12 @@
 	searchCtx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
-	// Asynchronously query for the estimated result count.
-	estimateChan := make(chan estimateResponse, 1)
-	go func() {
-		start := time.Now()
-		estimateResp := db.estimateResultsCount(searchCtx, q)
-		log.Debug(ctx, searchEvent{
-			Type:    "estimate",
-			Latency: time.Since(start),
-			Err:     estimateResp.err,
-		})
-		if guardTestResult != nil {
-			defer guardTestResult("estimate")()
-		}
-		estimateChan <- estimateResp
-	}()
-
 	// Fan out our search requests.
 	for _, s := range searchers {
 		s := s
 		go func() {
 			start := time.Now()
-			resp := s(db, searchCtx, q, limit, offset)
+			resp := s(db, searchCtx, q, limit, offset, maxResultCount)
 			log.Debug(ctx, searchEvent{
 				Type:    resp.source,
 				Latency: time.Since(start),
@@ -218,42 +197,6 @@
 	if resp.err != nil {
 		return nil, fmt.Errorf("%q search failed: %v", resp.source, resp.err)
 	}
-	if resp.uncounted {
-		// Since the response is uncounted, we should wait for either the count
-		// estimate to return, or for the first counted response.
-	loop:
-		for {
-			select {
-			case nextResp := <-responses:
-				switch {
-				case nextResp.err != nil:
-					// There are alternatives here: we could continue waiting for the
-					// estimate. But on the principle that errors are most likely to be
-					// caused by Postgres overload, we exit early to cancel the estimate.
-					return nil, fmt.Errorf("while waiting for count, got error from searcher %q: %v", nextResp.source, nextResp.err)
-				case !nextResp.uncounted:
-					log.Infof(ctx, "using counted search results from searcher %s", nextResp.source)
-					// use this response since it is counted.
-					resp = nextResp
-					break loop
-				}
-			case estr := <-estimateChan:
-				if estr.err != nil {
-					return nil, fmt.Errorf("error getting estimated count: %v", estr.err)
-				}
-				log.Debug(ctx, "using count estimate")
-				for _, r := range resp.results {
-					// TODO: change the return signature of search to separate
-					// result-level data from this query-level metadata.
-					r.NumResults = estr.estimate
-					r.Approximate = true
-				}
-				break loop
-			case <-ctx.Done():
-				return nil, fmt.Errorf("context deadline exceeded while waiting for estimated result count")
-			}
-		}
-	}
 	// cancel proactively here: we've got the search result we need.
 	cancel()
 	// latency is only recorded for valid search results, as fast failures could
@@ -276,88 +219,9 @@
 
 const hllRegisterCount = 128
 
-// hllQuery estimates search result counts using the hyperloglog algorithm.
-// https://en.wikipedia.org/wiki/HyperLogLog
-//
-// Here's how this works:
-//   1) Search documents have been partitioned ~evenly into hllRegisterCount
-//   registers, using the hll_register column. For each hll_register, compute
-//   the maximum number of leading zeros of any element in the register
-//   matching our search query. This is the slowest part of the query, but
-//   since we have an index on (hll_register, hll_leading_zeros desc), we can
-//   parallelize this and it should be very quick if the density of search
-//   results is high.  To achieve this parallelization, we use a trick of
-//   selecting a subselected value from generate_series(0, hllRegisterCount-1).
-//
-//   If there are NO search results in a register, the 'zeros' column will be
-//   NULL.
-//
-//   2) From the results of (1), proceed following the 'Practical
-//   Considerations' in the wikipedia page above:
-//     https://en.wikipedia.org/wiki/HyperLogLog#Practical_Considerations
-//   Specifically, use linear counting when E < (5/2)m and there are empty
-//   registers.
-//
-//   This should work for any register count >= 128. If we are to decrease this
-//   register count, we should adjust the estimate for a_m below according to
-//   the formulas in the wikipedia article above.
-var hllQuery = fmt.Sprintf(`
-	WITH hll_data AS (
-		SELECT (
-			SELECT * FROM (
-				SELECT hll_leading_zeros
-				FROM search_documents
-				WHERE (
-					%[2]s *
-					CASE WHEN tsv_search_tokens @@ websearch_to_tsquery($1) THEN 1 ELSE 0 END
-				) > 0.1
-				AND hll_register=generate_series
-				ORDER BY hll_leading_zeros DESC
-			) t
-			LIMIT 1
-		) zeros
-		FROM generate_series(0,%[1]d-1)
-	),
-	nonempty_registers as (SELECT zeros FROM hll_data WHERE zeros IS NOT NULL)
-	SELECT
-		-- use linear counting when there are not enough results, and there is at
-		-- least one empty register, per 'Practical Considerations'.
-		CASE WHEN result_count < 2.5 * %[1]d AND empty_register_count > 0
-		THEN ((0.7213 / (1 + 1.079 / %[1]d)) * (%[1]d *
-				log(2, (%[1]d::numeric) / empty_register_count)))::int
-		ELSE result_count END AS approx_count
-	FROM (
-		SELECT
-			(
-				(0.7213 / (1 + 1.079 / %[1]d)) *  -- estimate for a_m
-				pow(%[1]d, 2) *                   -- m^2
-				(1/((%[1]d - count(1)) + SUM(POW(2, -1 * (zeros+1)))))  -- Z
-			)::int AS result_count,
-			%[1]d - count(1) AS empty_register_count
-		FROM nonempty_registers
-	) d`, hllRegisterCount, scoreExpr)
-
-type estimateResponse struct {
-	estimate uint64
-	err      error
-}
-
-// EstimateResultsCount uses the hyperloglog algorithm to estimate the number
-// of results for the given search term.
-func (db *DB) estimateResultsCount(ctx context.Context, q string) estimateResponse {
-	row := db.db.QueryRow(ctx, hllQuery, q)
-	var estimate sql.NullInt64
-	if err := row.Scan(&estimate); err != nil {
-		return estimateResponse{err: fmt.Errorf("row.Scan(): %v", err)}
-	}
-	// If estimate is NULL, then we didn't find *any* results, so should return
-	// zero (the default).
-	return estimateResponse{estimate: uint64(estimate.Int64)}
-}
-
 // deepSearch searches all packages for the query. It is slower, but results
 // are always valid.
-func (db *DB) deepSearch(ctx context.Context, q string, limit, offset int) searchResponse {
+func (db *DB) deepSearch(ctx context.Context, q string, limit, offset, maxResultCount int) searchResponse {
 	query := fmt.Sprintf(`
 		SELECT *, COUNT(*) OVER() AS total
 		FROM (
@@ -393,6 +257,11 @@
 	if err != nil {
 		results = nil
 	}
+	if len(results) > 0 && results[0].NumResults > uint64(maxResultCount) {
+		for _, r := range results {
+			r.NumResults = uint64(maxResultCount)
+		}
+	}
 	return searchResponse{
 		source:  "deep",
 		results: results,
@@ -400,7 +269,7 @@
 	}
 }
 
-func (db *DB) popularSearch(ctx context.Context, searchQuery string, limit, offset int) searchResponse {
+func (db *DB) popularSearch(ctx context.Context, searchQuery string, limit, offset, maxResultCount int) searchResponse {
 	query := `
 		SELECT
 			package_path,
@@ -424,18 +293,28 @@
 	if err != nil {
 		results = nil
 	}
+	numResults := maxResultCount
+	if offset+limit > maxResultCount || len(results) < limit {
+		// It is practically impossible that len(results) < limit, because popular
+		// search will never linearly scan everything before deep search completes,
+		// but just to be slightly more theoretically correct, if our search
+		// results are partial we know that we have exhausted all results.
+		numResults = offset + len(results)
+	}
+	for _, r := range results {
+		r.NumResults = uint64(numResults)
+	}
 	return searchResponse{
-		source:    "popular",
-		results:   results,
-		err:       err,
-		uncounted: true,
+		source:  "popular",
+		results: results,
+		err:     err,
 	}
 }
 
 // addPackageDataToSearchResults adds package information to SearchResults that is not stored
 // in the search_documents table.
 func (db *DB) addPackageDataToSearchResults(ctx context.Context, results []*internal.SearchResult) (err error) {
-	defer derrors.Wrap(&err, "DB.enrichResults(results)")
+	defer derrors.Wrap(&err, "DB.addPackageDataToSearchResults(results)")
 	if len(results) == 0 {
 		return nil
 	}
diff --git a/internal/postgres/search_test.go b/internal/postgres/search_test.go
index a93a51f..36d5b83 100644
--- a/internal/postgres/search_test.go
+++ b/internal/postgres/search_test.go
@@ -188,18 +188,6 @@
 		}
 	}
 	guardTestResult := func(source string) func() {
-		// This test is inherently racy as 'estimate' results are are on a
-		// separate channel, and therefore even after guarding still race to
-		// the select statement.
-		//
-		// Since this is a concern only for testing, and since this test is
-		// rather slow anyway, just wait for a healthy amount of time in order
-		// to de-flake the test. If the test still proves flaky, we can either
-		// increase this sleep or refactor so that all asynchronous results
-		// arrive on the same channel.
-		if source == "estimate" {
-			time.Sleep(100 * time.Millisecond)
-		}
 		if await, ok := waitFor[source]; ok {
 			<-done[await]
 		}
@@ -218,42 +206,50 @@
 		wantTotal   uint64
 	}{
 		{
-			label:       "single package",
+			label:       "single package from popular",
 			modules:     importGraph("foo.com/A", "", 0),
-			resultOrder: []string{"popular", "estimate", "deep"},
+			resultOrder: []string{"popular", "deep"},
 			wantSource:  "popular",
 			wantResults: []string{"foo.com/A"},
 			wantTotal:   1,
 		},
 		{
+			label:       "single package from deep",
+			modules:     importGraph("foo.com/A", "", 0),
+			resultOrder: []string{"deep", "popular"},
+			wantSource:  "deep",
+			wantResults: []string{"foo.com/A"},
+			wantTotal:   1,
+		},
+		{
 			label:       "empty results",
 			modules:     []*internal.Module{},
-			resultOrder: []string{"deep", "estimate", "popular"},
+			resultOrder: []string{"deep", "popular"},
 			wantSource:  "deep",
 			wantResults: nil,
 		},
 		{
 			label:       "both popular and unpopular results",
 			modules:     importGraph("foo.com/popular", "bar.com/foo", 10),
-			resultOrder: []string{"popular", "estimate", "deep"},
+			resultOrder: []string{"popular", "deep"},
 			wantSource:  "popular",
 			wantResults: []string{"foo.com/popular", "bar.com/foo/importer0"},
-			wantTotal:   11, // HLL result count (happens to be right in this case)
+			wantTotal:   100, // popular assumes 100 results
 		},
 		{
-			label: "popular results, estimate before deep",
+			label: "popular before deep",
 			modules: append(importGraph("foo.com/popularA", "bar.com", 60),
 				importGraph("foo.com/popularB", "baz.com/foo", 70)...),
-			resultOrder: []string{"popular", "estimate", "deep"},
+			resultOrder: []string{"popular", "deep"},
 			wantSource:  "popular",
 			wantResults: []string{"foo.com/popularB", "foo.com/popularA"},
-			wantTotal:   76, // HLL result count (actual count is 72)
+			wantTotal:   100, // popular assumes 100 results
 		},
 		{
-			label: "popular results, deep before estimate",
+			label: "deep before popular",
 			modules: append(importGraph("foo.com/popularA", "bar.com/foo", 60),
 				importGraph("foo.com/popularB", "bar.com/foo", 70)...),
-			resultOrder: []string{"popular", "deep", "estimate"},
+			resultOrder: []string{"deep", "popular"},
 			wantSource:  "deep",
 			wantResults: []string{"foo.com/popularB", "foo.com/popularA"},
 			wantTotal:   72,
@@ -299,7 +295,7 @@
 				t.Fatal(err)
 			}
 			guardTestResult := resultGuard(test.resultOrder)
-			resp, err := testDB.hedgedSearch(ctx, "foo", 2, 0, searchers, guardTestResult)
+			resp, err := testDB.hedgedSearch(ctx, "foo", 2, 0, 100, searchers, guardTestResult)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -334,7 +330,7 @@
 		for name, search := range searchers {
 			if name == searcherName {
 				name := name
-				newSearchers[name] = func(*DB, context.Context, string, int, int) searchResponse {
+				newSearchers[name] = func(*DB, context.Context, string, int, int, int) searchResponse {
 					return searchResponse{
 						source: name,
 						err:    errors.New("bad"),
@@ -357,25 +353,19 @@
 		{
 			label:       "error in first result",
 			searchers:   errorIn("popular"),
-			resultOrder: []string{"popular", "estimate", "deep"},
+			resultOrder: []string{"popular", "deep"},
 			wantErr:     true,
 		},
 		{
 			label:       "return before error",
 			searchers:   errorIn("deep"),
-			resultOrder: []string{"popular", "estimate", "deep"},
+			resultOrder: []string{"popular", "deep"},
 			wantSource:  "popular",
 		},
 		{
-			label:       "error waiting for count",
-			searchers:   errorIn("deep"),
-			resultOrder: []string{"popular", "deep", "estimate"},
-			wantErr:     true,
-		},
-		{
 			label:       "counted result before error",
 			searchers:   errorIn("popular"),
-			resultOrder: []string{"deep", "popular", "estimate"},
+			resultOrder: []string{"deep", "popular"},
 			wantSource:  "deep",
 		},
 	}
@@ -394,7 +384,7 @@
 				t.Fatal(err)
 			}
 			guardTestResult := resultGuard(test.resultOrder)
-			resp, err := testDB.hedgedSearch(ctx, "foo", 2, 0, test.searchers, guardTestResult)
+			resp, err := testDB.hedgedSearch(ctx, "foo", 2, 0, 100, test.searchers, guardTestResult)
 			if (err != nil) != test.wantErr {
 				t.Fatalf("hedgedSearch(): got error %v, want error: %t", err, test.wantErr)
 			}
@@ -548,7 +538,7 @@
 					tc.limit = 10
 				}
 
-				got := searcher(testDB, ctx, tc.searchQuery, tc.limit, tc.offset)
+				got := searcher(testDB, ctx, tc.searchQuery, tc.limit, tc.offset, 100)
 				if got.err != nil {
 					t.Fatal(got.err)
 				}
@@ -603,7 +593,7 @@
 
 	for method, searcher := range searchers {
 		t.Run(method, func(t *testing.T) {
-			res := searcher(testDB, ctx, "foo", 10, 0)
+			res := searcher(testDB, ctx, "foo", 10, 0, 100)
 			if res.err != nil {
 				t.Fatal(res.err)
 			}
@@ -638,7 +628,7 @@
 		t.Fatal(err)
 	}
 	// Search for both packages.
-	gotResults, err := testDB.Search(ctx, domain, 10, 0)
+	gotResults, err := testDB.Search(ctx, domain, 10, 0, 100)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -670,7 +660,7 @@
 		{testDB, true},
 		{bypassDB, false},
 	} {
-		rs, err := test.db.Search(ctx, m.ModulePath, 10, 0)
+		rs, err := test.db.Search(ctx, m.ModulePath, 10, 0, 100)
 		if err != nil {
 			t.Fatal(err)
 		}