internal/postgres/{symbolsearch}: update queries

The symbolsearch queries are rewritten to filter on
symbol_search_documents.symbol_name, instead of symbols.name or
symbols.tsv_name_tokens.

This significantly improves performance on the first search because:

(1) We no longer need to fetch symbol IDs first, before performing the
    search query on symbol_search_documents, which can save ~100-200ms.
(2) Searching on the lower(symbol_name) index is faster than the
    symbol_name_id index. This likely because the lower(symbol_name) =
    lower($1) fetches rows that are paged together, whereas the
    symbol_name_ids maybe completely random.

For golang/go#44142

Change-Id: Ibe78676d4f5424dcace9da544f46a6445a06b160
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/342471
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/postgres/symbol_test.go b/internal/postgres/symbol_test.go
index 5d9431d..8b56ec3 100644
--- a/internal/postgres/symbol_test.go
+++ b/internal/postgres/symbol_test.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"database/sql"
+	"errors"
 	"fmt"
 	"sort"
 	"testing"
@@ -529,3 +530,35 @@
 	}
 	return buildToSymbols, nil
 }
+
+func TestSplitSymbolName(t *testing.T) {
+	for _, test := range []struct {
+		q, wantPkg, wantSym string
+	}{
+		{"sql.DB", "sql", "DB"},
+		{"sql.DB.Begin", "sql", "DB.Begin"},
+	} {
+		t.Run(test.q, func(t *testing.T) {
+			pkg, symbol, err := splitPackageAndSymbolNames(test.q)
+			if err != nil || pkg != test.wantPkg || symbol != test.wantSym {
+				t.Errorf("splitPackageAndSymbolNames(%q) = %q, %q, %v; want = %q, %q, nil",
+					test.q, pkg, symbol, err, test.wantPkg, test.wantSym)
+			}
+		})
+	}
+
+	for _, test := range []string{
+		"DB",
+		".DB",
+		"sql.",
+		"sql.DB.Begin.Blah",
+	} {
+		t.Run(test, func(t *testing.T) {
+			pkg, symbol, err := splitPackageAndSymbolNames(test)
+			if !errors.Is(err, derrors.NotFound) {
+				t.Errorf("splitPackageAndSymbolNames(%q) = %q, %q, %v; want %v",
+					test, pkg, symbol, err, derrors.NotFound)
+			}
+		})
+	}
+}
diff --git a/internal/postgres/symbolsearch.go b/internal/postgres/symbolsearch.go
index 5165d94..04b3fa1 100644
--- a/internal/postgres/symbolsearch.go
+++ b/internal/postgres/symbolsearch.go
@@ -113,7 +113,7 @@
 	case symbolsearch.InputTypeNoDot:
 		results, err = runSymbolSearch(ctx, db.db, symbolsearch.SearchTypeSymbol, q, limit)
 	case symbolsearch.InputTypeTwoDots:
-		results, err = runSymbolSearch(ctx, db.db, symbolsearch.SearchTypePackageDotSymbol, q, limit, q)
+		results, err = runSymbolSearchPackageDotSymbol(ctx, db.db, q, limit)
 	default:
 		// There is no supported situation where we will get results for one
 		// element containing more than 2 dots.
@@ -161,6 +161,11 @@
 		// There are no words in the query that could be a symbol name.
 		return nil, derrors.NotFound
 	}
+	if strings.Contains(q, "|") {
+		// TODO(golang/go#44142): The symbolsearch.SearchTypeMultiWordOr case
+		// is currently not supported.
+		return nil, derrors.NotFound
+	}
 	group, searchCtx := errgroup.WithContext(ctx)
 	resultsArray := make([][]*SearchResult, len(symbolToPathTokens))
 	count := 0
