blob: b564fd1b24a552c81348b79ee7dc1251a31325f2 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/log"
)
// CopyInsert insert rows into table using the pgx driver's CopyFrom method.
// It returns an error if the underlying driver is not pgx.
// columns is the list of columns to upsert.
// src is the source of the rows to upsert.
// If dropColumn is non-empty, that column will be dropped from the temporary
// table before copying. Use dropColumn for generated ID columns.
//
// CopyInsert works by first creating a temporary table, populating it with
// CopyFrom, and then running an INSERT...SELECT... to insert its rows into the
// original table.
func (db *DB) CopyInsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, dropColumn string) (err error) {
defer derrors.Wrap(&err, "CopyInsert(%q)", table)
return db.copy(ctx, table, columns, src, dropColumn, "")
}
// CopyUpsert upserts rows into table using the pgx driver's CopyFrom method.
// It returns an error if the underlying driver is not pgx.
// columns is the list of columns to upsert.
// src is the source of the rows to upsert.
// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
// constraint).
// If dropColumn is non-empty, that column will be dropped from the temporary
// table before copying. Use dropColumn for generated ID columns.
//
// CopyUpsert works by first creating a temporary table, populating it with
// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
// rows into the original table.
func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) {
defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
return db.copy(ctx, table, columns, src, dropColumn, buildUpsertConflictAction(columns, conflictColumns))
}
func (db *DB) copy(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, dropColumn, conflictAction string) (err error) {
if !db.InTransaction() {
return errors.New("not in a transaction")
}
return db.WithPGXConn(func(conn *pgx.Conn) error {
tempTable := fmt.Sprintf("__%s_copy", table)
stmt := fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP;
`, tempTable, table)
if dropColumn != "" {
stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn)
}
_, err = conn.Exec(ctx, stmt)
if err != nil {
return err
}
start := time.Now()
n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
if err != nil {
return fmt.Errorf("CopyFrom: %w", err)
}
if !QueryLoggingDisabled {
log.Debugf(ctx, "DB.copy(%q): copied %d rows in %s", table, n, time.Since(start))
}
cols := strings.Join(columns, ", ")
query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
start = time.Now()
ctag, err := conn.Exec(ctx, query)
if err != nil {
return err
}
if !QueryLoggingDisabled {
log.Debugf(ctx, "DB.copy(%q): affected %d rows in %s", table, ctag.RowsAffected(), time.Since(start))
}
return nil
})
}
func (db *DB) WithPGXConn(f func(conn *pgx.Conn) error) error {
if !db.InTransaction() {
return errors.New("not in a transaction")
}
return db.conn.Raw(func(c interface{}) error {
if w, ok := c.(*wrapConn); ok {
c = w.underlying
}
stdConn, ok := c.(*stdlib.Conn)
if !ok {
return fmt.Errorf("DB driver is not pgx or wrapper; conn type is %T", c)
}
return f(stdConn.Conn())
})
}
// A RowItem is a row of values or an error.
type RowItem struct {
Values []interface{}
Err error
}
// CopyFromChan returns a CopyFromSource that gets its rows from a channel.
func CopyFromChan(c <-chan RowItem) pgx.CopyFromSource {
return &chanCopySource{c: c}
}
type chanCopySource struct {
c <-chan RowItem
next RowItem
}
// Next implements CopyFromSource.Next.
func (cs *chanCopySource) Next() bool {
if cs.next.Err != nil {
return false
}
var ok bool
cs.next, ok = <-cs.c
return ok
}
// Values implements CopyFromSource.Values.
func (cs *chanCopySource) Values() ([]interface{}, error) {
return cs.next.Values, cs.next.Err
}
// Err implements CopyFromSource.Err.
func (cs *chanCopySource) Err() error {
return cs.next.Err
}