internal/database: CopyUpsert: support dropping a column
You can't do a CopyFrom on a table with a generated column: postgres
complains about the column value being null. To fix, drop the column
on the temporary table.
Change-Id: Ia52f59af6d026b3fcdaafe3c7865a2eb85deb179
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/305830
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/database/copy.go b/internal/database/copy.go
index 74c5d00..7a092c0 100644
--- a/internal/database/copy.go
+++ b/internal/database/copy.go
@@ -23,11 +23,13 @@
// src is the source of the rows to upsert.
// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
// constraint).
+// If dropColumn is non-empty, that column will be dropped from the temporary
+// table before copying. Use dropColumn for generated ID columns.
//
// CopyUpsert works by first creating a temporary table, populating it with
// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
// rows into the original table.
-func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
+func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) {
defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
if !db.InTransaction() {
@@ -46,8 +48,11 @@
tempTable := fmt.Sprintf("__%s_copy", table)
stmt := fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
- CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP
+ CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP;
`, tempTable, table)
+ if dropColumn != "" {
+ stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn)
+ }
_, err = conn.Exec(ctx, stmt)
if err != nil {
return err
@@ -55,12 +60,12 @@
start := time.Now()
n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
if err != nil {
- return err
+ return fmt.Errorf("CopyFrom: %w", err)
}
log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
- query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
-
+ cols := strings.Join(columns, ", ")
+ query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
start = time.Now()
ctag, err := conn.Exec(ctx, query)
@@ -72,11 +77,6 @@
})
}
-func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
- cols := strings.Join(columns, ", ")
- return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
-}
-
// A RowItem is a row of values or an error.
type RowItem struct {
Values []interface{}
diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go
index 5c6f41b..238af8e 100644
--- a/internal/database/copy_test.go
+++ b/internal/database/copy_test.go
@@ -15,18 +15,8 @@
)
func TestCopyUpsert(t *testing.T) {
+ pgxOnly(t)
ctx := context.Background()
- conn, err := testDB.db.Conn(ctx)
- if err != nil {
- t.Fatal(err)
- }
- conn.Raw(func(c interface{}) error {
- if _, ok := c.(*stdlib.Conn); !ok {
- t.Skip("skipping; DB driver not pgx")
- }
- return nil
- })
-
for _, stmt := range []string{
`DROP TABLE IF EXISTS test_streaming_upsert`,
`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
@@ -40,8 +30,8 @@
{3, "baz"}, // new row
{1, "moo"}, // replace "foo" with "moo"
}
- err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
- return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
+ err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+ return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "")
})
if err != nil {
t.Fatal(err)
@@ -66,3 +56,57 @@
}
}
+
+func TestCopyUpsertGeneratedColumn(t *testing.T) {
+ pgxOnly(t)
+ ctx := context.Background()
+ stmt := `
+ DROP TABLE IF EXISTS test_copy_gen;
+ CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key));
+ INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')`
+ if _, err := testDB.Exec(ctx, stmt); err != nil {
+ t.Fatal(err)
+ }
+
+ rows := [][]interface{}{
+ {13, "baz"}, // new row
+ {11, "moo"}, // replace "foo" with "moo"
+ }
+ err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+ return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id")
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type row struct {
+ ID int64
+ Key int
+ Value string
+ }
+ wantRows := []row{
+ {1, 11, "moo"},
+ {2, 12, "bar"},
+ {3, 13, "baz"},
+ }
+ var gotRows []row
+ if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_copy_gen ORDER BY ID`); err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(gotRows, wantRows) {
+ t.Errorf("got %v, want %v", gotRows, wantRows)
+ }
+}
+
+func pgxOnly(t *testing.T) {
+ conn, err := testDB.db.Conn(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Raw(func(c interface{}) error {
+ if _, ok := c.(*stdlib.Conn); !ok {
+ t.Skip("skipping; DB driver not pgx")
+ }
+ return nil
+ })
+}