@@ -171,17 +176,7 @@
 		count += 1
 		group.Go(func() error {
 			st := symbolsearch.SearchTypeMultiWordExact
-			if strings.Contains(q, "|") {
-				st = symbolsearch.SearchTypeMultiWordOr
-			}
-			ids, err := fetchMatchingSymbolIDs(searchCtx, ddb, st, symbol)
-			if err != nil {
-				if !errors.Is(err, derrors.NotFound) {
-					return err
-				}
-				return nil
-			}
-			r, err := fetchSymbolSearchResults(ctx, ddb, st, ids, limit, pathTokens)
+			r, err := runSymbolSearch(searchCtx, ddb, st, symbol, limit, pathTokens)
 			if err != nil {
 				return err
 			}
@@ -263,11 +258,15 @@
 		i := i
 		st := st
 		group.Go(func() error {
-			var args []interface{}
+			var (
+				results []*SearchResult
+				err     error
+			)
 			if st == symbolsearch.SearchTypePackageDotSymbol {
-				args = append(args, q)
+				results, err = runSymbolSearchPackageDotSymbol(searchCtx, ddb, q, limit)
+			} else {
+				results, err = runSymbolSearch(searchCtx, ddb, st, q, limit)
 			}
-			results, err := runSymbolSearch(searchCtx, ddb, st, q, limit, args...)
 			if err != nil {
 				return err
 			}
@@ -281,55 +280,34 @@
 	return mergedResults(resultsArray, limit), nil
 }
 
+func runSymbolSearchPackageDotSymbol(ctx context.Context, ddb *database.DB, q string, limit int) (_ []*SearchResult, err error) {
+	pkg, symbol, err := splitPackageAndSymbolNames(q)
+	if err != nil {
+		return nil, err
+	}
+	return runSymbolSearch(ctx, ddb, symbolsearch.SearchTypePackageDotSymbol, symbol, limit, pkg)
+}
+
+func splitPackageAndSymbolNames(q string) (pkgName string, symbolName string, err error) {
+	parts := strings.Split(q, ".")
+	if len(parts) != 2 && len(parts) != 3 {
+		return "", "", derrors.NotFound
+	}
+	for _, p := range parts {
+		// Handle cases where we have odd dot placement, such as .Foo or
+		// Foo..
+		if p == "" {
+			return "", "", derrors.NotFound
+		}
+	}
+	return parts[0], strings.Join(parts[1:], "."), nil
+}
+
 func runSymbolSearch(ctx context.Context, ddb *database.DB,
-	st symbolsearch.SearchType, q string, limit int, args ...interface{}) (_ []*SearchResult, err error) {
+	st symbolsearch.SearchType, q string, limit int, args ...interface{}) (results []*SearchResult, err error) {
 	defer derrors.Wrap(&err, "runSymbolSearch(ctx, ddb, %q, %q, %d, %v)", st, q, limit, args)
 	defer middleware.ElapsedStat(ctx, fmt.Sprintf("%s-runSymbolSearch", st))()
 
-	ids, err := fetchMatchingSymbolIDs(ctx, ddb, st, q)
-	if err != nil {
-		if errors.Is(err, derrors.NotFound) {
-			return nil, nil
-		}
-		return nil, err
-	}
-	return fetchSymbolSearchResults(ctx, ddb, st, ids, limit, args...)
-}
-
-// fetchMatchingSymbolIDs fetches the symbol ids to be used for a given
-// symbolsearch.SearchType. It runs the query returned by
-// symbolsearch.MatchingSymbolIDsQuery. The ids returned will be used by in
-// runSymbolSearch.
-func fetchMatchingSymbolIDs(ctx context.Context, ddb *database.DB, st symbolsearch.SearchType, q string) (_ []int, err error) {
-	defer derrors.Wrap(&err, "fetchMatchingSymbolIDs(ctx, ddb, %q, %q)", st, q)
-	defer middleware.ElapsedStat(ctx, fmt.Sprintf("%s-fetchMatchingSymbolIDs", st))()
-
-	var ids []int
-	collect := func(rows *sql.Rows) error {
-		var id int
-		if err := rows.Scan(&id); err != nil {
-			return err
-		}
-		ids = append(ids, id)
-		return nil
-	}
-	query := symbolsearch.MatchingSymbolIDsQuery(st)
-	if err := ddb.RunQuery(ctx, query, collect, q); err != nil {
-		return nil, err
-	}
-	if len(ids) == 0 {
-		return nil, derrors.NotFound
-	}
-	return ids, nil
-}
-
-// fetchSymbolSearchResults executes a symbol search for the given
-// symbolsearch.SearchType and args.
-func fetchSymbolSearchResults(ctx context.Context, ddb *database.DB,
-	st symbolsearch.SearchType, ids []int, limit int, args ...interface{}) (results []*SearchResult, err error) {
-	defer derrors.Wrap(&err, "fetchSymbolSearchResults(ctx, ddb, %q, ids: %v, limit:  %d, args: %v)", st.String(), ids, limit, args)
-	defer middleware.ElapsedStat(ctx, fmt.Sprintf("%s-fetchSymbolSearchResults", st))()
-
 	collect := func(rows *sql.Rows) error {
 		var r SearchResult
 		if err := rows.Scan(
@@ -352,7 +330,7 @@
 		return nil
 	}
 	query := symbolsearch.Query(st)
-	args = append([]interface{}{pq.Array(ids), limit}, args...)
+	args = append([]interface{}{q, limit}, args...)
 	if err := ddb.RunQuery(ctx, query, collect, args...); err != nil {
 		return nil, err
 	}
diff --git a/internal/postgres/symbolsearch/content.go b/internal/postgres/symbolsearch/content.go
index 5a07684..c091545 100644
--- a/internal/postgres/symbolsearch/content.go
+++ b/internal/postgres/symbolsearch/content.go
@@ -26,34 +26,10 @@
 
 // querySearchMultiWordExact is used when the search query is multiple elements.
 %s
-
-// querySearchMultiWordOr is used when the search query is multiple elements.
-%s
-
-// queryMatchingSymbolIDsSymbol is used to find the matching symbol
-// ids when the SearchType is SearchTypeSymbol.
-%s
-
-// queryMatchingSymbolIDsPackageDotSymbol is used to find the matching symbol
-// ids when the SearchType is SearchTypePackageDotSymbol.
-%s
-
-// queryMatchingSymbolIDsMultiWordExact is used to find the matching symbol ids when
-// the SearchType is SearchTypeMultiWordExact.
-%s
-
-// queryMatchingSymbolIDsMultiWordOr is used to find the matching symbol ids when
-// the SearchType is SearchTypeMultiWordOr.
-%s
 `,
 	formatQuery("querySearchSymbol", Query(SearchTypeSymbol)),
 	formatQuery("querySearchPackageDotSymbol", Query(SearchTypePackageDotSymbol)),
-	formatQuery("querySearchMultiWordOr", Query(SearchTypeMultiWordOr)),
-	formatQuery("querySearchMultiWordExact", Query(SearchTypeMultiWordExact)),
-	formatQuery("queryMatchingSymbolIDsSymbol", MatchingSymbolIDsQuery(SearchTypeSymbol)),
-	formatQuery("queryMatchingSymbolIDsPackageDotSymbol", MatchingSymbolIDsQuery(SearchTypePackageDotSymbol)),
-	formatQuery("queryMatchingSymbolIDsMultiWordOr", MatchingSymbolIDsQuery(SearchTypeMultiWordOr)),
-	formatQuery("queryMatchingSymbolIDsMultiWordExact", MatchingSymbolIDsQuery(SearchTypeMultiWordExact)))
+	formatQuery("querySearchMultiWordExact", Query(SearchTypeMultiWordExact)))
 
 func formatQuery(name, query string) string {
 	return fmt.Sprintf("const %s = `%s`", name, query)
diff --git a/internal/postgres/symbolsearch/query.gen.go b/internal/postgres/symbolsearch/query.gen.go
index b360aab..ee038ae 100644
--- a/internal/postgres/symbolsearch/query.gen.go
+++ b/internal/postgres/symbolsearch/query.gen.go
@@ -19,8 +19,8 @@
 		ssd.goarch,
 		ssd.ln_imported_by_count AS score
 	FROM symbol_search_documents ssd
-	WHERE
-		symbol_name_id = ANY($1)
+	WHERE 
+		lower(symbol_name) = lower($1)
 	ORDER BY
 		score DESC,
 		package_path,
@@ -60,12 +60,12 @@
 		ssd.goarch,
 		ssd.ln_imported_by_count AS score
 	FROM symbol_search_documents ssd
-	WHERE
-		symbol_name_id = ANY($1)
-	AND (
-		ssd.uuid_package_name=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1)) OR
-		ssd.uuid_package_path=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))
-	)
+	WHERE 
+		lower(symbol_name) = lower($1)
+		AND (
+			ssd.uuid_package_name=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1)) OR
+			ssd.uuid_package_path=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))
+		)
 	ORDER BY
 		score DESC,
 		package_path,
@@ -93,50 +93,6 @@
 ORDER BY score DESC;`
 
 // querySearchMultiWordExact is used when the search query is multiple elements.
-const querySearchMultiWordOr = `
-WITH ssd AS (
-	SELECT
-		ssd.unit_id,
-		ssd.package_symbol_id,
-		ssd.symbol_name_id,
-		ssd.goos,
-		ssd.goarch,
-		(
-			ts_rank(
-				'{0.1, 0.2, 1.0, 1.0}',
-				sd.tsv_path_tokens,
-				to_tsquery('symbols', quote_literal(replace($3, '_', '-')))
-			) * ssd.ln_imported_by_count
-		) AS score
-	FROM symbol_search_documents ssd
-	INNER JOIN search_documents sd ON sd.package_path_id = ssd.package_path_id
-	WHERE
-		symbol_name_id = ANY($1)
-		AND sd.tsv_path_tokens @@ to_tsquery('symbols', quote_literal(replace($3, '_', '-')))
-	ORDER BY score DESC
-	LIMIT $2
-)
-SELECT
-	s.name AS symbol_name,
-	sd.package_path,
-	sd.module_path,
-	sd.version,
-	sd.name,
-	sd.synopsis,
-	sd.license_types,
-	sd.commit_time,
-	sd.imported_by_count,
-	ssd.goos,
-	ssd.goarch,
-	ps.type AS symbol_kind,
-	ps.synopsis AS symbol_synopsis
-FROM ssd
-INNER JOIN symbol_names s ON s.id=ssd.symbol_name_id
-INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
-INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
-ORDER BY score DESC;`
-
-// querySearchMultiWordOr is used when the search query is multiple elements.
 const querySearchMultiWordExact = `
 WITH ssd AS (
 	SELECT
@@ -155,7 +111,7 @@
 	FROM symbol_search_documents ssd
 	INNER JOIN search_documents sd ON sd.package_path_id = ssd.package_path_id
 	WHERE
-		symbol_name_id = ANY($1)
+		lower(symbol_name) = lower($1)
 		AND sd.tsv_path_tokens @@ to_tsquery('symbols', quote_literal(replace($3, '_', '-')))
 	ORDER BY score DESC
 	LIMIT $2
@@ -179,31 +135,3 @@
 INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
 INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
 ORDER BY score DESC;`
-
-// queryMatchingSymbolIDsSymbol is used to find the matching symbol
-// ids when the SearchType is SearchTypeSymbol.
-const queryMatchingSymbolIDsSymbol = `
-		SELECT id
-		FROM symbol_names
-		WHERE lower(name) = lower($1)`
-
-// queryMatchingSymbolIDsPackageDotSymbol is used to find the matching symbol
-// ids when the SearchType is SearchTypePackageDotSymbol.
-const queryMatchingSymbolIDsPackageDotSymbol = `
-		SELECT id
-		FROM symbol_names
-		WHERE lower(name) = lower(substring($1 from E'[^.]*\.(.+)$'))`
-
-// queryMatchingSymbolIDsMultiWordExact is used to find the matching symbol ids when
-// the SearchType is SearchTypeMultiWordExact.
-const queryMatchingSymbolIDsMultiWordOr = `
-		SELECT id
-		FROM symbol_names
-		WHERE tsv_name_tokens @@ to_tsquery('symbols', quote_literal(replace(replace($1, '_', '-'), ' ', ' | ')))`
-
-// queryMatchingSymbolIDsMultiWordOr is used to find the matching symbol ids when
-// the SearchType is SearchTypeMultiWordOr.
-const queryMatchingSymbolIDsMultiWordExact = `
-		SELECT id
-		FROM symbol_names
-		WHERE lower(name) = lower($1)`
diff --git a/internal/postgres/symbolsearch/symbolsearch.go b/internal/postgres/symbolsearch/symbolsearch.go
index 5a428c3..70b9f64 100644
--- a/internal/postgres/symbolsearch/symbolsearch.go
+++ b/internal/postgres/symbolsearch/symbolsearch.go
@@ -17,21 +17,29 @@
 
 // Query returns a symbol search query to be used in internal/postgres.
 // Each query that is returned accepts the following args:
-// $1 = ids
+// $1 = query
 // $2 = limit
-// $3 = search query input (not used by SearchTypeSymbol)
+// $3 = only used by multi-word-exact for path tokens
 func Query(st SearchType) string {
-	var filter string
 	switch st {
-	case SearchTypeMultiWordOr, SearchTypeMultiWordExact:
-		return fmt.Sprintf(baseQuery, multiwordCTE())
+	case SearchTypeMultiWordExact:
+		return fmt.Sprintf(baseQuery, multiwordCTE)
 	case SearchTypePackageDotSymbol:
-		// PackageDotSymbol case.
-		filter = filterPackageDotSymbol
+		// When $1 is either <package>.<symbol> OR
+		// <package>.<type>.<methodOrField>, only match on the exact
+		// symbol name.
+		return fmt.Sprintf(baseQuery, fmt.Sprintf(symbolCTE, filterPackageDotSymbol))
 	case SearchTypeSymbol:
-		filter = ""
+		// When $1 is the full symbol name, either <symbol> or
+		// <type>.<methodOrField>, match on just the identifier name.
+		//
+		// Matching on just <field> and <method> is too slow at the moment (can
+		// take several seconds to return results), but we
+		// might want to add support for that later. For example, searching for
+		// "Begin" should return "DB.Begin".
+		return fmt.Sprintf(baseQuery, fmt.Sprintf(symbolCTE, filterSymbol))
 	}
-	return fmt.Sprintf(baseQuery, fmt.Sprintf(symbolCTE, filter))
+	return ""
 }
 
 const symbolCTE = `
@@ -43,8 +51,7 @@
 		ssd.goarch,
 		ssd.ln_imported_by_count AS score
 	FROM symbol_search_documents ssd
-	WHERE
-		symbol_name_id = ANY($1)%s
+	WHERE %s
 	ORDER BY
 		score DESC,
 		package_path,
@@ -52,17 +59,21 @@
 	LIMIT $2
 `
 
+const filterSymbol = `
+		lower(symbol_name) = lower($1)`
+
 // TODO(golang/go#44142): Filtering on package path currently only works for
 // standard library packages, since non-standard library packages will have a
 // dot.
 var filterPackageDotSymbol = fmt.Sprintf(`
-	AND (
-		ssd.uuid_package_name=%s OR
-		ssd.uuid_package_path=%[1]s
-	)`, "uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))")
+		lower(symbol_name) = lower($1)
+		AND (
+			ssd.uuid_package_name=%[1]s OR
+			ssd.uuid_package_path=%[1]s
+		)`,
+	"uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))")
 
-func multiwordCTE() string {
-	return fmt.Sprintf(`
+var multiwordCTE = fmt.Sprintf(`
 	SELECT
 		ssd.unit_id,
 		ssd.package_symbol_id,
@@ -73,18 +84,17 @@
 			ts_rank(
 				'{0.1, 0.2, 1.0, 1.0}',
 				sd.tsv_path_tokens,
-				%s
+				%[1]s
 			) * ssd.ln_imported_by_count
 		) AS score
 	FROM symbol_search_documents ssd
 	INNER JOIN search_documents sd ON sd.package_path_id = ssd.package_path_id
 	WHERE
-		symbol_name_id = ANY($1)
+		lower(symbol_name) = lower($1)
 		AND sd.tsv_path_tokens @@ %[1]s
 	ORDER BY score DESC
 	LIMIT $2
 `, toTSQuery("$3"))
-}
 
 const baseQuery = `
 WITH ssd AS (%s)
@@ -108,39 +118,6 @@
 INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
 ORDER BY score DESC;`
 
-// MatchingSymbolIDsQuery returns a query to fetch the symbol ids that match the
-// search input, based on the SearchType.
-func MatchingSymbolIDsQuery(st SearchType) string {
-	var filter string
-	switch st {
-	case SearchTypeSymbol, SearchTypeMultiWordExact:
-		// When $1 is the full symbol name, either <symbol> or
-		// <type>.<methodOrField>, match on just the identifier name.
-		//
-		// Matching on just <field> and <method> is too slow at the moment (can
-		// take several seconds to return results), but we
-		// might want to add support for that later. For example, searching for
-		// "Begin" should return "DB.Begin".
-		filter = `lower(name) = lower($1)`
-	case SearchTypePackageDotSymbol:
-		// When $1 is either <package>.<symbol> OR
-		// <package>.<type>.<methodOrField>, only match on the exact
-		// symbol name.
-		filter = fmt.Sprintf("lower(name) = lower(%s)", "substring($1 from E'[^.]*\\.(.+)$')")
-	case SearchTypeMultiWordOr:
-		// When $1 contains multiple words, separated by spaces, at least one
-		// element for the query must match a symbol name.
-		//
-		// TODO(44142): This is currently somewhat slow, since many IDs can be
-		// returned.
-		filter = fmt.Sprintf(`tsv_name_tokens @@ %s`, toTSQuery("replace($1, ' ', ' | ')"))
-	}
-	return fmt.Sprintf(`
-		SELECT id
-		FROM symbol_names
-		WHERE %s`, filter)
-}
-
 func toTSQuery(arg string) string {
 	return fmt.Sprintf("to_tsquery('%s', quote_literal(%s))", SymbolTextSearchConfiguration, processArg(arg))
 }
diff --git a/internal/postgres/symbolsearch/symbolsearch_test.go b/internal/postgres/symbolsearch/symbolsearch_test.go
index 0addd97..b0fb0f1 100644
--- a/internal/postgres/symbolsearch/symbolsearch_test.go
+++ b/internal/postgres/symbolsearch/symbolsearch_test.go
@@ -59,11 +59,6 @@
 		{"querySearchSymbol", Query(SearchTypeSymbol), querySearchSymbol},
 		{"querySearchPackageDotSymbol", Query(SearchTypePackageDotSymbol), querySearchPackageDotSymbol},
 		{"querySearchMultiWordExact", Query(SearchTypeMultiWordExact), querySearchMultiWordExact},
-		{"querySearchMultiWordOr", Query(SearchTypeMultiWordOr), querySearchMultiWordOr},
-		{"queryMatchingSymbolIDsSymbol", MatchingSymbolIDsQuery(SearchTypeSymbol), queryMatchingSymbolIDsSymbol},
-		{"queryMatchingSymbolIDsPackageDotSymbol", MatchingSymbolIDsQuery(SearchTypePackageDotSymbol), queryMatchingSymbolIDsPackageDotSymbol},
-		{"queryMatchingSymbolIDsMultiWordExact", MatchingSymbolIDsQuery(SearchTypeMultiWordExact), queryMatchingSymbolIDsMultiWordExact},
-		{"queryMatchingSymbolIDsMultiWordOr", MatchingSymbolIDsQuery(SearchTypeMultiWordOr), queryMatchingSymbolIDsMultiWordOr},
 	} {
 		t.Run(test.name, func(t *testing.T) {
 			if diff := cmp.Diff(test.want, test.q); diff != "" {