internal/postgres: refactor path insertion

Move the module-independent part of insertPaths to a separate
function, upsertPaths, and put it in the paths.go file so it is near
it's single-path cousin.

Write a simple test for it.

Change-Id: I3f4a6d3e21836b2840153354c880f8be823b466b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/296551
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/postgres/insert_module.go b/internal/postgres/insert_module.go
index 02dc81b..51f186c 100644
--- a/internal/postgres/insert_module.go
+++ b/internal/postgres/insert_module.go
@@ -391,20 +391,6 @@
 }
 
 func insertPaths(ctx context.Context, db *database.DB, m *internal.Module) (pathToID map[string]int, err error) {
-	// Add new unit paths to the paths table.
-	pathToID = map[string]int{}
-	collect := func(rows *sql.Rows) error {
-		var (
-			pathID int
-			path   string
-		)
-		if err := rows.Scan(&pathID, &path); err != nil {
-			return err
-		}
-		pathToID[path] = pathID
-		return nil
-	}
-
 	// Read all existing paths for this module, to avoid a large bulk upsert.
 	// (We've seen these bulk upserts hang for so long that they time out (10
 	// minutes)).
@@ -418,28 +404,7 @@
 	for p := range curPathsSet {
 		curPaths = append(curPaths, p)
 	}
-	if err := db.RunQuery(ctx, `SELECT id, path FROM paths WHERE path = ANY($1)`,
-		collect, pq.Array(curPaths)); err != nil {
-		return nil, err
-	}
-
-	// Insert any unit paths that we don't already have.
-	var values []interface{}
-	for _, v := range curPaths {
-		if _, ok := pathToID[v]; !ok {
-			values = append(values, v)
-		}
-	}
-	if len(values) > 0 {
-		// Insert data into the paths table.
-		pathCols := []string{"path"}
-		returningPathCols := []string{"id", "path"}
-		if err := db.BulkInsertReturning(ctx, "paths", pathCols, values,
-			database.OnConflictDoNothing, returningPathCols, collect); err != nil {
-			return nil, err
-		}
-	}
-	return pathToID, nil
+	return upsertPaths(ctx, db, curPaths)
 }
 
 func insertUnits(ctx context.Context, db *database.DB, unitValues []interface{}) (pathIDToUnitID map[int]int, err error) {
diff --git a/internal/postgres/path.go b/internal/postgres/path.go
index 778ef95..96d120c 100644
--- a/internal/postgres/path.go
+++ b/internal/postgres/path.go
@@ -12,6 +12,7 @@
 	"strconv"
 	"strings"
 
+	"github.com/lib/pq"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/database"
 	"golang.org/x/pkgsite/internal/derrors"
@@ -106,3 +107,46 @@
 	}
 	return id, nil
 }
+
+// upsertPaths adds all the paths to the paths table if they aren't already
+// there, and returns their ID either way.
+// It assumes it is running inside a transaction.
+func upsertPaths(ctx context.Context, db *database.DB, paths []string) (pathToID map[string]int, err error) {
+	defer derrors.WrapStack(&err, "upsertPaths(%d paths)", len(paths))
+
+	pathToID = map[string]int{}
+	collect := func(rows *sql.Rows) error {
+		var (
+			pathID int
+			path   string
+		)
+		if err := rows.Scan(&pathID, &path); err != nil {
+			return err
+		}
+		pathToID[path] = pathID
+		return nil
+	}
+
+	if err := db.RunQuery(ctx, `SELECT id, path FROM paths WHERE path = ANY($1)`,
+		collect, pq.Array(paths)); err != nil {
+		return nil, err
+	}
+
+	// Insert any unit paths that we don't already have.
+	var values []interface{}
+	for _, v := range paths {
+		if _, ok := pathToID[v]; !ok {
+			values = append(values, v)
+		}
+	}
+	if len(values) > 0 {
+		// Insert data into the paths table.
+		pathCols := []string{"path"}
+		returningPathCols := []string{"id", "path"}
+		if err := db.BulkInsertReturning(ctx, "paths", pathCols, values,
+			database.OnConflictDoNothing, returningPathCols, collect); err != nil {
+			return nil, err
+		}
+	}
+	return pathToID, nil
+}
diff --git a/internal/postgres/path_test.go b/internal/postgres/path_test.go
index a9350e8..9ce6497 100644
--- a/internal/postgres/path_test.go
+++ b/internal/postgres/path_test.go
@@ -102,6 +102,7 @@
 func TestUpsertPathConcurrently(t *testing.T) {
 	// Verify that we get no constraint violations or other errors when
 	// the same path is upserted multiple times concurrently.
+	t.Parallel()
 	testDB, release := acquire(t)
 	defer release()
 	ctx := context.Background()
@@ -129,3 +130,50 @@
 		}
 	}
 }
+
+func TestUpsertPaths(t *testing.T) {
+	t.Parallel()
+	testDB, release := acquire(t)
+	defer release()
+	ctx := context.Background()
+
+	check := func(paths []string) {
+		got, err := upsertPathsInTx(ctx, testDB.db, paths)
+		if err != nil {
+			t.Fatal(err)
+		}
+		checkPathMap(t, got, paths)
+	}
+
+	check([]string{"a", "b", "c"})
+	check([]string{"b", "c", "d", "e"})
+}
+
+func checkPathMap(t *testing.T, got map[string]int, paths []string) {
+	t.Helper()
+	if g, w := len(got), len(paths); g != w {
+		t.Errorf("got %d paths, want %d", g, w)
+		return
+	}
+	for _, p := range paths {
+		g, ok := got[p]
+		if !ok {
+			t.Errorf("missing path %q", p)
+		} else if g == 0 {
+			t.Errorf("path %q has a 0 ID", p)
+		}
+	}
+}
+
+func upsertPathsInTx(ctx context.Context, db *database.DB, paths []string) (map[string]int, error) {
+	var m map[string]int
+	err := db.Transact(ctx, sql.LevelRepeatableRead, func(tx *database.DB) error {
+		var err error
+		m, err = upsertPaths(ctx, tx, paths)
+		return err
+	})
+	if err != nil {
+		return nil, err
+	}
+	return m, nil
+}