blob: 76ca9467722b668996253198e95d5764ec9779ea [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package relui
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/golang-migrate/migrate/v4"
dbpgx "github.com/golang-migrate/migrate/v4/database/pgx"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v4"
)
var errDBNotExist = errors.New("database does not exist")
// InitDB creates and applies all migrations to the database specified
// in conn.
//
// If the database does not exist, one will be created using the
// credentials provided.
//
// Any key/value or URI string compatible with libpq is valid.
func InitDB(ctx context.Context, conn string) error {
cfg, err := pgx.ParseConfig(conn)
if err != nil {
return fmt.Errorf("pgx.ParseConfig() = %w", err)
}
if err := CreateDBIfNotExists(ctx, cfg); err != nil {
return err
}
if err := MigrateDB(conn, false); err != nil {
return err
}
return nil
}
// MigrateDB applies all migrations to the database specified in conn.
//
// Any key/value or URI string compatible with libpq is a valid conn.
// If downUp is true, all migrations will be run, then the down and up
// migrations of the final migration are run.
func MigrateDB(conn string, downUp bool) error {
cfg, err := pgx.ParseConfig(conn)
if err != nil {
return fmt.Errorf("pgx.ParseConfig() = %w", err)
}
db, err := sql.Open("pgx", conn)
if err != nil {
return fmt.Errorf("sql.Open(%q, _) = %v, %w", "pgx", db, err)
}
mcfg := &dbpgx.Config{
MigrationsTable: "migrations",
DatabaseName: cfg.Database,
}
mdb, err := dbpgx.WithInstance(db, mcfg)
if err != nil {
return fmt.Errorf("dbpgx.WithInstance(_, %v) = %v, %w", mcfg, mdb, err)
}
mfs, err := iofs.New(migrations, "migrations")
if err != nil {
return fmt.Errorf("iofs.New(%v, %q) = %v, %w", migrations, "migrations", mfs, err)
}
m, err := migrate.NewWithInstance("iofs", mfs, "pgx", mdb)
if err != nil {
return fmt.Errorf("migrate.NewWithInstance(%q, %v, %q, %v) = %v, %w", "iofs", migrations, "pgx", mdb, m, err)
}
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("m.Up() = %w", err)
}
if downUp {
if err := m.Steps(-1); err != nil {
return fmt.Errorf("m.Steps(%d) = %w", -1, err)
}
if err := m.Up(); err != nil {
return fmt.Errorf("m.Up() = %w", err)
}
}
db.Close()
return nil
}
// ConnectMaintenanceDB connects to the maintenance database using the
// credentials from cfg. If maintDB is an empty string, the database
// with the name cfg.User will be used.
func ConnectMaintenanceDB(ctx context.Context, cfg *pgx.ConnConfig, maintDB string) (*pgx.Conn, error) {
cfg = cfg.Copy()
if maintDB == "" {
maintDB = "postgres"
}
cfg.Database = maintDB
return pgx.ConnectConfig(ctx, cfg)
}
// CreateDBIfNotExists checks whether the given dbName is an existing
// database, and creates one if not.
func CreateDBIfNotExists(ctx context.Context, cfg *pgx.ConnConfig) error {
exists, err := checkIfDBExists(ctx, cfg)
if err != nil || exists {
return err
}
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
if err != nil {
return fmt.Errorf("ConnectMaintenanceDB = %w", err)
}
createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{cfg.Database}.Sanitize())
if _, err := conn.Exec(ctx, createSQL); err != nil {
return fmt.Errorf("conn.Exec(%q) = %w", createSQL, err)
}
return nil
}
// DropDB drops the database specified in cfg. An error returned if
// the database does not exist.
func DropDB(ctx context.Context, cfg *pgx.ConnConfig) error {
exists, err := checkIfDBExists(ctx, cfg)
if err != nil {
return fmt.Errorf("p.checkIfDBExists() = %w", err)
}
if !exists {
return errDBNotExist
}
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
if err != nil {
return fmt.Errorf("ConnectMaintenanceDB = %w", err)
}
dropSQL := fmt.Sprintf("DROP DATABASE %s", pgx.Identifier{cfg.Database}.Sanitize())
if _, err := conn.Exec(ctx, dropSQL); err != nil {
return fmt.Errorf("conn.Exec(%q) = %w", dropSQL, err)
}
return nil
}
func checkIfDBExists(ctx context.Context, cfg *pgx.ConnConfig) (bool, error) {
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
if err != nil {
return false, fmt.Errorf("ConnectMaintenanceDB = %w", err)
}
row := conn.QueryRow(ctx, "SELECT 1 from pg_database WHERE datname=$1 LIMIT 1", cfg.Database)
var exists int
if err := row.Scan(&exists); err != nil && err != pgx.ErrNoRows {
return false, fmt.Errorf("row.Scan() = %w", err)
}
return exists == 1, nil
}