internal/postgres: change upserts for symbol_names and package_symbols

To avoid unnecessary write attempts when inserting symbols information,
run a SELECT query first to get existing symbol names and package
symbols, and only insert rows that we know do not exist.

For golang/go#37102

Change-Id: I6aaf96a34702c3736be6bf1e9386d5139404fafb
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/295270
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Trust: Julie Qiu <julie@golang.org>
diff --git a/internal/postgres/symbol.go b/internal/postgres/symbol.go
index 6481a94..b3229f3 100644
--- a/internal/postgres/symbol.go
+++ b/internal/postgres/symbol.go
@@ -46,14 +46,16 @@
 			}
 			for _, build := range builds {
 				nameToSymbol := buildToNameToSym[internal.BuildContext{GOOS: build.GOOS, GOARCH: build.GOARCH}]
-				updateSymbols(doc.API, func(s *internal.Symbol) (err error) {
+				if err := updateSymbols(doc.API, func(s *internal.Symbol) (err error) {
 					defer derrors.WrapStack(&err, "updateSymbols(%q)", s.Name)
 					if !shouldUpdateSymbolHistory(s.Name, version, nameToSymbol) {
 						return nil
 					}
-					pkgsymID := pkgsymToID[packageSymbolKey(s.Section, s.Synopsis)]
+
+					pkgsym := packageSymbol{synopsis: s.Synopsis, section: s.Section}
+					pkgsymID := pkgsymToID[pkgsym]
 					if pkgsymID == 0 {
-						return fmt.Errorf("symbolID cannot be 0: %q", s.Name)
+						return fmt.Errorf("pkgsymID cannot be 0: %q", pkgsym)
 					}
 
 					// Validate that the unique constraint won't be violated.
@@ -66,7 +68,9 @@
 					uniqueKeys[key] = fmt.Sprintf("%q (%q)", s.Name, s.Synopsis)
 					symHistoryValues = append(symHistoryValues, pkgsymID, build.GOOS, build.GOARCH, version)
 					return nil
-				})
+				}); err != nil {
+					return err
+				}
 			}
 		}
 	}
@@ -89,54 +93,24 @@
 	return semver.Compare(newVersion, dh.SinceVersion) < 1
 }
 
