internal: delete stale symbol_search_documents

When the latest version of a module is inserted, delete
symbol_search_documents rows for symbols not in that version of the
package, so that stale rows are removed.

For golang/go#44142

Change-Id: I9a5c45eff4713cb765fd15b57530bf139385a3a8
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/349892
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/symbol.go b/internal/postgres/symbol.go
index bbec5c4..7a3ec3a 100644
--- a/internal/postgres/symbol.go
+++ b/internal/postgres/symbol.go
@@ -10,10 +10,12 @@
 	"fmt"
 	"sort"
 
+	"github.com/Masterminds/squirrel"
 	"github.com/lib/pq"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/database"
 	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/version"
 )
 
@@ -32,7 +34,10 @@
 	if versionType != version.TypeRelease && !isLatest {
 		return nil
 	}
-
+	modulePathID := pathToID[modulePath]
+	if modulePathID == 0 {
+		return fmt.Errorf("modulePathID cannot be 0: %q", modulePath)
+	}
 	pathToDocIDToDoc, err := getDocIDsForPath(ctx, tx, pathToUnitID, pathToDocs)
 	if err != nil {
 		return err
@@ -41,7 +46,7 @@
 	if err != nil {
 		return err
 	}
-	pathToPkgsymToID, err := upsertPackageSymbolsReturningIDs(ctx, tx, modulePath, pathToID, nameToID, pathToDocIDToDoc)
+	pathToPkgsymToID, err := upsertPackageSymbolsReturningIDs(ctx, tx, modulePathID, pathToID, nameToID, pathToDocIDToDoc)
 	if err != nil {
 		return err
 	}
@@ -49,8 +54,13 @@
 		return err
 	}
 	if versionType == version.TypeRelease {
-		return upsertSymbolHistory(ctx, tx, modulePath, v, nameToID,
-			pathToID, pathToPkgsymToID, pathToDocIDToDoc)
+		if err := upsertSymbolHistory(ctx, tx, modulePath, v, nameToID,
+			pathToID, pathToPkgsymToID, pathToDocIDToDoc); err != nil {
+			return err
+		}
+	}
+	if isLatest {
+		return deleteOldSymbolSearchDocuments(ctx, tx, modulePathID, pathToID, pathToDocIDToDoc, pathToPkgsymToID)
 	}
 	return nil
 }
