internal/database: add ability to run queries incrementally

Postgres lets you define a cursor and fetch rows from it as top-level
statements, without needing to write any PS-SQL. See
https://www.postgresql.org/docs/11/sql-declare.html and
https://www.postgresql.org/docs/11/sql-fetch.html for details.

Use this feature to define RunQueryIncrementally, which repeatedly
fetches query rows in batches until it runs out or the passed function
says it's done.

If grouping by module paths is enabled, use RunQueryIncrementally
with a very large limit to read rows until we've seen a page's worth
of module paths.

This CL doesn't handle pagination correctly. Any page after the first
is going to be wrong.

Change-Id: Idf8233160b0cf74412a688e1a6b95f4f2b720008
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/329469
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/database/database.go b/internal/database/database.go
index 30963f0..dbbde42 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -12,6 +12,7 @@
 	"database/sql"
 	"errors"
 	"fmt"
+	"io"
 	"regexp"
 	"strings"
 	"sync"
@@ -141,23 +142,57 @@
 	return db.db.PrepareContext(ctx, query)
 }
 
-// RunQuery executes query, then calls f on each row.
+// RunQuery executes query, then calls f on each row. It stops when there are no
+// more rows or f returns a non-nil error.
 func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error {
 	rows, err := db.Query(ctx, query, params...)
 	if err != nil {
 		return err
 	}
-	return processRows(rows, f)
+	_, err = processRows(rows, f)
+	return err
 }
 
-func processRows(rows *sql.Rows, f func(*sql.Rows) error) error {
+func processRows(rows *sql.Rows, f func(*sql.Rows) error) (int, error) {
 	defer rows.Close()
+	n := 0
 	for rows.Next() {
+		n++
 		if err := f(rows); err != nil {
-			return err
+			return n, err
 		}
 	}
-	return rows.Err()
+	return n, rows.Err()
+}
+
+// RunQueryIncrementally executes query, then calls f on each row. It fetches
+// rows in groups of size batchSize. It stops when there are no more rows, or
+// when f returns io.EOF.
+func (db *DB) RunQueryIncrementally(ctx context.Context, query string, batchSize int, f func(*sql.Rows) error, params ...interface{}) (err error) {
+	// Run in a transaction, because cursors require one.
+	return db.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+		// Declare a cursor and associate it with the query.
+		// It will be closed when the transaction commits.
+		_, err = tx.Exec(ctx, fmt.Sprintf(`DECLARE c CURSOR FOR %s`, query), params...)
+		if err != nil {
+			return err
+		}
+		for {
+			// Fetch batchSize rows and process them.
+			rows, err := tx.Query(ctx, fmt.Sprintf(`FETCH %d FROM c`, batchSize))
+			if err != nil {
+				return err
+			}
+			n, err := processRows(rows, f)
+			// Stop if there were no rows, or the processing function returned io.EOF.
+			if n == 0 || err == io.EOF {
+				return nil
+			}
+			if err != nil {
+				return err
+			}
+		}
+	})
 }
 
 // Transact executes the given function in the context of a SQL transaction at
@@ -369,7 +404,7 @@
 			if err != nil {
 				return err
 			}
-			err = processRows(rows, scanFunc)
+			_, err = processRows(rows, scanFunc)
 		}
 		if err != nil {
 			return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err)
diff --git a/internal/database/database_test.go b/internal/database/database_test.go
index 911e8f8..8111258 100644
--- a/internal/database/database_test.go
+++ b/internal/database/database_test.go
@@ -9,6 +9,7 @@
 	"database/sql"
 	"errors"
 	"fmt"
+	"io"
 	"log"
 	"os"
 	"sort"
@@ -502,3 +503,57 @@
 		t.Errorf("got %v, wanted code %s", gerr, constraintViolationCode)
 	}
 }
