internal/postgres/symbolsearch: add new queries

The symbolsearch queries are rewritten to first fetch the matching
symbol ids, then use these as input to the search query. Compared to
using CTEs, this performed significantly better (on the order of 1-2s,
sometimes minutes).

For golang/go#44142

Change-Id: I4703995b3f7f6423c8d424d908eb0ad6cdb87a92
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/338051
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/symbolsearch/content.go b/internal/postgres/symbolsearch/content.go
index 705fc98..d016a9a 100644
--- a/internal/postgres/symbolsearch/content.go
+++ b/internal/postgres/symbolsearch/content.go
@@ -14,21 +14,50 @@
 
 package symbolsearch
 
-// QuerySymbol is used when the search query is only one word, with no dots.
+// querySearchSymbol is used when the search query is only one word, with no dots.
 // In this case, the word must match a symbol name and ranking is completely
 // determined by the path_tokens.
 %s
 
-// QueryPackageDotSymbol is used when the search query is one element
+// querySearchPackageDotSymbol is used when the search query is one element
 // containing a dot, where the first part is assumed to be the package name and
 // the second the symbol name. For example, "sql.DB" or "sql.DB.Begin".
 %s
 
-// QueryMultiWord is used when the search query is multiple elements.
-%s`,
-	formatQuery("querySymbol", rawQuerySymbol),
-	formatQuery("queryPackageDotSymbol", rawQueryPackageDotSymbol),
-	formatQuery("queryMultiWord", rawQueryMultiWord))
+// querySearchMultiWord is used when the search query is multiple elements.
+%s
+
+// queryMatchingSymbolIDsSymbol is used to find the matching symbol
+// ids when the SearchType is SearchTypeSymbol.
+%s
+
+// queryMatchingSymbolIDsPackageDotSymbol is used to find the matching symbol
+// ids when the SearchType is SearchTypePackageDotSymbol.
+%s
+
+// queryMatchingSymbolIDsMultiWord is used to find the matching symbol ids when
+// the SearchType is SearchTypeMultiWord.
+%s
+
+// oldQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
+%s
+
+// oldQueryPackageDotSymbol - TODO(golang/go#44142): replace with
+// querySearchPackageDotSymbol.
+%s
+
+// oldQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
+%s
+`,
+	formatQuery("querySearchSymbol", newQuery(SearchTypeSymbol)),
+	formatQuery("querySearchPackageDotSymbol", newQuery(SearchTypePackageDotSymbol)),
+	formatQuery("querySearchMultiWord", newQuery(SearchTypeMultiWord)),
+	formatQuery("queryMatchingSymbolIDsSymbol", matchingIDsQuery(SearchTypeSymbol)),
+	formatQuery("queryMatchingSymbolIDsPackageDotSymbol", matchingIDsQuery(SearchTypePackageDotSymbol)),
+	formatQuery("queryMatchingSymbolIDsMultiWord", matchingIDsQuery(SearchTypeMultiWord)),
+	formatQuery("oldQuerySymbol", rawQuerySymbol),
+	formatQuery("oldQueryPackageDotSymbol", rawQueryPackageDotSymbol),
+	formatQuery("oldQueryMultiWord", rawQueryMultiWord))
 
 func formatQuery(name, query string) string {
 	return fmt.Sprintf("const %s = `%s`", name, query)
diff --git a/internal/postgres/symbolsearch/legacy_symbolsearch.go b/internal/postgres/symbolsearch/legacy_symbolsearch.go
index c2b88ac..16c15ad 100644
--- a/internal/postgres/symbolsearch/legacy_symbolsearch.go
+++ b/internal/postgres/symbolsearch/legacy_symbolsearch.go
@@ -125,12 +125,17 @@
 	return fmt.Sprintf("to_tsquery('%s', %s)", SymbolTextSearchConfiguration, processArg(arg))
 }
 