+type packageSymbol struct {
+	synopsis string
+	section  internal.SymbolSection
+}
+
 func upsertPackageSymbolsReturningIDs(ctx context.Context, db *database.DB,
-	modulePath string, pathToID map[string]int, pathToDocs map[string][]*internal.Documentation) (_ map[string]int, err error) {
+	modulePath string, pathToID map[string]int, pathToDocs map[string][]*internal.Documentation) (_ map[packageSymbol]int, err error) {
 	defer derrors.WrapStack(&err, "upsertPackageSymbolsReturningIDs(ctx, db, %q, pathToID, pathToDocs)", modulePath)
 	nameToID, err := upsertSymbolNamesReturningIDs(ctx, db, pathToDocs)
 	if err != nil {
 		return nil, err
 	}
 
-	var values []interface{}
 	modulePathID := pathToID[modulePath]
 	if modulePathID == 0 {
 		return nil, fmt.Errorf("modulePathID cannot be 0: %q", modulePath)
 	}
-	for path, docs := range pathToDocs {
-		pathID := pathToID[path]
-		for _, doc := range docs {
-			updateSymbols(doc.API, func(s *internal.Symbol) error {
-				if s.ParentName == "" {
-					s.ParentName = s.Name
-				}
-				values = append(values, pathID, modulePathID, nameToID[s.Name], nameToID[s.ParentName], s.Section, s.Kind, s.Synopsis)
-				return nil
-			})
-		}
-	}
-	if err := db.BulkInsert(ctx, "package_symbols",
-		[]string{
-			"package_path_id",
-			"module_path_id",
-			"symbol_name_id",
-			"parent_symbol_name_id",
-			"section",
-			"type",
-			"synopsis",
-		}, values, database.OnConflictDoNothing); err != nil {
-		return nil, err
-	}
-
-	query := `
-        SELECT
-            ps.id,
-            ps.section,
-            ps.synopsis
-        FROM package_symbols ps
-        INNER JOIN symbol_names sn
-        ON ps.symbol_name_id = sn.id
-        WHERE module_path_id = $1;`
-	pkgsymToID := map[string]int{}
+	pkgsymToID := map[packageSymbol]int{}
 	collect := func(rows *sql.Rows) error {
 		var (
 			id       int
@@ -146,65 +120,116 @@
 		if err := rows.Scan(&id, &section, &synopsis); err != nil {
 			return fmt.Errorf("row.Scan(): %v", err)
 		}
-		pkgsymToID[packageSymbolKey(section, synopsis)] = id
+		pkgsymToID[packageSymbol{synopsis: synopsis, section: section}] = id
 		return nil
 	}
-	if err := db.RunQuery(ctx, query, collect, modulePathID); err != nil {
+	if err := db.RunQuery(ctx, `
+        SELECT
+            ps.id,
+            ps.section,
+            ps.synopsis
+        FROM package_symbols ps
+        INNER JOIN symbol_names sn ON ps.symbol_name_id = sn.id
+        WHERE module_path_id = $1;`, collect, modulePathID); err != nil {
 		return nil, err
 	}
-	for _, docs := range pathToDocs {
+
+	var packageSymbols []interface{}
+	for path, docs := range pathToDocs {
+		pathID := pathToID[path]
+		if pathID == 0 {
+			return nil, fmt.Errorf("pathID cannot be 0: %q", path)
+		}
 		for _, doc := range docs {
-			updateSymbols(doc.API, func(s *internal.Symbol) error {
-				if _, ok := pkgsymToID[packageSymbolKey(s.Section, s.Synopsis)]; !ok {
-					return fmt.Errorf("missing package symbol for %q %q (section=%q, type=%q)", s.Name, s.Synopsis, s.Section, s.Kind)
+			if err := updateSymbols(doc.API, func(s *internal.Symbol) error {
+				ps := packageSymbol{synopsis: s.Synopsis, section: s.Section}
+				symID := nameToID[s.Name]
+				if symID == 0 {
+					return fmt.Errorf("pathID cannot be 0: %q", s.Name)
+				}
+				if s.ParentName == "" {
+					s.ParentName = s.Name
+				}
+				parentID := nameToID[s.ParentName]
+				if parentID == 0 {
+					return fmt.Errorf("pathID cannot be 0: %q", s.ParentName)
+				}
+				if _, ok := pkgsymToID[ps]; !ok {
+					packageSymbols = append(packageSymbols, pathID,
+						modulePathID, symID, parentID, s.Section, s.Kind,
+						s.Synopsis)
 				}
 				return nil
-			})
+			}); err != nil {
+				return nil, err
+			}
 		}
 	}
+	// The order of pkgsymcols must match that of the SELECT query in the
+	//collect function.
+	pkgsymcols := []string{"id", "section", "synopsis"}
+	if err := db.BulkInsertReturning(ctx, "package_symbols",
+		[]string{
+			"package_path_id",
+			"module_path_id",
+			"symbol_name_id",
+			"parent_symbol_name_id",
+			"section",
+			"type",
+			"synopsis",
+		}, packageSymbols, database.OnConflictDoNothing, pkgsymcols, collect); err != nil {
+		return nil, err
+	}
 	return pkgsymToID, nil
 }
 
-func packageSymbolKey(section internal.SymbolSection, synopsis string) string {
-	return fmt.Sprintf("section=%s_synopsis=%s", section, synopsis)
-}
-
 func upsertSymbolNamesReturningIDs(ctx context.Context, db *database.DB, pathToDocs map[string][]*internal.Documentation) (_ map[string]int, err error) {
 	defer derrors.WrapStack(&err, "upsertSymbolNamesReturningIDs")
-	var values []interface{}
+	var names []string
 	for _, docs := range pathToDocs {
 		for _, doc := range docs {
-			updateSymbols(doc.API, func(s *internal.Symbol) error {
-				values = append(values, s.Name)
+			if err := updateSymbols(doc.API, func(s *internal.Symbol) error {
+				names = append(names, s.Name)
 				return nil
-			})
+			}); err != nil {
+				return nil, err
+			}
 		}
 	}
-
-	if err := db.BulkInsert(ctx, "symbol_names", []string{"name"}, values, database.OnConflictDoNothing); err != nil {
-		return nil, err
-	}
 	query := `
         SELECT id, name
         FROM symbol_names
         WHERE name = ANY($1);`
-	symbols := map[string]int{}
+	nameToID := map[string]int{}
 	collect := func(rows *sql.Rows) error {
-		var name string
-		var id int
+		var (
+			id   int
+			name string
+		)
 		if err := rows.Scan(&id, &name); err != nil {
 			return fmt.Errorf("row.Scan(): %v", err)
 		}
-		symbols[name] = id
+		nameToID[name] = id
 		if id == 0 {
 			return fmt.Errorf("id can't be 0: %q", name)
 		}
 		return nil
 	}
-	if err := db.RunQuery(ctx, query, collect, pq.Array(values)); err != nil {
+	if err := db.RunQuery(ctx, query, collect, pq.Array(names)); err != nil {
 		return nil, err
 	}
-	return symbols, nil
+
+	var values []interface{}
+	for _, name := range names {
+		if _, ok := nameToID[name]; !ok {
+			values = append(values, name)
+		}
+	}
+	if err := db.BulkInsertReturning(ctx, "symbol_names", []string{"name"},
+		values, database.OnConflictDoNothing, []string{"id", "name"}, collect); err != nil {
+		return nil, err
+	}
+	return nameToID, nil
 }
 
 func getSymbolHistory(ctx context.Context, db *database.DB, packagePath, modulePath string) (_ map[internal.BuildContext]map[string]*internal.Symbol, err error) {
diff --git a/internal/postgres/symbol_test.go b/internal/postgres/symbol_test.go
index 16f9dda..aaee4f3 100644
--- a/internal/postgres/symbol_test.go
+++ b/internal/postgres/symbol_test.go
@@ -249,11 +249,6 @@
 	MustInsertModule(ctx, t, testDB, mod10)
 	MustInsertModule(ctx, t, testDB, mod11)
 
-	gotHist, err := getSymbolHistory(ctx, testDB.db, mod10.Packages()[0].Path, mod10.ModulePath)
-	if err != nil {
-		t.Fatal(err)
-	}
-
 	symbols := map[string]*internal.Symbol{
 		"Foo":   typ,
 		"Foo.A": methodA,
@@ -265,6 +260,10 @@
 		internal.BuildContextLinux:   symbols,
 		internal.BuildContextWindows: symbols,
 	}
+	gotHist, err := getSymbolHistory(ctx, testDB.db, mod12.Packages()[0].Path, mod12.ModulePath)
+	if err != nil {
+		t.Fatal(err)
+	}
 	if diff := cmp.Diff(wantHist, gotHist,
 		cmpopts.IgnoreFields(internal.Symbol{}, "GOOS", "GOARCH")); diff != "" {
 		t.Fatalf("mismatch (-want +got):\n%s", diff)
@@ -304,10 +303,22 @@
 	mod11 := moduleWithSymbols(t, "v1.1.0", nil)
 	makeDocs := func() []*internal.Documentation {
 		return []*internal.Documentation{
-			sample.Documentation("linux", "amd64", sample.DocContents),
-			sample.Documentation("windows", "amd64", sample.DocContents),
-			sample.Documentation("darwin", "amd64", sample.DocContents),
-			sample.Documentation("js", "wasm", sample.DocContents),
+			sample.Documentation(
+				internal.BuildContextLinux.GOOS,
+				internal.BuildContextLinux.GOARCH,
+				sample.DocContents),
+			sample.Documentation(
+				internal.BuildContextWindows.GOOS,
+				internal.BuildContextWindows.GOARCH,
+				sample.DocContents),
+			sample.Documentation(
+				internal.BuildContextDarwin.GOOS,
+				internal.BuildContextDarwin.GOARCH,
+				sample.DocContents),
+			sample.Documentation(
+				internal.BuildContextJS.GOOS,
+				internal.BuildContextJS.GOARCH,
+				sample.DocContents),
 		}
 	}
 	mod11.Packages()[0].Documentation = makeDocs()