blob: 305716324c26ddabcec7953ab31c0be4ebe97b0d [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package database adds some useful functionality to a sql.DB.
// It is independent of the database driver and the
// DB schema.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"regexp"
"strings"
"sync"
"time"
"github.com/jackc/pgconn"
"github.com/lib/pq"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/log"
)
// DB wraps a sql.DB. The methods it exports correspond closely to those of
// sql.DB. They enhance the original by requiring a context argument, and by
// logging the query and any resulting errors.
//
// A DB may represent a transaction. If so, its execution and query methods
// operate within the transaction.
type DB struct {
db *sql.DB
instanceID string
tx *sql.Tx
conn *sql.Conn // the Conn of the Tx, when tx != nil
opts sql.TxOptions // valid when tx != nil
mu sync.Mutex
maxRetries int // max times a single transaction was retried
}
// Open creates a new DB for the given connection string.
func Open(driverName, dbinfo, instanceID string) (_ *DB, err error) {
defer derrors.Wrap(&err, "database.Open(%q, %q)",
driverName, redactPassword(dbinfo))
db, err := sql.Open(driverName, dbinfo)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return New(db, instanceID), nil
}
// New creates a new DB from a sql.DB.
func New(db *sql.DB, instanceID string) *DB {
return &DB{db: db, instanceID: instanceID}
}
func (db *DB) Ping() error {
return db.db.Ping()
}
func (db *DB) InTransaction() bool {
return db.tx != nil
}
func (db *DB) IsRetryable() bool {
return db.tx != nil && isRetryable(db.opts.Isolation)
}
var passwordRegexp = regexp.MustCompile(`password=\S+`)
func redactPassword(dbinfo string) string {
return passwordRegexp.ReplaceAllLiteralString(dbinfo, "password=REDACTED")
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.db.Close()
}
// Exec executes a SQL statement and returns the number of rows it affected.
func (db *DB) Exec(ctx context.Context, query string, args ...any) (_ int64, err error) {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(&err)
res, err := db.execResult(ctx, query, args...)
if err != nil {
return 0, err
}
n, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("RowsAffected: %v", err)
}
return n, nil
}
// execResult executes a SQL statement and returns a sql.Result.
func (db *DB) execResult(ctx context.Context, query string, args ...any) (res sql.Result, err error) {
if db.tx != nil {
return db.tx.ExecContext(ctx, query, args...)
}
return db.db.ExecContext(ctx, query, args...)
}
// Query runs the DB query.
func (db *DB) Query(ctx context.Context, query string, args ...any) (_ *sql.Rows, err error) {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(&err)
if db.tx != nil {
return db.tx.QueryContext(ctx, query, args...)
}
return db.db.QueryContext(ctx, query, args...)
}
// QueryRow runs the query and returns a single row.
func (db *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(nil)
start := time.Now()
defer func() {
if ctx.Err() != nil {
d, _ := ctx.Deadline()
msg := fmt.Sprintf("args=%v; elapsed=%q, start=%q, deadline=%q", args, time.Since(start), start, d)
log.Errorf(ctx, "QueryRow context error: %v "+msg, ctx.Err())
}
}()
if db.tx != nil {
return db.tx.QueryRowContext(ctx, query, args...)
}
return db.db.QueryRowContext(ctx, query, args...)
}
func (db *DB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
defer logQuery(ctx, "preparing "+query, nil, db.instanceID, db.IsRetryable())
if db.tx != nil {
return db.tx.PrepareContext(ctx, query)
}
return db.db.PrepareContext(ctx, query)
}
// RunQuery executes query, then calls f on each row. It stops when there are no
// more rows or f returns a non-nil error.
func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...any) error {
rows, err := db.Query(ctx, query, params...)
if err != nil {
return err
}
_, err = processRows(rows, f)
return err
}
func processRows(rows *sql.Rows, f func(*sql.Rows) error) (int, error) {
defer rows.Close()
n := 0
for rows.Next() {
n++
if err := f(rows); err != nil {
return n, err
}
}
return n, rows.Err()
}
// RunQueryIncrementally executes query, then calls f on each row. It fetches
// rows in groups of size batchSize. It stops when there are no more rows, or
// when f returns io.EOF.
func (db *DB) RunQueryIncrementally(ctx context.Context, query string, batchSize int, f func(*sql.Rows) error, params ...any) (err error) {
// Run in a transaction, because cursors require one.
return db.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
// Declare a cursor and associate it with the query.
// It will be closed when the transaction commits.
_, err = tx.Exec(ctx, fmt.Sprintf(`DECLARE c CURSOR FOR %s`, query), params...)
if err != nil {
return err
}
for {
// Fetch batchSize rows and process them.
rows, err := tx.Query(ctx, fmt.Sprintf(`FETCH %d FROM c`, batchSize))
if err != nil {
return err
}
n, err := processRows(rows, f)
// Stop if there were no rows, or the processing function returned io.EOF.
if n == 0 || err == io.EOF {
return nil
}
if err != nil {
return err
}
}
})
}
// Transact executes the given function in the context of a SQL transaction at
// the given isolation level, rolling back the transaction if the function
// panics or returns an error.
//
// The given function is called with a DB that is associated with a transaction.
// The DB should be used only inside the function; if it is used to access the
// database after the function returns, the calls will return errors.
//
// If the isolation level requires it, Transact will retry the transaction upon
// serialization failure, so txFunc may be called more than once.
func (db *DB) Transact(ctx context.Context, iso sql.IsolationLevel, txFunc func(*DB) error) (err error) {
defer derrors.Wrap(&err, "Transact(%s)", iso)
// For the levels which require retry, see
// https://www.postgresql.org/docs/11/transaction-iso.html.
opts := &sql.TxOptions{Isolation: iso}
if isRetryable(iso) {
return db.transactWithRetry(ctx, opts, txFunc)
}
return db.transact(ctx, opts, txFunc)
}
func isRetryable(iso sql.IsolationLevel) bool {
return iso == sql.LevelRepeatableRead || iso == sql.LevelSerializable
}
// serializationFailureCode is the Postgres error code returned when a serializable
// transaction fails because it would violate serializability.
// See https://www.postgresql.org/docs/current/errcodes-appendix.html.
const serializationFailureCode = "40001"
func (db *DB) transactWithRetry(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB) error) (err error) {
defer derrors.Wrap(&err, "transactWithRetry(%v)", opts)
// Retry on serialization failure, up to some max.
// See https://www.postgresql.org/docs/11/transaction-iso.html.
const maxRetries = 10
sleepDur := 125 * time.Millisecond
for i := 0; i <= maxRetries; i++ {
err = db.transact(ctx, opts, txFunc)
if isSerializationFailure(err) {
db.mu.Lock()
if i > db.maxRetries {
db.maxRetries = i
}
db.mu.Unlock()
log.Debugf(ctx, "serialization failure; retrying after %s", sleepDur)
time.Sleep(sleepDur)
sleepDur *= 2
continue
}
if err != nil {
log.Debugf(ctx, "transactWithRetry: error type %T: %[1]v", err)
if strings.Contains(err.Error(), serializationFailureCode) {
return fmt.Errorf("error text has %q but not recognized as serialization failure: type %T, err %v",
serializationFailureCode, err, err)
}
}
if i > 0 {
log.Debugf(ctx, "retried serializable transaction %d time(s)", i)
}
return err
}
return fmt.Errorf("reached max number of tries due to serialization failure (%d)", maxRetries)
}
func isSerializationFailure(err error) bool {
// The underlying error type depends on the driver. Try both pq and pgx types.
var perr *pq.Error
if errors.As(err, &perr) && perr.Code == serializationFailureCode {
return true
}
var gerr *pgconn.PgError
if errors.As(err, &gerr) && gerr.Code == serializationFailureCode {
return true
}
return false
}
func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB) error) (err error) {
if db.InTransaction() {
return errors.New("a DB Transact function was called on a DB already in a transaction")
}
conn, err := db.db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("conn.BeginTx(): %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
} else {
if txErr := tx.Commit(); txErr != nil {
err = fmt.Errorf("tx.Commit(): %w", txErr)
}
}
}()
dbtx := New(db.db, db.instanceID)
dbtx.tx = tx
dbtx.conn = conn
dbtx.opts = *opts
defer dbtx.logTransaction(ctx)(&err)
if err := txFunc(dbtx); err != nil {
return fmt.Errorf("txFunc(tx): %w", err)
}
return nil
}
// MaxRetries returns the maximum number of times thata serializable transaction was retried.
func (db *DB) MaxRetries() int {
db.mu.Lock()
defer db.mu.Unlock()
return db.maxRetries
}
const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
// BulkInsert constructs and executes a multi-value insert statement. The
// query is constructed using the format:
//
// INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
//
// If conflictAction is not empty, it is appended to the statement.
//
// The query is executed using a PREPARE statement with the provided values.
func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []any, conflictAction string) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
return db.bulkInsert(ctx, table, columns, nil, values, conflictAction, nil)
}
// BulkInsertReturning is like BulkInsert, but supports returning values from the INSERT statement.
// In addition to the arguments of BulkInsert, it takes a list of columns to return and a function
// to scan those columns. To get the returned values, provide a function that scans them as if
// they were the selected columns of a query. See TestBulkInsert for an example.
func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []string, values []any, conflictAction string, returningColumns []string, scanFunc func(*sql.Rows) error) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsertReturning(ctx, %q, %v, [%d values], %q, %v, scanFunc)",
table, columns, len(values), conflictAction, returningColumns)
if returningColumns == nil || scanFunc == nil {
return errors.New("need returningColumns and scan function")
}
return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
}
// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of
// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO
// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for
// every column c.
func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []any, conflictColumns []string) error {
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
return db.BulkInsert(ctx, table, columns, values, conflictAction)
}
// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []any, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error {
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc)
}
func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []any, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
if remainder := len(values) % len(columns); remainder != 0 {
return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
}
// Postgres supports up to 65535 parameters, but stop well before that
// so we don't construct humongous queries.
const maxParameters = 1000
stride := (maxParameters / len(columns)) * len(columns)
if stride == 0 {
// This is a pathological case (len(columns) > maxParameters), but we
// handle it cautiously.
return fmt.Errorf("too many columns to insert: %d", len(columns))
}
prepare := func(n int) (*sql.Stmt, error) {
return db.Prepare(ctx, buildInsertQuery(table, columns, returningColumns, n, conflictAction))
}
var stmt *sql.Stmt
for leftBound := 0; leftBound < len(values); leftBound += stride {
rightBound := leftBound + stride
if rightBound <= len(values) && stmt == nil {
stmt, err = prepare(stride)
if err != nil {
return err
}
defer stmt.Close()
} else if rightBound > len(values) {
rightBound = len(values)
stmt, err = prepare(rightBound - leftBound)
if err != nil {
return err
}
defer stmt.Close()
}
valueSlice := values[leftBound:rightBound]
var err error
if returningColumns == nil {
_, err = stmt.ExecContext(ctx, valueSlice...)
} else {
var rows *sql.Rows
rows, err = stmt.QueryContext(ctx, valueSlice...)
if err != nil {
return err
}
_, err = processRows(rows, scanFunc)
}
if err != nil {
return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err)
}
}
return nil
}
// buildInsertQuery builds an multi-value insert query, following the format:
// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) <conflictAction>
// If returningColumns is not empty, it appends a RETURNING clause to the query.
//
// When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0.
func buildInsertQuery(table string, columns, returningColumns []string, nvalues int, conflictAction string) string {
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s", table)
fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", "))
var placeholders []string
for i := 1; i <= nvalues; i++ {
// Construct the full query by adding placeholders for each
// set of values that we want to insert.
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
if i%len(columns) != 0 {
continue
}
// When the end of a set is reached, write it to the query
// builder and reset placeholders.
fmt.Fprintf(&b, "(%s)", strings.Join(placeholders, ", "))
placeholders = nil
// Do not add a comma delimiter after the last set of values.
if i == nvalues {
break
}
b.WriteString(", ")
}
if conflictAction != "" {
b.WriteString(" " + conflictAction)
}
if len(returningColumns) > 0 {
fmt.Fprintf(&b, " RETURNING %s", strings.Join(returningColumns, ", "))
}
return b.String()
}
func buildUpsertConflictAction(columns, conflictColumns []string) string {
var sets []string
for _, c := range columns {
sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c))
}
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
strings.Join(conflictColumns, ", "),
strings.Join(sets, ", "))
}
// maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
// Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
// to a reasonable size.)
// It is a variable for testing.
var maxBulkUpdateArrayLen = 10000
// BulkUpdate executes multiple UPDATE statements in a transaction.
//
// Columns must contain the names of some of table's columns. The first is treated
// as a key; that is, the values to update are matched with existing rows by comparing
// the values of the first column.
//
// Types holds the database type of each column. For example,
//
// []string{"INT", "TEXT"}
//
// Values contains one slice of values per column. (Note that this is unlike BulkInsert, which
// takes a single slice of interleaved values.)
func (db *DB) BulkUpdate(ctx context.Context, table string, columns, types []string, values [][]any) (err error) {
defer derrors.Wrap(&err, "DB.BulkUpdate(ctx, tx, %q, %v, [%d values])",
table, columns, len(values))
if len(columns) < 2 {
return errors.New("need at least two columns")
}
if len(columns) != len(values) {
return errors.New("len(values) != len(columns)")
}
nRows := len(values[0])
for _, v := range values[1:] {
if len(v) != nRows {
return errors.New("all values slices must be the same length")
}
}
query := buildBulkUpdateQuery(table, columns, types)
for left := 0; left < nRows; left += maxBulkUpdateArrayLen {
right := left + maxBulkUpdateArrayLen
if right > nRows {
right = nRows
}
var args []any
for _, vs := range values {
args = append(args, pq.Array(vs[left:right]))
}
if _, err := db.Exec(ctx, query, args...); err != nil {
return fmt.Errorf("db.Exec(%q, values[%d:%d]): %w", query, left, right, err)
}
}
return nil
}
func buildBulkUpdateQuery(table string, columns, types []string) string {
var sets, unnests []string
// Build "c = data.c" for each non-key column.
for _, c := range columns[1:] {
sets = append(sets, fmt.Sprintf("%s = data.%[1]s", c))
}
// Build "UNNEST($1::TYPE) AS c" for each column.
// We need the type, or Postgres complains that UNNEST is not unique.
for i, c := range columns {
unnests = append(unnests, fmt.Sprintf("UNNEST($%d::%s[]) AS %s", i+1, types[i], c))
}
return fmt.Sprintf(`
UPDATE %[1]s
SET %[2]s
FROM (SELECT %[3]s) AS data
WHERE %[1]s.%[4]s = data.%[4]s`,
table, // 1
strings.Join(sets, ", "), // 2
strings.Join(unnests, ", "), // 3
columns[0], // 4
)
}
// Collect1 runs the query, which must select for a single column that can be
// scanned into a value of type T, and returns a slice of the resulting values.
func Collect1[T any](ctx context.Context, db *DB, query string, args ...any) (ts []T, err error) {
defer derrors.WrapStack(&err, "Collect1(%q)", query)
err = db.RunQuery(ctx, query, func(rows *sql.Rows) error {
var t T
if err := rows.Scan(&t); err != nil {
return err
}
ts = append(ts, t)
return nil
}, args...)
if err != nil {
return nil, err
}
return ts, nil
}
// emptyStringScanner wraps the functionality of sql.NullString to just write
// an empty string if the value is NULL.
type emptyStringScanner struct {
ptr *string
}
func (e emptyStringScanner) Scan(value any) error {
var ns sql.NullString
if err := ns.Scan(value); err != nil {
return err
}
*e.ptr = ns.String
return nil
}
// NullIsEmpty returns a sql.Scanner that writes the empty string to s if the
// sql.Value is NULL.
func NullIsEmpty(s *string) sql.Scanner {
return emptyStringScanner{s}
}