internal/postgres: refactor path insertion
Move the module-independent part of insertPaths to a separate
function, upsertPaths, and put it in the paths.go file so it is near
it's single-path cousin.
Write a simple test for it.
Change-Id: I3f4a6d3e21836b2840153354c880f8be823b466b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/296551
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/postgres/insert_module.go b/internal/postgres/insert_module.go
index 02dc81b..51f186c 100644
--- a/internal/postgres/insert_module.go
+++ b/internal/postgres/insert_module.go
@@ -391,20 +391,6 @@
}
func insertPaths(ctx context.Context, db *database.DB, m *internal.Module) (pathToID map[string]int, err error) {
- // Add new unit paths to the paths table.
- pathToID = map[string]int{}
- collect := func(rows *sql.Rows) error {
- var (
- pathID int
- path string
- )
- if err := rows.Scan(&pathID, &path); err != nil {
- return err
- }
- pathToID[path] = pathID
- return nil
- }
-
// Read all existing paths for this module, to avoid a large bulk upsert.
// (We've seen these bulk upserts hang for so long that they time out (10
// minutes)).
@@ -418,28 +404,7 @@
for p := range curPathsSet {
curPaths = append(curPaths, p)
}
- if err := db.RunQuery(ctx, `SELECT id, path FROM paths WHERE path = ANY($1)`,
- collect, pq.Array(curPaths)); err != nil {
- return nil, err
- }
-
- // Insert any unit paths that we don't already have.
- var values []interface{}
- for _, v := range curPaths {
- if _, ok := pathToID[v]; !ok {
- values = append(values, v)
- }
- }
- if len(values) > 0 {
- // Insert data into the paths table.
- pathCols := []string{"path"}
- returningPathCols := []string{"id", "path"}
- if err := db.BulkInsertReturning(ctx, "paths", pathCols, values,
- database.OnConflictDoNothing, returningPathCols, collect); err != nil {
- return nil, err
- }
- }
- return pathToID, nil
+ return upsertPaths(ctx, db, curPaths)
}
func insertUnits(ctx context.Context, db *database.DB, unitValues []interface{}) (pathIDToUnitID map[int]int, err error) {
diff --git a/internal/postgres/path.go b/internal/postgres/path.go
index 778ef95..96d120c 100644
--- a/internal/postgres/path.go
+++ b/internal/postgres/path.go
@@ -12,6 +12,7 @@
"strconv"
"strings"
+ "github.com/lib/pq"
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
@@ -106,3 +107,46 @@
}
return id, nil
}
+
+// upsertPaths adds all the paths to the paths table if they aren't already
+// there, and returns their ID either way.
+// It assumes it is running inside a transaction.
+func upsertPaths(ctx context.Context, db *database.DB, paths []string) (pathToID map[string]int, err error) {
+ defer derrors.WrapStack(&err, "upsertPaths(%d paths)", len(paths))
+
+ pathToID = map[string]int{}
+ collect := func(rows *sql.Rows) error {
+ var (
+ pathID int
+ path string
+ )
+ if err := rows.Scan(&pathID, &path); err != nil {
+ return err
+ }
+ pathToID[path] = pathID
+ return nil
+ }
+
+ if err := db.RunQuery(ctx, `SELECT id, path FROM paths WHERE path = ANY($1)`,
+ collect, pq.Array(paths)); err != nil {
+ return nil, err
+ }
+
+ // Insert any unit paths that we don't already have.
+ var values []interface{}
+ for _, v := range paths {
+ if _, ok := pathToID[v]; !ok {
+ values = append(values, v)
+ }
+ }
+ if len(values) > 0 {
+ // Insert data into the paths table.
+ pathCols := []string{"path"}
+ returningPathCols := []string{"id", "path"}
+ if err := db.BulkInsertReturning(ctx, "paths", pathCols, values,
+ database.OnConflictDoNothing, returningPathCols, collect); err != nil {
+ return nil, err
+ }
+ }
+ return pathToID, nil
+}
diff --git a/internal/postgres/path_test.go b/internal/postgres/path_test.go
index a9350e8..9ce6497 100644
--- a/internal/postgres/path_test.go
+++ b/internal/postgres/path_test.go
@@ -102,6 +102,7 @@
func TestUpsertPathConcurrently(t *testing.T) {
// Verify that we get no constraint violations or other errors when
// the same path is upserted multiple times concurrently.
+ t.Parallel()
testDB, release := acquire(t)
defer release()
ctx := context.Background()
@@ -129,3 +130,50 @@
}
}
}
+
+func TestUpsertPaths(t *testing.T) {
+ t.Parallel()
+ testDB, release := acquire(t)
+ defer release()
+ ctx := context.Background()
+
+ check := func(paths []string) {
+ got, err := upsertPathsInTx(ctx, testDB.db, paths)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkPathMap(t, got, paths)
+ }
+
+ check([]string{"a", "b", "c"})
+ check([]string{"b", "c", "d", "e"})
+}
+
+func checkPathMap(t *testing.T, got map[string]int, paths []string) {
+ t.Helper()
+ if g, w := len(got), len(paths); g != w {
+ t.Errorf("got %d paths, want %d", g, w)
+ return
+ }
+ for _, p := range paths {
+ g, ok := got[p]
+ if !ok {
+ t.Errorf("missing path %q", p)
+ } else if g == 0 {
+ t.Errorf("path %q has a 0 ID", p)
+ }
+ }
+}
+
+func upsertPathsInTx(ctx context.Context, db *database.DB, paths []string) (map[string]int, error) {
+ var m map[string]int
+ err := db.Transact(ctx, sql.LevelRepeatableRead, func(tx *database.DB) error {
+ var err error
+ m, err = upsertPaths(ctx, tx, paths)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return m, nil
+}