internal/database: support bulk upsert

Add DB.BulkUpsert, which adds an ON CONFLICT clause to the INSERT
that replaces existing column values.

Change-Id: I59f36be0bcb0c0854f42da489e265f2a1396c439
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/766360
Reviewed-by: Julie Qiu <julieqiu@google.com>
diff --git a/internal/database/database.go b/internal/database/database.go
index 25c90b9..48d0a60 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -219,10 +219,11 @@
 const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
 
 // BulkInsert constructs and executes a multi-value insert statement. The
-// query is constructed using the format: INSERT TO <table> (<columns>) VALUES
-// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it
-// append ON CONFLICT DO NOTHING to the end of the query. The query is executed
-// using a PREPARE statement with the provided values.
+// query is constructed using the format:
+//   INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
+// If conflictAction is not empty, it is appended to the statement.
+//
+// The query is executed using a PREPARE statement with the provided values.
 func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) {
 	defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
 		table, columns, len(values), conflictAction)
@@ -244,6 +245,21 @@
 	return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
 }
 
+// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of
+// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO
+// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for
+// every column c.
+func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns []string) error {
+	conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+	return db.BulkInsert(ctx, table, columns, values, conflictAction)
+}
+
+// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
+func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error {
+	conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+	return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc)
+}
+
 func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
 	if remainder := len(values) % len(columns); remainder != 0 {
 		return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
@@ -338,6 +354,16 @@
 	return b.String()
 }
 
+func buildUpsertConflictAction(columns, conflictColumns []string) string {
+	var sets []string
+	for _, c := range columns {
+		sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c))
+	}
+	return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
+		strings.Join(conflictColumns, ", "),
+		strings.Join(sets, ", "))
+}
+
 // maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
 // Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
 // to a reasonable size.)
diff --git a/internal/database/database_test.go b/internal/database/database_test.go
index 2ff6f9d..cb6acbd 100644
--- a/internal/database/database_test.go
+++ b/internal/database/database_test.go
@@ -83,14 +83,6 @@
 		},
 		{
 
-			name:           "test-conflict-no-action-true",
-			columns:        []string{"colA"},
-			values:         []interface{}{"valueA", "valueA"},
-			conflictAction: OnConflictDoNothing,
-			wantCount:      1,
-		},
-		{
-
 			name:         "insert-returning",
 			columns:      []string{"colA", "colB"},
 			values:       []interface{}{"valueA1", "valueB1", "valueA2", "valueB2"},
@@ -99,13 +91,21 @@
 		},
 		{
 
-			name:    "test-conflict-no-action-false",
+			name:    "test-conflict",
 			columns: []string{"colA"},
 			values:  []interface{}{"valueA", "valueA"},
 			wantErr: true,
 		},
 		{
 
+			name:           "test-conflict-do-nothing",
+			columns:        []string{"colA"},
+			values:         []interface{}{"valueA", "valueA"},
+			conflictAction: OnConflictDoNothing,
+			wantCount:      1,
+		},
+		{
+
 			// This should execute the statement
 			// INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)');
 			// which will insert a row with path value:
@@ -219,6 +219,48 @@
 	}
 }
 
+func TestBulkUpsert(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), testTimeout*3)
+	defer cancel()
+	if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_replace (C1 int PRIMARY KEY, C2 int);`); err != nil {
+		t.Fatal(err)
+	}
+	for _, values := range [][]interface{}{
+		{2, 4, 4, 8},                 // First, insert some rows.
+		{1, -1, 2, -2, 3, -3, 4, -4}, // Then replace those rows while inserting others.
+	} {
+		err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+			return tx.BulkUpsert(ctx, "test_replace", []string{"C1", "C2"}, values, []string{"C1"})
+		})
+		if err != nil {
+			t.Fatal(err)
+		}
+		var got []interface{}
+		err = testDB.RunQuery(ctx, `SELECT C1, C2 FROM test_replace ORDER BY C1`, func(rows *sql.Rows) error {
+			var a, b int
+			if err := rows.Scan(&a, &b); err != nil {
+				return err
+			}
+			got = append(got, a, b)
+			return nil
+		})
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !cmp.Equal(got, values) {
+			t.Errorf("%v: got %v, want %v", values, got, values)
+		}
+	}
+}
+
+func TestBuildUpsertConflictAction(t *testing.T) {
+	got := buildUpsertConflictAction([]string{"a", "b"}, []string{"c", "d"})
+	want := "ON CONFLICT (c, d) DO UPDATE SET a=excluded.a, b=excluded.b"
+	if got != want {
+		t.Errorf("got %q, want %q", got, want)
+	}
+}
+
 func TestDBAfterTransactFails(t *testing.T) {
 	ctx := context.Background()
 	var tx *DB