-// processSymbol converts a symbol with underscores to slashes (for example,
+// processArg converts a symbol with underscores to slashes (for example,
 // "A_B" -> "A-B"). This is because the postgres parser treats underscores as
 // slashes, but we want a search for "A" to rank "A_B" lower than just "A". We
 // also want to be able to search specificially for "A_B".
 func processArg(arg string) string {
-	return strings.ReplaceAll(arg, "$1", "replace($1, '_', '-')")
+	s := "$1"
+	if len(arg) == 2 && strings.HasPrefix(arg, "$") {
+		// If the arg is a different $N, substitute that instead.
+		s = arg
+	}
+	return strings.ReplaceAll(arg, s, fmt.Sprintf("replace(%s, '_', '-')", s))
 }
 
 const symbolSearchBaseQuery = `
diff --git a/internal/postgres/symbolsearch/query.gen.go b/internal/postgres/symbolsearch/query.gen.go
index 5add395..3c7df3c 100644
--- a/internal/postgres/symbolsearch/query.gen.go
+++ b/internal/postgres/symbolsearch/query.gen.go
@@ -6,10 +6,153 @@
 
 package symbolsearch
 
-// QuerySymbol is used when the search query is only one word, with no dots.
+// querySearchSymbol is used when the search query is only one word, with no dots.
 // In this case, the word must match a symbol name and ranking is completely
 // determined by the path_tokens.
-const querySymbol = `
+const querySearchSymbol = `
+WITH ssd AS (
+	SELECT
+		ssd.unit_id,
+		ssd.package_symbol_id,
+		ssd.symbol_name_id,
+		ssd.goos,
+		ssd.goarch,
+		ssd.ln_imported_by_count AS score
+	FROM symbol_search_documents ssd
+	WHERE
+		symbol_name_id = ANY($1) 
+	ORDER BY score DESC
+	LIMIT $2
+)
+SELECT
+	s.name AS symbol_name,
+	sd.package_path,
+	sd.module_path,
+	sd.version,
+	sd.name,
+	sd.synopsis,
+	sd.license_types,
+	sd.commit_time,
+	sd.imported_by_count,
+	ssd.goos,
+	ssd.goarch,
+	ps.type AS symbol_kind,
+	ps.synopsis AS symbol_synopsis
+FROM ssd
+INNER JOIN symbol_names s ON s.id=ssd.symbol_name_id
+INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
+INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
+ORDER BY score DESC;`
+
+// querySearchPackageDotSymbol is used when the search query is one element
+// containing a dot, where the first part is assumed to be the package name and
+// the second the symbol name. For example, "sql.DB" or "sql.DB.Begin".
+const querySearchPackageDotSymbol = `
+WITH ssd AS (
+	SELECT
+		ssd.unit_id,
+		ssd.package_symbol_id,
+		ssd.symbol_name_id,
+		ssd.goos,
+		ssd.goarch,
+		ssd.ln_imported_by_count AS score
+	FROM symbol_search_documents ssd
+	WHERE
+		symbol_name_id = ANY($1) 
+	AND (
+		ssd.uuid_package_name=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1)) OR
+		ssd.uuid_package_path=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))
+	)
+	ORDER BY score DESC
+	LIMIT $2
+)
+SELECT
+	s.name AS symbol_name,
+	sd.package_path,
+	sd.module_path,
+	sd.version,
+	sd.name,
+	sd.synopsis,
+	sd.license_types,
+	sd.commit_time,
+	sd.imported_by_count,
+	ssd.goos,
+	ssd.goarch,
+	ps.type AS symbol_kind,
+	ps.synopsis AS symbol_synopsis
+FROM ssd
+INNER JOIN symbol_names s ON s.id=ssd.symbol_name_id
+INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
+INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
+ORDER BY score DESC;`
+
+// querySearchMultiWord is used when the search query is multiple elements.
+const querySearchMultiWord = `
+WITH ssd AS (
+	SELECT
+		ssd.unit_id,
+		ssd.package_symbol_id,
+		ssd.symbol_name_id,
+		ssd.goos,
+		ssd.goarch,
+		(
+			ts_rank(
+				'{0.1, 0.2, 1.0, 1.0}',
+				sd.tsv_path_tokens,
+				to_tsquery('symbols', replace(replace($3, '_', '-'), ' ', ' | '))
+			) * ssd.ln_imported_by_count
+		) AS score
+	FROM symbol_search_documents ssd
+	INNER JOIN search_documents sd ON sd.package_path_id = ssd.package_path_id
+	WHERE
+		symbol_name_id = ANY($1)
+		AND sd.tsv_path_tokens @@ to_tsquery('symbols', replace(replace($3, '_', '-'), ' ', ' | '))
+	ORDER BY score DESC
+	LIMIT $2
+)
+SELECT
+	s.name AS symbol_name,
+	sd.package_path,
+	sd.module_path,
+	sd.version,
+	sd.name,
+	sd.synopsis,
+	sd.license_types,
+	sd.commit_time,
+	sd.imported_by_count,
+	ssd.goos,
+	ssd.goarch,
+	ps.type AS symbol_kind,
+	ps.synopsis AS symbol_synopsis
+FROM ssd
+INNER JOIN symbol_names s ON s.id=ssd.symbol_name_id
+INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
+INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
+ORDER BY score DESC;`
+
+// queryMatchingSymbolIDsSymbol is used to find the matching symbol
+// ids when the SearchType is SearchTypeSymbol.
+const queryMatchingSymbolIDsSymbol = `
+		SELECT id
+		FROM symbol_names
+		WHERE tsv_name_tokens @@ to_tsquery('symbols', replace($1, '_', '-')) OR lower(name) = lower($1)`
+
+// queryMatchingSymbolIDsPackageDotSymbol is used to find the matching symbol
+// ids when the SearchType is SearchTypePackageDotSymbol.
+const queryMatchingSymbolIDsPackageDotSymbol = `
+		SELECT id
+		FROM symbol_names
+		WHERE lower(name) = lower(substring($1 from E'[^.]*\.(.+)$'))`
+
+// queryMatchingSymbolIDsMultiWord is used to find the matching symbol ids when
+// the SearchType is SearchTypeMultiWord.
+const queryMatchingSymbolIDsMultiWord = `
+		SELECT id
+		FROM symbol_names
+		WHERE tsv_name_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | '))`
+
+// oldQuerySymbol - TODO(golang/go#44142): replace with querySearchSymbol.
+const oldQuerySymbol = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
@@ -54,10 +197,9 @@
 	package_path
 LIMIT $2;`
 
-// QueryPackageDotSymbol is used when the search query is one element
-// containing a dot, where the first part is assumed to be the package name and
-// the second the symbol name. For example, "sql.DB" or "sql.DB.Begin".
-const queryPackageDotSymbol = `
+// oldQueryPackageDotSymbol - TODO(golang/go#44142): replace with
+// querySearchPackageDotSymbol.
+const oldQueryPackageDotSymbol = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
@@ -104,8 +246,8 @@
 	package_path
 LIMIT $2;`
 
-// QueryMultiWord is used when the search query is multiple elements.
-const queryMultiWord = `
+// oldQueryMultiWord - TODO(golang/go#44142): replace with queryMultiWord.
+const oldQueryMultiWord = `
 WITH results AS (
 	SELECT
 			s.name AS symbol_name,
diff --git a/internal/postgres/symbolsearch/symbolsearch.go b/internal/postgres/symbolsearch/symbolsearch.go
new file mode 100644
index 0000000..b5c7b78
--- /dev/null
+++ b/internal/postgres/symbolsearch/symbolsearch.go
@@ -0,0 +1,127 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package symbolsearch
+
+import (
+	"fmt"
+)
+
+func newQuery(st SearchType) string {
+	var filter string
+	switch st {
+	case SearchTypeMultiWord:
+		return fmt.Sprintf(baseQuery, multiwordCTE)
+	case SearchTypeSymbol:
+		filter = ""
+	case SearchTypePackageDotSymbol:
+		// PackageDotSymbol case.
+		filter = newfilterPackageDotSymbol
+	}
+	q := fmt.Sprintf(baseQuery, fmt.Sprintf(symbolCTE, filter))
+	return q
+}
+
+// TODO(golang/go#44142): Filtering on package path currently only works for
+// standard library packages, since non-standard library packages will have a
+// dot.
+const newfilterPackageDotSymbol = `
+	AND (
+		ssd.uuid_package_name=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1)) OR
+		ssd.uuid_package_path=uuid_generate_v5(uuid_nil(), split_part($3, '.', 1))
+	)`
+
+const symbolCTE = `
+	SELECT
+		ssd.unit_id,
+		ssd.package_symbol_id,
+		ssd.symbol_name_id,
+		ssd.goos,
+		ssd.goarch,
+		ssd.ln_imported_by_count AS score
+	FROM symbol_search_documents ssd
+	WHERE
+		symbol_name_id = ANY($1) %s
+	ORDER BY score DESC
+	LIMIT $2
+`
+
+var multiwordCTE = fmt.Sprintf(`
+	SELECT
+		ssd.unit_id,
+		ssd.package_symbol_id,
+		ssd.symbol_name_id,
+		ssd.goos,
+		ssd.goarch,
+		(
+			ts_rank(
+				'{0.1, 0.2, 1.0, 1.0}',
+				sd.tsv_path_tokens,
+				to_tsquery('%s', %s)
+			) * ssd.ln_imported_by_count
+		) AS score
+	FROM symbol_search_documents ssd
+	INNER JOIN search_documents sd ON sd.package_path_id = ssd.package_path_id
+	WHERE
+		symbol_name_id = ANY($1)
+		AND sd.tsv_path_tokens @@ to_tsquery('%[1]s', %[2]s)
+	ORDER BY score DESC
+	LIMIT $2
+`,
+	SymbolTextSearchConfiguration,
+	splitORFunc(processArg("$3")))
+
+const baseQuery = `
+WITH ssd AS (%s)
+SELECT
+	s.name AS symbol_name,
+	sd.package_path,
+	sd.module_path,
+	sd.version,
+	sd.name,
+	sd.synopsis,
+	sd.license_types,
+	sd.commit_time,
+	sd.imported_by_count,
+	ssd.goos,
+	ssd.goarch,
+	ps.type AS symbol_kind,
+	ps.synopsis AS symbol_synopsis
+FROM ssd
+INNER JOIN symbol_names s ON s.id=ssd.symbol_name_id
+INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
+INNER JOIN package_symbols ps ON ps.id=ssd.package_symbol_id
+ORDER BY score DESC;`
+
+// matchingIDsQuery returns a query to fetch the symbol ids that match the
+// search input, based on the SearchType.
+func matchingIDsQuery(st SearchType) string {
+	var filter string
+	switch st {
+	case SearchTypeSymbol:
+		// When searching for just a symbol, match on both the identifier name
+		// and just the field or method name. For example, "Begin" will return
+		// "DB.Begin".
+		// tsv_name_tokens does a bad job of indexing symbol names with
+		// multiple "_", so also do an exact match search.
+		filter = fmt.Sprintf(`tsv_name_tokens @@ %s OR lower(name) = lower($1)`,
+			toTSQuery("$1"))
+	case SearchTypePackageDotSymbol:
+		// When searching for a <package>.<symbol>, only match on the exact
+		// symbol name. It is assumed that $1 = <package>.<symbol>.
+		filter = fmt.Sprintf("lower(name) = lower(%s)", "substring($1 from E'[^.]*\\.(.+)$')")
+	case SearchTypeMultiWord:
+		// TODO(44142): This is currently somewhat slow, since many IDs can be
+		// returned.
+		filter = fmt.Sprintf(`tsv_name_tokens @@ %s`, toTSQuery(splitORFunc("$1")))
+	}
+	return fmt.Sprintf(`
+		SELECT id
+		FROM symbol_names
+		WHERE %s`, filter)
+}
+
+func splitORFunc(arg string) string {
+	return fmt.Sprintf("replace(%s, ' ', ' | ')", arg)
+}
diff --git a/internal/postgres/symbolsearch/symbolsearch_test.go b/internal/postgres/symbolsearch/symbolsearch_test.go
index ef9233b..e1d5c23 100644
--- a/internal/postgres/symbolsearch/symbolsearch_test.go
+++ b/internal/postgres/symbolsearch/symbolsearch_test.go
@@ -16,9 +16,15 @@
 	for _, test := range []struct {
 		name, q, want string
 	}{
-		{"querySymbol", rawQuerySymbol, querySymbol},
-		{"queryPackageDotSymbol", rawQueryPackageDotSymbol, queryPackageDotSymbol},
-		{"queryMultiWord", rawQueryMultiWord, queryMultiWord},
+		{"querySearchSymbol", newQuery(SearchTypeSymbol), querySearchSymbol},
+		{"querySearchPackageDotSymbol", newQuery(SearchTypePackageDotSymbol), querySearchPackageDotSymbol},
+		{"querySearchMultiWord", newQuery(SearchTypeMultiWord), querySearchMultiWord},
+		{"queryMatchingSymbolIDsSymbol", matchingIDsQuery(SearchTypeSymbol), queryMatchingSymbolIDsSymbol},
+		{"queryMatchingSymbolIDsPackageDotSymbol", matchingIDsQuery(SearchTypePackageDotSymbol), queryMatchingSymbolIDsPackageDotSymbol},
+		{"queryMatchingSymbolIDsMultiWord", matchingIDsQuery(SearchTypeMultiWord), queryMatchingSymbolIDsMultiWord},
+		{"oldQuerySymbol", rawQuerySymbol, oldQuerySymbol},
+		{"oldQueryPackageDotSymbol", rawQueryPackageDotSymbol, oldQueryPackageDotSymbol},
+		{"oldQueryMultiWord", rawQueryMultiWord, oldQueryMultiWord},
 	} {
 		t.Run(test.name, func(t *testing.T) {
 			if diff := cmp.Diff(test.want, test.q); diff != "" {