internal/database: CopyUpsert: support dropping a column

You can't do a CopyFrom on a table with a generated column: postgres
complains about the column value being null. To fix, drop the column
on the temporary table.

Change-Id: Ia52f59af6d026b3fcdaafe3c7865a2eb85deb179
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/305830
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/database/copy.go b/internal/database/copy.go
index 74c5d00..7a092c0 100644
--- a/internal/database/copy.go
+++ b/internal/database/copy.go
@@ -23,11 +23,13 @@
 // src is the source of the rows to upsert.
 // conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
 // constraint).
+// If dropColumn is non-empty, that column will be dropped from the temporary
+// table before copying. Use dropColumn for generated ID columns.
 //
 // CopyUpsert works by first creating a temporary table, populating it with
 // CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
 // rows into the original table.
-func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
+func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) {
 	defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
 
 	if !db.InTransaction() {
@@ -46,8 +48,11 @@
 		tempTable := fmt.Sprintf("__%s_copy", table)
 		stmt := fmt.Sprintf(`
 			DROP TABLE IF EXISTS %s;
-			CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP
+			CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP;
 		`, tempTable, table)
+		if dropColumn != "" {
+			stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn)
+		}
 		_, err = conn.Exec(ctx, stmt)
 		if err != nil {
 			return err
@@ -55,12 +60,12 @@
 		start := time.Now()
 		n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
 		if err != nil {
-			return err
+			return fmt.Errorf("CopyFrom: %w", err)
 		}
 		log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
 		conflictAction := buildUpsertConflictAction(columns, conflictColumns)
-		query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
-
+		cols := strings.Join(columns, ", ")
+		query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
 		defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
 		start = time.Now()
 		ctag, err := conn.Exec(ctx, query)
@@ -72,11 +77,6 @@
 	})
 }
 
-func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
-	cols := strings.Join(columns, ", ")
-	return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
-}
-
 // A RowItem is a row of values or an error.
 type RowItem struct {
 	Values []interface{}
diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go
index 5c6f41b..238af8e 100644
--- a/internal/database/copy_test.go
+++ b/internal/database/copy_test.go
@@ -15,18 +15,8 @@
 )
 
 func TestCopyUpsert(t *testing.T) {
+	pgxOnly(t)
 	ctx := context.Background()
-	conn, err := testDB.db.Conn(ctx)
-	if err != nil {
-		t.Fatal(err)
-	}
-	conn.Raw(func(c interface{}) error {
-		if _, ok := c.(*stdlib.Conn); !ok {
-			t.Skip("skipping; DB driver not pgx")
-		}
-		return nil
-	})
-
 	for _, stmt := range []string{
 		`DROP TABLE IF EXISTS test_streaming_upsert`,
 		`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
@@ -40,8 +30,8 @@
 		{3, "baz"}, // new row
 		{1, "moo"}, // replace "foo" with "moo"
 	}
-	err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
-		return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
+	err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+		return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "")
 	})
 	if err != nil {
 		t.Fatal(err)
@@ -66,3 +56,57 @@
 	}
 
 }
+
+func TestCopyUpsertGeneratedColumn(t *testing.T) {
+	pgxOnly(t)
+	ctx := context.Background()
+	stmt := `
+		DROP TABLE IF EXISTS test_copy_gen;
+		CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key));
+		INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')`
+	if _, err := testDB.Exec(ctx, stmt); err != nil {
+		t.Fatal(err)
+	}
+
+	rows := [][]interface{}{
+		{13, "baz"}, // new row
+		{11, "moo"}, // replace "foo" with "moo"
+	}
+	err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+		return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id")
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type row struct {
+		ID    int64
+		Key   int
+		Value string
+	}
+	wantRows := []row{
+		{1, 11, "moo"},
+		{2, 12, "bar"},
+		{3, 13, "baz"},
+	}
+	var gotRows []row
+	if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_copy_gen ORDER BY ID`); err != nil {
+		t.Fatal(err)
+	}
+	if !cmp.Equal(gotRows, wantRows) {
+		t.Errorf("got %v, want %v", gotRows, wantRows)
+	}
+}
+
+func pgxOnly(t *testing.T) {
+	conn, err := testDB.db.Conn(context.Background())
+	if err != nil {
+		t.Fatal(err)
+	}
+	conn.Raw(func(c interface{}) error {
+		if _, ok := c.(*stdlib.Conn); !ok {
+			t.Skip("skipping; DB driver not pgx")
+		}
+		return nil
+	})
+}