internal/database: support bulk upsert
Add DB.BulkUpsert, which adds an ON CONFLICT clause to the INSERT
that replaces existing column values.
Change-Id: I59f36be0bcb0c0854f42da489e265f2a1396c439
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/766360
Reviewed-by: Julie Qiu <julieqiu@google.com>
diff --git a/internal/database/database.go b/internal/database/database.go
index 25c90b9..48d0a60 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -219,10 +219,11 @@
const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
// BulkInsert constructs and executes a multi-value insert statement. The
-// query is constructed using the format: INSERT TO <table> (<columns>) VALUES
-// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it
-// append ON CONFLICT DO NOTHING to the end of the query. The query is executed
-// using a PREPARE statement with the provided values.
+// query is constructed using the format:
+// INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
+// If conflictAction is not empty, it is appended to the statement.
+//
+// The query is executed using a PREPARE statement with the provided values.
func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
@@ -244,6 +245,21 @@
return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
}
+// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of
+// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO
+// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for
+// every column c.
+func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns []string) error {
+ conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+ return db.BulkInsert(ctx, table, columns, values, conflictAction)
+}
+
+// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
+func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error {
+ conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+ return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc)
+}
+
func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
if remainder := len(values) % len(columns); remainder != 0 {
return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
@@ -338,6 +354,16 @@
return b.String()
}
+func buildUpsertConflictAction(columns, conflictColumns []string) string {
+ var sets []string
+ for _, c := range columns {
+ sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c))
+ }
+ return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
+ strings.Join(conflictColumns, ", "),
+ strings.Join(sets, ", "))
+}
+
// maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
// Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
// to a reasonable size.)
diff --git a/internal/database/database_test.go b/internal/database/database_test.go
index 2ff6f9d..cb6acbd 100644
--- a/internal/database/database_test.go
+++ b/internal/database/database_test.go
@@ -83,14 +83,6 @@
},
{
- name: "test-conflict-no-action-true",
- columns: []string{"colA"},
- values: []interface{}{"valueA", "valueA"},
- conflictAction: OnConflictDoNothing,
- wantCount: 1,
- },
- {
-
name: "insert-returning",
columns: []string{"colA", "colB"},
values: []interface{}{"valueA1", "valueB1", "valueA2", "valueB2"},
@@ -99,13 +91,21 @@
},
{
- name: "test-conflict-no-action-false",
+ name: "test-conflict",
columns: []string{"colA"},
values: []interface{}{"valueA", "valueA"},
wantErr: true,
},
{
+ name: "test-conflict-do-nothing",
+ columns: []string{"colA"},
+ values: []interface{}{"valueA", "valueA"},
+ conflictAction: OnConflictDoNothing,
+ wantCount: 1,
+ },
+ {
+
// This should execute the statement
// INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)');
// which will insert a row with path value:
@@ -219,6 +219,48 @@
}
}
+func TestBulkUpsert(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), testTimeout*3)
+ defer cancel()
+ if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_replace (C1 int PRIMARY KEY, C2 int);`); err != nil {
+ t.Fatal(err)
+ }
+ for _, values := range [][]interface{}{
+ {2, 4, 4, 8}, // First, insert some rows.
+ {1, -1, 2, -2, 3, -3, 4, -4}, // Then replace those rows while inserting others.
+ } {
+ err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+ return tx.BulkUpsert(ctx, "test_replace", []string{"C1", "C2"}, values, []string{"C1"})
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ var got []interface{}
+ err = testDB.RunQuery(ctx, `SELECT C1, C2 FROM test_replace ORDER BY C1`, func(rows *sql.Rows) error {
+ var a, b int
+ if err := rows.Scan(&a, &b); err != nil {
+ return err
+ }
+ got = append(got, a, b)
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(got, values) {
+ t.Errorf("%v: got %v, want %v", values, got, values)
+ }
+ }
+}
+
+func TestBuildUpsertConflictAction(t *testing.T) {
+ got := buildUpsertConflictAction([]string{"a", "b"}, []string{"c", "d"})
+ want := "ON CONFLICT (c, d) DO UPDATE SET a=excluded.a, b=excluded.b"
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
func TestDBAfterTransactFails(t *testing.T) {
ctx := context.Background()
var tx *DB