@@ -183,11 +193,11 @@
 }
 
 func upsertPackageSymbolsReturningIDs(ctx context.Context, db *database.DB,
-	modulePath string,
+	modulePathID int,
 	pathToID map[string]int,
 	nameToID map[string]int,
 	pathToDocIDToDoc map[string]map[int]*internal.Documentation) (_ map[string]map[packageSymbol]int, err error) {
-	defer derrors.WrapStack(&err, "upsertPackageSymbolsReturningIDs(ctx, db, %q, pathToID, pathToDocIDToDoc)", modulePath)
+	defer derrors.WrapStack(&err, "upsertPackageSymbolsReturningIDs(ctx, db, %d, pathToID, pathToDocIDToDoc)", modulePathID)
 
 	idToPath := map[int]string{}
 	for path, id := range pathToID {
@@ -200,10 +210,6 @@
 		names = append(names, name)
 	}
 
-	modulePathID := pathToID[modulePath]
-	if modulePathID == 0 {
-		return nil, fmt.Errorf("modulePathID cannot be 0: %q", modulePath)
-	}
 	pathTopkgsymToID := map[string]map[packageSymbol]int{}
 	collect := func(rows *sql.Rows) error {
 		var (
@@ -383,3 +389,75 @@
 	}
 	return nil
 }
+
+func deleteOldSymbolSearchDocuments(ctx context.Context, db *database.DB,
+	modulePathID int,
+	pathToID map[string]int,
+	pathToDocIDToDoc map[string]map[int]*internal.Documentation,
+	latestPathToPkgsymToID map[string]map[packageSymbol]int) (err error) {
+	defer derrors.WrapStack(&err, "deleteOldSymbolSearchDocuments(ctx, db, %q, pathToID, pathToDocIDToDoc)", modulePathID)
+
+	// Get all package_symbol_ids for the latest module (the current one we are
+	// trying to insert).
+	latestPkgsymIDs := map[int]bool{}
+	for path := range pathToID {
+		docs := pathToDocIDToDoc[path]
+		pathID := pathToID[path]
+		if pathID == 0 {
+			return fmt.Errorf("pathID cannot be 0: %q", path)
+		}
+		for _, doc := range docs {
+			err := updateSymbols(doc.API, func(sm *internal.SymbolMeta) error {
+				pkgsymToID, ok := latestPathToPkgsymToID[path]
+				if !ok {
+					return fmt.Errorf("path could not be found: %q", path)
+				}
+				ps := packageSymbol{synopsis: sm.Synopsis, name: sm.Name, parentName: sm.ParentName}
+				pkgsymID, ok := pkgsymToID[ps]
+				if !ok {
+					return fmt.Errorf("package symbol could not be found: %v", ps)
+				}
+				latestPkgsymIDs[pkgsymID] = true
+				return nil
+			})
+			if err != nil {
+				return err
+			}
+		}
+	}
+
+	var pathIDs []int
+	for _, id := range pathToID {
+		pathIDs = append(pathIDs, id)
+	}
+	// Fetch package_symbol_id currently in symbol_search_documents.
+	dbPkgSymIDs, err := db.CollectInts(ctx, `
+		SELECT package_symbol_id
+		FROM symbol_search_documents
+		WHERE package_path_id = ANY($1);`,
+		pq.Array(pathIDs))
+	if err != nil {
+		return err
+	}
+
+	var toDelete []int
+	for _, id := range dbPkgSymIDs {
+		if _, ok := latestPkgsymIDs[id]; !ok {
+			toDelete = append(toDelete, id)
+		}
+	}
+
+	// Delete stale rows.
+	q, args, err := squirrel.Delete("symbol_search_documents").
+		Where("package_symbol_id = ANY(?)", pq.Array(toDelete)).
+		PlaceholderFormat(squirrel.Dollar).ToSql()
+	if err != nil {
+		return err
+	}
+	n, err := db.Exec(ctx, q, args...)
+	if err != nil {
+		return err
+	}
+	log.Infof(ctx, "deleted %d rows from symbol_search_documents", n)
+	return nil
+}
diff --git a/internal/postgres/symbol_test.go b/internal/postgres/symbol_test.go
index 8b56ec3..9ccdd9e 100644
--- a/internal/postgres/symbol_test.go
+++ b/internal/postgres/symbol_test.go
@@ -562,3 +562,55 @@
 		})
 	}
 }
+
+func TestDeleteOldSymbolSearchDocuments(t *testing.T) {
+	ctx := context.Background()
+	testDB, release := acquire(t)
+	defer release()
+
+	q := `SELECT symbol_name FROM symbol_search_documents;`
+	checkRows := func(t *testing.T, v string, api []*internal.Symbol) {
+		t.Helper()
+		m := sample.DefaultModule()
+		m.Version = v
+		m.Packages()[0].Documentation[0].API = api
+		MustInsertModule(ctx, t, testDB, m)
+		got, err := testDB.db.CollectStrings(ctx, q)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		var want []string
+		if err := updateSymbols(api, func(sm *internal.SymbolMeta) error {
+			want = append(want, sm.Name)
+			return nil
+		}); err != nil {
+			t.Fatal(err)
+		}
+
+		sort.Strings(got)
+		sort.Strings(want)
+		if diff := cmp.Diff(want, got); diff != "" {
+			t.Errorf("mismatch for %q (-want +got):\n%s", v, diff)
+		}
+	}
+
+	api := []*internal.Symbol{
+		sample.Constant,
+		sample.Variable,
+		sample.Function,
+		sample.Type,
+	}
+	checkRows(t, "v1.1.0", api)
+
+	// Symbol deleted in newer version.
+	api2 := []*internal.Symbol{
+		sample.Constant,
+		sample.Variable,
+		sample.Function,
+	}
+	checkRows(t, "v1.2.0", api2)
+
+	// Older version inserted, no effect.
+	checkRows(t, "v1.0.0", api2)
+}