internal/database: add CopyUpsert
Add the CopyUpsert method, which uses an efficient
Postgres protocol to insert rows.
For this to work, we need the connection underlying a sql.Tx value.
Since sql.Tx doesn't expose its connection, we create one explicitly
in DB.transact.
Change-Id: Ie48ce7a4318f4531d4756f779943188a6f0fb6cd
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304631
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/database/copy.go b/internal/database/copy.go
new file mode 100644
index 0000000..90dcc07
--- /dev/null
+++ b/internal/database/copy.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package database
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
+ "golang.org/x/pkgsite/internal/derrors"
+ "golang.org/x/pkgsite/internal/log"
+)
+
+// CopyUpsert upserts rows into table using the pgx driver's CopyFrom method.
+// It returns an error if the underlying driver is not pgx.
+// columns is the list of columns to upsert.
+// src is the source of the rows to upsert.
+// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
+// constraint).
+//
+// CopyUpsert works by first creating a temporary table, populating it with
+// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
+// rows into the original table.
+func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
+ defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
+
+ if !db.InTransaction() {
+ return errors.New("not in a transaction")
+ }
+
+ return db.conn.Raw(func(c interface{}) error {
+ if w, ok := c.(*wrapConn); ok {
+ c = w.underlying
+ }
+ stdConn, ok := c.(*stdlib.Conn)
+ if !ok {
+ return fmt.Errorf("DB driver is not pgx or wrapper; conn type is %T", c)
+ }
+ conn := stdConn.Conn()
+ tempTable := fmt.Sprintf("__%s_copy", table)
+ stmt := fmt.Sprintf(`
+ DROP TABLE IF EXISTS %s;
+ CREATE TEMP TABLE %[1]s AS SELECT * FROM %s LIMIT 0
+ `, tempTable, table)
+ _, err = conn.Exec(ctx, stmt)
+ if err != nil {
+ return err
+ }
+ start := time.Now()
+ n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
+ if err != nil {
+ return err
+ }
+ log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
+ conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+ query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
+
+ defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
+ start = time.Now()
+ ctag, err := conn.Exec(ctx, query)
+ if err != nil {
+ return err
+ }
+ log.Debugf(ctx, "CopyUpsert(%q): upserted %d rows in %s", table, ctag.RowsAffected(), time.Since(start))
+ return nil
+ })
+}
+
+func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
+ cols := strings.Join(columns, ", ")
+ return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
+}
+
+// A RowItem is a row of values or an error.
+type RowItem struct {
+ Values []interface{}
+ Err error
+}
+
+// CopyFromChan returns a CopyFromSource that gets its rows from a channel.
+func CopyFromChan(c <-chan RowItem) pgx.CopyFromSource {
+ return &chanCopySource{c: c}
+}
+
+type chanCopySource struct {
+ c <-chan RowItem
+ next RowItem
+}
+
+// Next implements CopyFromSource.Next.
+func (cs *chanCopySource) Next() bool {
+ if cs.next.Err != nil {
+ return false
+ }
+ var ok bool
+ cs.next, ok = <-cs.c
+ return ok
+}
+
+// Values implements CopyFromSource.Values.
+func (cs *chanCopySource) Values() ([]interface{}, error) {
+ return cs.next.Values, cs.next.Err
+}
+
+// Err implements CopyFromSource.Err.
+func (cs *chanCopySource) Err() error {
+ return cs.next.Err
+}
diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go
new file mode 100644
index 0000000..5c6f41b
--- /dev/null
+++ b/internal/database/copy_test.go
@@ -0,0 +1,68 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package database
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
+)
+
+func TestCopyUpsert(t *testing.T) {
+ ctx := context.Background()
+ conn, err := testDB.db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Raw(func(c interface{}) error {
+ if _, ok := c.(*stdlib.Conn); !ok {
+ t.Skip("skipping; DB driver not pgx")
+ }
+ return nil
+ })
+
+ for _, stmt := range []string{
+ `DROP TABLE IF EXISTS test_streaming_upsert`,
+ `CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
+ `INSERT INTO test_streaming_upsert (key, value) VALUES (1, 'foo'), (2, 'bar')`,
+ } {
+ if _, err := testDB.Exec(ctx, stmt); err != nil {
+ t.Fatal(err)
+ }
+ }
+ rows := [][]interface{}{
+ {3, "baz"}, // new row
+ {1, "moo"}, // replace "foo" with "moo"
+ }
+ err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
+ return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type row struct {
+ Key int
+ Value string
+ }
+
+ wantRows := []row{
+ {1, "moo"},
+ {2, "bar"},
+ {3, "baz"},
+ }
+ var gotRows []row
+ if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_streaming_upsert ORDER BY key`); err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(gotRows, wantRows) {
+ t.Errorf("got %v, want %v", gotRows, wantRows)
+ }
+
+}
diff --git a/internal/database/database.go b/internal/database/database.go
index 96e858f..312f83a 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -33,6 +33,7 @@
db *sql.DB
instanceID string
tx *sql.Tx
+ conn *sql.Conn // the Conn of the Tx, when tx != nil
opts sql.TxOptions // valid when tx != nil
mu sync.Mutex
maxRetries int // max times a single transaction was retried
@@ -237,9 +238,15 @@
if db.InTransaction() {
return errors.New("a DB Transact function was called on a DB already in a transaction")
}
- tx, err := db.db.BeginTx(ctx, opts)
+ conn, err := db.db.Conn(ctx)
if err != nil {
- return fmt.Errorf("db.BeginTx(): %w", err)
+ return err
+ }
+ defer conn.Close()
+
+ tx, err := conn.BeginTx(ctx, opts)
+ if err != nil {
+ return fmt.Errorf("conn.BeginTx(): %w", err)
}
defer func() {
if p := recover(); p != nil {
@@ -256,6 +263,7 @@
dbtx := New(db.db, db.instanceID)
dbtx.tx = tx
+ dbtx.conn = conn
dbtx.opts = *opts
defer dbtx.logTransaction(ctx)(&err)
if err := txFunc(dbtx); err != nil {