+
+func TestRunQueryIncrementally(t *testing.T) {
+	ctx := context.Background()
+	for _, stmt := range []string{
+		`DROP TABLE IF EXISTS test_rqi`,
+		`CREATE TABLE test_rqi (i  INTEGER PRIMARY KEY)`,
+		`INSERT INTO test_rqi (i) VALUES (1), (2), (3), (4), (5)`,
+	} {
+		if _, err := testDB.Exec(ctx, stmt); err != nil {
+			t.Fatal(err)
+		}
+	}
+	query := `SELECT i FROM test_rqi ORDER BY i LIMIT $1`
+	var got []int
+
+	// Run until all rows consumed.
+	err := testDB.RunQueryIncrementally(ctx, query, 2, func(rows *sql.Rows) error {
+		var i int
+		if err := rows.Scan(&i); err != nil {
+			return err
+		}
+		got = append(got, i)
+		return nil
+	}, 4)
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := []int{1, 2, 3, 4}
+	if !cmp.Equal(got, want) {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// Stop early.
+	got = nil
+	err = testDB.RunQueryIncrementally(ctx, query, 2, func(rows *sql.Rows) error {
+		var i int
+		if err := rows.Scan(&i); err != nil {
+			return err
+		}
+		got = append(got, i)
+		if len(got) == 3 {
+			return io.EOF
+		}
+		return nil
+	}, 10)
+	if err != nil {
+		t.Fatal(err)
+	}
+	want = []int{1, 2, 3}
+	if !cmp.Equal(got, want) {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+}
diff --git a/internal/postgres/search.go b/internal/postgres/search.go
index 0d86680..a3981cd 100644
--- a/internal/postgres/search.go
+++ b/internal/postgres/search.go
@@ -8,6 +8,7 @@
 	"context"
 	"database/sql"
 	"fmt"
+	"io"
 	"math"
 	"sort"
 	"strings"
@@ -124,7 +125,9 @@
 	queryLimit := limit
 	if experiment.IsActive(ctx, internal.ExperimentSearchGrouping) {
 		// Gather extra results for better grouping by module and series.
-		queryLimit *= 5
+		// Since deep search is using incremental querying, we can make this large.
+		// TODO(jba): For performance, modify the popular_search stored procedure.
+		queryLimit *= 100
 	}
 
 	var searchers map[string]searcher
@@ -271,17 +274,47 @@
 		WHERE r.score > 0.1
 		LIMIT $2
 		OFFSET $3`, scoreExpr)
-	var results []*internal.SearchResult
-	collect := func(rows *sql.Rows) error {
-		var r internal.SearchResult
-		if err := rows.Scan(&r.PackagePath, &r.Version, &r.ModulePath, &r.CommitTime,
-			&r.NumImportedBy, &r.Score, &r.NumResults); err != nil {
-			return fmt.Errorf("rows.Scan(): %v", err)
+
+	var (
+		results []*internal.SearchResult
+		collect func(rows *sql.Rows) error
+		err     error
+	)
+	if experiment.IsActive(ctx, internal.ExperimentSearchGrouping) {
+		modulePaths := map[string]bool{}
+		const pageSize = 10  // TODO(jba): get from elsewhere
+		additionalRows := 10 // after reaching pageSize module paths
+		collect = func(rows *sql.Rows) error {
+			var r internal.SearchResult
+			if err := rows.Scan(&r.PackagePath, &r.Version, &r.ModulePath, &r.CommitTime,
+				&r.NumImportedBy, &r.Score, &r.NumResults); err != nil {
+				return fmt.Errorf("rows.Scan(): %v", err)
+			}
+			results = append(results, &r)
+			// Stop a few rows after we've seen pageSize module paths.
+			modulePaths[r.ModulePath] = true
+			if len(modulePaths) >= pageSize {
+				additionalRows--
+				if additionalRows <= 0 {
+					return io.EOF
+				}
+			}
+			return nil
 		}
-		results = append(results, &r)
-		return nil
+		const fetchSize = 10 // number of rows to fetch at a time
+		err = db.db.RunQueryIncrementally(ctx, query, fetchSize, collect, q, limit, offset)
+	} else {
+		collect = func(rows *sql.Rows) error {
+			var r internal.SearchResult
+			if err := rows.Scan(&r.PackagePath, &r.Version, &r.ModulePath, &r.CommitTime,
+				&r.NumImportedBy, &r.Score, &r.NumResults); err != nil {
+				return fmt.Errorf("rows.Scan(): %v", err)
+			}
+			results = append(results, &r)
+			return nil
+		}
+		err = db.db.RunQuery(ctx, query, collect, q, limit, offset)
 	}
-	err := db.db.RunQuery(ctx, query, collect, q, limit, offset)
 	if err != nil {
 		results = nil
 	}
diff --git a/internal/postgres/search_test.go b/internal/postgres/search_test.go
index 0ca5969..e28f601 100644
--- a/internal/postgres/search_test.go
+++ b/internal/postgres/search_test.go
@@ -209,6 +209,15 @@
 	// Cannot be run in parallel with other search tests, because it reads
 	// metrics before and after (see responseDelta below).
 	ctx := context.Background()
+	t.Run("no grouping", func(t *testing.T) {
+		testSearch(t, ctx)
+	})
+	t.Run("grouping", func(t *testing.T) {
+		testSearch(t, experiment.NewContext(ctx, internal.ExperimentSearchGrouping))
+	})
+}
+
+func testSearch(t *testing.T, ctx context.Context) {
 	tests := []struct {
 		label       string
 		modules     []*internal.Module