internal/database: add CopyUpsert

Add the CopyUpsert method, which uses an efficient
Postgres protocol to insert rows.

For this to work, we need the connection underlying a sql.Tx value.
Since sql.Tx doesn't expose its connection, we create one explicitly
in DB.transact.

Change-Id: Ie48ce7a4318f4531d4756f779943188a6f0fb6cd
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304631
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/database/copy.go b/internal/database/copy.go
new file mode 100644
index 0000000..90dcc07
--- /dev/null
+++ b/internal/database/copy.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package database
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"strings"
+	"time"
+
+	"github.com/jackc/pgx/v4"
+	"github.com/jackc/pgx/v4/stdlib"
+	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/log"
+)
+
+// CopyUpsert upserts rows into table using the pgx driver's CopyFrom method.
+// It returns an error if the underlying driver is not pgx.
+// columns is the list of columns to upsert.
+// src is the source of the rows to upsert.
+// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
+// constraint).
+//
+// CopyUpsert works by first creating a temporary table, populating it with
+// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
+// rows into the original table.
+func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
+	defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
+
+	if !db.InTransaction() {
+		return errors.New("not in a transaction")
+	}
+
+	return db.conn.Raw(func(c interface{}) error {
+		if w, ok := c.(*wrapConn); ok {
+			c = w.underlying
+		}
+		stdConn, ok := c.(*stdlib.Conn)
+		if !ok {
+			return fmt.Errorf("DB driver is not pgx or wrapper; conn type is %T", c)
+		}
+		conn := stdConn.Conn()
+		tempTable := fmt.Sprintf("__%s_copy", table)
+		stmt := fmt.Sprintf(`
+			DROP TABLE IF EXISTS %s;
+			CREATE TEMP TABLE %[1]s AS SELECT * FROM %s LIMIT 0
+		`, tempTable, table)
+		_, err = conn.Exec(ctx, stmt)
+		if err != nil {
+			return err
+		}
+		start := time.Now()
+		n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
+		if err != nil {
+			return err
+		}
+		log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
+		conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+		query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
+
+		defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
+		start = time.Now()
+		ctag, err := conn.Exec(ctx, query)
+		if err != nil {
+			return err
+		}
+		log.Debugf(ctx, "CopyUpsert(%q): upserted %d rows in %s", table, ctag.RowsAffected(), time.Since(start))
+		return nil
+	})
+}
+
+func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
+	cols := strings.Join(columns, ", ")
+	return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
+}
+
+// A RowItem is a row of values or an error.
+type RowItem struct {
+	Values []interface{}
+	Err    error
+}
+
+// CopyFromChan returns a CopyFromSource that gets its rows from a channel.
+func CopyFromChan(c <-chan RowItem) pgx.CopyFromSource {
+	return &chanCopySource{c: c}
+}
+
+type chanCopySource struct {
+	c    <-chan RowItem
+	next RowItem
+}
+
+// Next implements CopyFromSource.Next.
+func (cs *chanCopySource) Next() bool {
+	if cs.next.Err != nil {
+		return false
+	}
+	var ok bool
+	cs.next, ok = <-cs.c
+	return ok
+}
+
+// Values implements CopyFromSource.Values.
+func (cs *chanCopySource) Values() ([]interface{}, error) {
+	return cs.next.Values, cs.next.Err
+}
+
+// Err implements CopyFromSource.Err.
+func (cs *chanCopySource) Err() error {
+	return cs.next.Err
+}
diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go
new file mode 100644
index 0000000..5c6f41b
--- /dev/null
+++ b/internal/database/copy_test.go
@@ -0,0 +1,68 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package database
+
+import (
+	"context"
+	"database/sql"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/jackc/pgx/v4"
+	"github.com/jackc/pgx/v4/stdlib"
+)
+
+func TestCopyUpsert(t *testing.T) {
+	ctx := context.Background()
+	conn, err := testDB.db.Conn(ctx)
+	if err != nil {
+		t.Fatal(err)
+	}
+	conn.Raw(func(c interface{}) error {
+		if _, ok := c.(*stdlib.Conn); !ok {
+			t.Skip("skipping; DB driver not pgx")
+		}
+		return nil
+	})
+
+	for _, stmt := range []string{
+		`DROP TABLE IF EXISTS test_streaming_upsert`,
+		`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
+		`INSERT INTO test_streaming_upsert (key, value) VALUES (1, 'foo'), (2, 'bar')`,
+	} {
+		if _, err := testDB.Exec(ctx, stmt); err != nil {
+			t.Fatal(err)
+		}
+	}
+	rows := [][]interface{}{
+		{3, "baz"}, // new row
+		{1, "moo"}, // replace "foo" with "moo"
+	}
+	err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+		return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type row struct {
+		Key   int
+		Value string
+	}
+
+	wantRows := []row{
+		{1, "moo"},
+		{2, "bar"},
+		{3, "baz"},
+	}
+	var gotRows []row
+	if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_streaming_upsert ORDER BY key`); err != nil {
+		t.Fatal(err)
+	}
+	if !cmp.Equal(gotRows, wantRows) {
+		t.Errorf("got %v, want %v", gotRows, wantRows)
+	}
+
+}
diff --git a/internal/database/database.go b/internal/database/database.go
index 96e858f..312f83a 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -33,6 +33,7 @@
 	db         *sql.DB
 	instanceID string
 	tx         *sql.Tx
+	conn       *sql.Conn     // the Conn of the Tx, when tx != nil
 	opts       sql.TxOptions // valid when tx != nil
 	mu         sync.Mutex
 	maxRetries int // max times a single transaction was retried
@@ -237,9 +238,15 @@
 	if db.InTransaction() {
 		return errors.New("a DB Transact function was called on a DB already in a transaction")
 	}
-	tx, err := db.db.BeginTx(ctx, opts)
+	conn, err := db.db.Conn(ctx)
 	if err != nil {
-		return fmt.Errorf("db.BeginTx(): %w", err)
+		return err
+	}
+	defer conn.Close()
+
+	tx, err := conn.BeginTx(ctx, opts)
+	if err != nil {
+		return fmt.Errorf("conn.BeginTx(): %w", err)
 	}
 	defer func() {
 		if p := recover(); p != nil {
@@ -256,6 +263,7 @@
 
 	dbtx := New(db.db, db.instanceID)
 	dbtx.tx = tx
+	dbtx.conn = conn
 	dbtx.opts = *opts
 	defer dbtx.logTransaction(ctx)(&err)
 	if err := txFunc(dbtx); err != nil {