blob: a89118aeb67b78dcd2a4265cdda067ac0ed4c340 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
"github.com/lib/pq"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/postgres/search"
"golang.org/x/sync/errgroup"
)
func upsertSymbolSearchDocuments(ctx context.Context, tx *database.DB,
modulePath, v string) (err error) {
defer derrors.Wrap(&err, "upsertSymbolSearchDocuments(ctx, ddb, %q, %q)", modulePath, v)
defer middleware.ElapsedStat(ctx, "upsertSymbolSearchDocuments")()
// If a user is looking for the symbol "DB.Begin", from package
// database/sql, we want them to be able to find this by searching for
// "DB.Begin" and "sql.DB.Begin". Searching for "sql.DB", "DB", "Begin" or
// "sql.DB" will not return "DB.Begin".
// If a user is looking for the symbol "DB.Begin", from package
// database/sql, we want them to be able to find this by searching for
// "DB.Begin", "Begin", and "sql.DB.Begin". Searching for "sql.DB" or
// "DB" will not return "DB.Begin".
q := `
INSERT INTO symbol_search_documents (
package_path_id,
symbol_name_id,
unit_id,
package_symbol_id,
goos,
goarch,
package_name,
package_path,
imported_by_count,
symbol_name
)
SELECT DISTINCT ON (sd.package_path_id, ps.symbol_name_id)
sd.package_path_id,
ps.symbol_name_id,
sd.unit_id,
ps.id AS package_symbol_id,
d.goos,
d.goarch,
sd.name,
sd.package_path,
sd.imported_by_count,
s.name
FROM search_documents sd
INNER JOIN units u ON sd.unit_id = u.id
INNER JOIN documentation d ON d.unit_id = sd.unit_id
INNER JOIN documentation_symbols ds ON d.id = ds.documentation_id
INNER JOIN package_symbols ps ON ps.id = ds.package_symbol_id
INNER JOIN symbol_names s ON s.id = ps.symbol_name_id
WHERE
sd.module_path = $1 AND sd.version = $2
AND u.name != 'main' -- do not insert data for commands
AND sd.redistributable
ORDER BY
sd.package_path_id,
ps.symbol_name_id,
-- Order should match internal.BuildContexts.
CASE WHEN d.goos = 'all' THEN 0
WHEN d.goos = 'linux' THEN 1
WHEN d.goos = 'windows' THEN 2
WHEN d.goos = 'darwin' THEN 3
WHEN d.goos = 'js' THEN 4
END
ON CONFLICT (package_path_id, symbol_name_id)
DO UPDATE SET
unit_id = excluded.unit_id,
package_symbol_id = excluded.package_symbol_id,
goos = excluded.goos,
goarch = excluded.goarch,
package_name = excluded.package_name,
package_path = excluded.package_path,
imported_by_count = excluded.imported_by_count,
symbol_name = excluded.symbol_name;`
_, err = tx.Exec(ctx, q, modulePath, v)
return err
}
// symbolSearch searches all symbols in the symbol_search_documents table for
// the query.
//
// TODO(https://golang.org/issue/44142): factor out common code between
// symbolSearch and deepSearch.
func (db *DB) symbolSearch(ctx context.Context, q string, limit int, opts SearchOptions) searchResponse {
defer middleware.ElapsedStat(ctx, "symbolSearch")()
var (
results []*SearchResult
err error
)
sr := searchResponse{source: "symbol"}
it := search.ParseInputType(q)
switch it {
case search.InputTypeOneDot:
results, err = runSymbolSearchOneDot(ctx, db.db, q, limit)
case search.InputTypeMultiWord:
results, err = runSymbolSearchMultiWord(ctx, db.db, q, limit, opts.SymbolFilter)
case search.InputTypeNoDot:
results, err = runSymbolSearch(ctx, db.db, search.SearchTypeSymbol, q, limit)
case search.InputTypeTwoDots:
results, err = runSymbolSearchPackageDotSymbol(ctx, db.db, q, limit)
default:
// There is no supported situation where we will get results for one
// element containing more than 2 dots.
return sr
}
if len(results) == 0 {
if err != nil && !errors.Is(err, derrors.NotFound) {
sr.err = err
}
return sr
}
sort.Slice(results, func(i, j int) bool {
if results[i].NumImportedBy != results[j].NumImportedBy {
return results[i].NumImportedBy > results[j].NumImportedBy
}
// If two packages have the same imported by count, return them in
// alphabetical order by package path.
if results[i].PackagePath != results[j].PackagePath {
return results[i].PackagePath < results[j].PackagePath
}
// If one package has multiple matching symbols, return them by
// alphabetical order of symbol name.
return results[i].SymbolName < results[j].SymbolName
})
if len(results) > limit {
results = results[0:limit]
}
for _, r := range results {
r.NumResults = uint64(len(results))
}
sr.results = results
return sr
}
// runSymbolSearchMultiWord executes a symbol search for SearchTypeMultiWord.
func runSymbolSearchMultiWord(ctx context.Context, ddb *database.DB, q string, limit int,
symbolFilter string) (_ []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearchMultiWord(ctx, ddb, query, %q, %d, %q)",
q, limit, symbolFilter)
defer middleware.ElapsedStat(ctx, "runSymbolSearchMultiWord")()
symbolToPathTokens := multiwordSearchCombinations(q, symbolFilter)
if len(symbolToPathTokens) == 0 {
// There are no words in the query that could be a symbol name.
return nil, derrors.NotFound
}
if strings.Contains(q, "|") {
// TODO(golang/go#44142): The search.SearchTypeMultiWordOr case
// is currently not supported.
return nil, derrors.NotFound
}
group, searchCtx := errgroup.WithContext(ctx)
resultsArray := make([][]*SearchResult, len(symbolToPathTokens))
count := 0
for symbol, pathTokens := range symbolToPathTokens {
symbol := symbol
pathTokens := pathTokens
i := count
count += 1
group.Go(func() error {
st := search.SearchTypeMultiWordExact
r, err := runSymbolSearch(searchCtx, ddb, st, symbol, limit, pathTokens)
if err != nil {
return err
}
resultsArray[i] = r
return nil
})
}
if err := group.Wait(); err != nil {
return nil, err
}
return mergedResults(resultsArray, limit), nil
}
func mergedResults(resultsArray [][]*SearchResult, limit int) []*SearchResult {
var results []*SearchResult
deduped := map[string]bool{}
for _, array := range resultsArray {
for _, r := range array {
key := fmt.Sprintf("%s@%s", r.PackagePath, r.SymbolName)
if !deduped[key] {
results = append(results, r)
deduped[key] = true
}
}
}
sort.Slice(results, func(i, j int) bool { return results[i].NumImportedBy > results[j].NumImportedBy })
if len(results) > limit {
results = results[0:limit]
}
return results
}
// multiwordSearchCombinations returns a map of symbol name to path_tokens to
// be used for possible search combinations.
//
// For each word, check if there is an invalid symbol character or if it
// matches a common hostname. If so, the search on tsv_path_tokens must match
// that search.
//
// It is assumed that the symbol name is always 1 word. For example, if the
// user wants sql.DB.Begin, "sql DB.Begin", "sql Begin", or "sql DB" will all
// be return the relevant result, but "sql DB Begin" will not.
func multiwordSearchCombinations(q, symbolFilter string) map[string]string {
words := strings.Fields(q)
symbolToPathTokens := map[string]string{}
for i, w := range words {
// Is this word a possible symbol name? If not, continue.
if strings.Contains(w, "/") || strings.Contains(w, "-") || commonHostnames[w] {
continue
}
// A symbolFilter was used, and this word does not match it, so
// it can't be the symbol name.
if symbolFilter != "" && w != symbolFilter {
continue
}
// If it is, try search for this word assuming it is the symbol name
// and everything else is a path element.
pathTokens := append(append([]string{}, words[0:i]...), words[i+1:]...)
sort.Strings(pathTokens)
symbolToPathTokens[w] = strings.Join(pathTokens, " & ")
}
if len(symbolToPathTokens) == 0 {
return nil
}
if len(symbolToPathTokens) > 3 {
// There are too many searches that can be performed, so
// return no results.
// TODO(golang/go#44142): Leave add support for an OR query.
return nil
}
return symbolToPathTokens
}
// runSymbolSearchOneDot is used when q contains only 1 dot, so the search must
// either be for <package>.<symbol> or <type>.<methodOrFieldName>.
//
// This search is split into two parallel queries, since the query is very slow
// when using an OR in the WHERE clause.
func runSymbolSearchOneDot(ctx context.Context, ddb *database.DB, q string, limit int) (_ []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearchOneDot(ctx, ddb, %q, %d)", q, limit)
defer middleware.ElapsedStat(ctx, "runSymbolSearchOneDot")()
group, searchCtx := errgroup.WithContext(ctx)
resultsArray := make([][]*SearchResult, 2)
for i, st := range []search.SearchType{
search.SearchTypeSymbol,
search.SearchTypePackageDotSymbol,
} {
i := i
st := st
group.Go(func() error {
var (
results []*SearchResult
err error
)
if st == search.SearchTypePackageDotSymbol {
results, err = runSymbolSearchPackageDotSymbol(searchCtx, ddb, q, limit)
} else {
results, err = runSymbolSearch(searchCtx, ddb, st, q, limit)
}
if err != nil {
return err
}
resultsArray[i] = results
return nil
})
}
if err := group.Wait(); err != nil {
return nil, err
}
return mergedResults(resultsArray, limit), nil
}
func runSymbolSearchPackageDotSymbol(ctx context.Context, ddb *database.DB, q string, limit int) (_ []*SearchResult, err error) {
pkg, symbol, err := splitPackageAndSymbolNames(q)
if err != nil {
return nil, err
}
return runSymbolSearch(ctx, ddb, search.SearchTypePackageDotSymbol, symbol, limit, pkg)
}
func splitPackageAndSymbolNames(q string) (pkgName string, symbolName string, err error) {
parts := strings.Split(q, ".")
if len(parts) != 2 && len(parts) != 3 {
return "", "", derrors.NotFound
}
for _, p := range parts {
// Handle cases where we have odd dot placement, such as .Foo or
// Foo..
if p == "" {
return "", "", derrors.NotFound
}
}
return parts[0], strings.Join(parts[1:], "."), nil
}
func runSymbolSearch(ctx context.Context, ddb *database.DB,
st search.SearchType, q string, limit int, args ...any) (results []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearch(ctx, ddb, %q, %q, %d, %v)", st, q, limit, args)
defer middleware.ElapsedStat(ctx, fmt.Sprintf("%s-runSymbolSearch", st))()
collect := func(rows *sql.Rows) error {
var r SearchResult
if err := rows.Scan(
&r.SymbolName,
&r.PackagePath,
&r.ModulePath,
&r.Version,
&r.Name,
&r.Synopsis,
pq.Array(&r.Licenses),
&r.CommitTime,
&r.NumImportedBy,
&r.SymbolGOOS,
&r.SymbolGOARCH,
&r.SymbolKind,
&r.SymbolSynopsis); err != nil {
return fmt.Errorf("symbolSearch: rows.Scan(): %v", err)
}
results = append(results, &r)
return nil
}
query := search.SymbolQuery(st)
args = append([]any{q, limit}, args...)
if err := ddb.RunQuery(ctx, query, collect, args...); err != nil {
return nil, err
}
return results, nil
}