internal/postgres: use new symbol search queries

The new symbol search queries created in CL 338051 are now used instead
of the legacy symbol search queries.

For golang/go#44142

Change-Id: I9edbdb75c6fd06df86bb9e6bce64a0499d1481b0
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/338052
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/symbolsearch.go b/internal/postgres/symbolsearch.go
index 4ed0632..db12448 100644
--- a/internal/postgres/symbolsearch.go
+++ b/internal/postgres/symbolsearch.go
@@ -167,6 +167,11 @@
 func runSymbolSearch(ctx context.Context, ddb *database.DB, st symbolsearch.SearchType, q string, limit int) (_ []*SearchResult, err error) {
 	defer derrors.Wrap(&err, "runSymbolSearch(ctx, ddb, query, %q, %d)", q, limit)
 
+	ids, err := fetchMatchingSymbolIDs(ctx, ddb, st, q)
+	if err != nil {
+		return nil, err
+	}
+
 	var results []*SearchResult
 	collect := func(rows *sql.Rows) error {
 		var r SearchResult
@@ -190,8 +195,34 @@
 		return nil
 	}
 	query := symbolsearch.Query(st)
-	if err := ddb.RunQuery(ctx, query, collect, q, limit); err != nil {
+	args := []interface{}{pq.Array(ids), limit}
+	if st != symbolsearch.SearchTypeSymbol {
+		args = append(args, q)
+	}
+	if err := ddb.RunQuery(ctx, query, collect, args...); err != nil {
 		return nil, err
 	}
 	return results, nil
 }
+
+// fetchMatchingSymbolIDs fetches the symbol ids to be used for a given
+// symbolsearch.SearchType. It runs the query returned by
+// symbolsearch.MatchingSymbolIDsQuery. The ids returned will be used by in
+// runSymbolSearch.
+func fetchMatchingSymbolIDs(ctx context.Context, ddb *database.DB, st symbolsearch.SearchType, q string) (_ []int, err error) {
+	defer derrors.Wrap(&err, "fetchMatchingSymbolIDs(ctx, ddb, %d, %q)", st, q)
+	var ids []int
+	collect := func(rows *sql.Rows) error {
+		var id int
+		if err := rows.Scan(&id); err != nil {
+			return err
+		}
+		ids = append(ids, id)
+		return nil
+	}
+	query := symbolsearch.MatchingSymbolIDsQuery(st)
+	if err := ddb.RunQuery(ctx, query, collect, q); err != nil {
+		return nil, err
+	}
+	return ids, nil
+}
diff --git a/internal/postgres/symbolsearch/content.go b/internal/postgres/symbolsearch/content.go
index d016a9a..5112947 100644
--- a/internal/postgres/symbolsearch/content.go
+++ b/internal/postgres/symbolsearch/content.go
@@ -39,25 +39,25 @@
 // the SearchType is SearchTypeMultiWord.
 %s
 
-// oldQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
+// legacyQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
 %s
 
-// oldQueryPackageDotSymbol - TODO(golang/go#44142): replace with
+// legacyQueryPackageDotSymbol - TODO(golang/go#44142): replace with
 // querySearchPackageDotSymbol.
 %s
 
-// oldQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
+// legacyQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
 %s
 `,
-	formatQuery("querySearchSymbol", newQuery(SearchTypeSymbol)),
-	formatQuery("querySearchPackageDotSymbol", newQuery(SearchTypePackageDotSymbol)),
-	formatQuery("querySearchMultiWord", newQuery(SearchTypeMultiWord)),
-	formatQuery("queryMatchingSymbolIDsSymbol", matchingIDsQuery(SearchTypeSymbol)),
-	formatQuery("queryMatchingSymbolIDsPackageDotSymbol", matchingIDsQuery(SearchTypePackageDotSymbol)),
-	formatQuery("queryMatchingSymbolIDsMultiWord", matchingIDsQuery(SearchTypeMultiWord)),
-	formatQuery("oldQuerySymbol", rawQuerySymbol),
-	formatQuery("oldQueryPackageDotSymbol", rawQueryPackageDotSymbol),
-	formatQuery("oldQueryMultiWord", rawQueryMultiWord))
+	formatQuery("querySearchSymbol", Query(SearchTypeSymbol)),
+	formatQuery("querySearchPackageDotSymbol", Query(SearchTypePackageDotSymbol)),
+	formatQuery("querySearchMultiWord", Query(SearchTypeMultiWord)),
+	formatQuery("queryMatchingSymbolIDsSymbol", MatchingSymbolIDsQuery(SearchTypeSymbol)),
+	formatQuery("queryMatchingSymbolIDsPackageDotSymbol", MatchingSymbolIDsQuery(SearchTypePackageDotSymbol)),
+	formatQuery("queryMatchingSymbolIDsMultiWord", MatchingSymbolIDsQuery(SearchTypeMultiWord)),
+	formatQuery("legacyQuerySymbol", rawLegacyQuerySymbol),
+	formatQuery("legacyQueryPackageDotSymbol", rawLegacyQueryPackageDotSymbol),
+	formatQuery("legacyQueryMultiWord", rawLegacyQueryMultiWord))
 
 func formatQuery(name, query string) string {
 	return fmt.Sprintf("const %s = `%s`", name, query)
diff --git a/internal/postgres/symbolsearch/legacy_symbolsearch.go b/internal/postgres/symbolsearch/legacy_symbolsearch.go
index 16c15ad..3057304 100644
--- a/internal/postgres/symbolsearch/legacy_symbolsearch.go
+++ b/internal/postgres/symbolsearch/legacy_symbolsearch.go
@@ -20,25 +20,11 @@
 const SymbolTextSearchConfiguration = "symbols"
 
 var (
-	rawQuerySymbol           = constructQuery(filterSymbol)
-	rawQueryPackageDotSymbol = constructQuery(filterPackageDotSymbol)
-	rawQueryMultiWord        = constructQuery(filterMultiWord)
+	rawLegacyQuerySymbol           = constructQuery(filterSymbol)
+	rawLegacyQueryPackageDotSymbol = constructQuery(filterPackageDotSymbol)
+	rawLegacyQueryMultiWord        = constructQuery(filterMultiWord)
 )
 
-// Query returns a search query to be used in internal/postgres for symbol
-// search.
-func Query(st SearchType) string {
-	switch st {
-	case SearchTypeSymbol:
-		return rawQuerySymbol
-	case SearchTypePackageDotSymbol:
-		return rawQueryPackageDotSymbol
-	case SearchTypeMultiWord:
-		return rawQueryMultiWord
-	}
-	return ""
-}
-
 // constructQuery is used to construct a symbol search query.
 func constructQuery(where string) string {
 	// When there is only one word in the query, popularity is the only score
diff --git a/internal/postgres/symbolsearch/query.gen.go b/internal/postgres/symbolsearch/query.gen.go
index 3c7df3c..8dfa7a5 100644
--- a/internal/postgres/symbolsearch/query.gen.go
+++ b/internal/postgres/symbolsearch/query.gen.go
@@ -151,8 +151,8 @@
 		FROM symbol_names
 		WHERE tsv_name_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | '))`
 
-// oldQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
-const oldQuerySymbol = `
+// legacyQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
+const legacyQuerySymbol = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
@@ -197,9 +197,9 @@
 	package_path
 LIMIT $2;`
 
-// oldQueryPackageDotSymbol - TODO(golang/go#44142): replace with
+// legacyQueryPackageDotSymbol - TODO(golang/go#44142): replace with
 // querySearchPackageDotSymbol.
-const oldQueryPackageDotSymbol = `
+const legacyQueryPackageDotSymbol = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
@@ -246,8 +246,8 @@
 	package_path
 LIMIT $2;`
 
-// oldQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
-const oldQueryMultiWord = `
+// legacyQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
+const legacyQueryMultiWord = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
diff --git a/internal/postgres/symbolsearch/symbolsearch.go b/internal/postgres/symbolsearch/symbolsearch.go
index b5c7b78..629906b 100644
--- a/internal/postgres/symbolsearch/symbolsearch.go
+++ b/internal/postgres/symbolsearch/symbolsearch.go
@@ -8,7 +8,12 @@
 	"fmt"
 )
 
-func newQuery(st SearchType) string {
+// Query returns a symbol search query to be used in internal/postgres.
+// Each query that is returned accepts the following args:
+// $1 = ids
+// $2 = limit
+// $3 = search query input (not used by SearchTypeSymbol)
+func Query(st SearchType) string {
 	var filter string
 	switch st {
 	case SearchTypeMultiWord:
@@ -94,9 +99,9 @@
 INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
 ORDER BY score DESC;`
 
-// matchingIDsQuery returns a query to fetch the symbol ids that match the
+// MatchingSymbolIDsQuery returns a query to fetch the symbol ids that match the
 // search input, based on the SearchType.
-func matchingIDsQuery(st SearchType) string {
+func MatchingSymbolIDsQuery(st SearchType) string {
 	var filter string
 	switch st {
 	case SearchTypeSymbol:
diff --git a/internal/postgres/symbolsearch/symbolsearch_test.go b/internal/postgres/symbolsearch/symbolsearch_test.go
index e1d5c23..f5527a3 100644
--- a/internal/postgres/symbolsearch/symbolsearch_test.go
+++ b/internal/postgres/symbolsearch/symbolsearch_test.go
@@ -16,15 +16,15 @@
 	for _, test := range []struct {
 		name, q, want string
 	}{
-		{"querySearchSymbol", newQuery(SearchTypeSymbol), querySearchSymbol},
-		{"querySearchPackageDotSymbol", newQuery(SearchTypePackageDotSymbol), querySearchPackageDotSymbol},
-		{"querySearchMultiWord", newQuery(SearchTypeMultiWord), querySearchMultiWord},
-		{"queryMatchingSymbolIDsSymbol", matchingIDsQuery(SearchTypeSymbol), queryMatchingSymbolIDsSymbol},
-		{"queryMatchingSymbolIDsPackageDotSymbol", matchingIDsQuery(SearchTypePackageDotSymbol), queryMatchingSymbolIDsPackageDotSymbol},
-		{"queryMatchingSymbolIDsMultiWord", matchingIDsQuery(SearchTypeMultiWord), queryMatchingSymbolIDsMultiWord},
-		{"oldQuerySymbol", rawQuerySymbol, oldQuerySymbol},
-		{"oldQueryPackageDotSymbol", rawQueryPackageDotSymbol, oldQueryPackageDotSymbol},
-		{"oldQueryMultiWord", rawQueryMultiWord, oldQueryMultiWord},
+		{"querySearchSymbol", Query(SearchTypeSymbol), querySearchSymbol},
+		{"querySearchPackageDotSymbol", Query(SearchTypePackageDotSymbol), querySearchPackageDotSymbol},
+		{"querySearchMultiWord", Query(SearchTypeMultiWord), querySearchMultiWord},
+		{"queryMatchingSymbolIDsSymbol", MatchingSymbolIDsQuery(SearchTypeSymbol), queryMatchingSymbolIDsSymbol},
+		{"queryMatchingSymbolIDsPackageDotSymbol", MatchingSymbolIDsQuery(SearchTypePackageDotSymbol), queryMatchingSymbolIDsPackageDotSymbol},
+		{"queryMatchingSymbolIDsMultiWord", MatchingSymbolIDsQuery(SearchTypeMultiWord), queryMatchingSymbolIDsMultiWord},
+		{"legacyQuerySymbol", rawLegacyQuerySymbol, legacyQuerySymbol},
+		{"legacyQueryPackageDotSymbol", rawLegacyQueryPackageDotSymbol, legacyQueryPackageDotSymbol},
+		{"legacyQueryMultiWord", rawLegacyQueryMultiWord, legacyQueryMultiWord},
 	} {
 		t.Run(test.name, func(t *testing.T) {
 			if diff := cmp.Diff(test.want, test.q); diff != "" {