| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package database |
| |
| import ( |
| "context" |
| "database/sql" |
| "fmt" |
| "log" |
| "net/url" |
| "path/filepath" |
| "strings" |
| |
| "golang.org/x/pkgsite/internal/config/serverconfig" |
| "golang.org/x/pkgsite/internal/derrors" |
| "golang.org/x/pkgsite/internal/testing/testhelper" |
| |
| // imported to register the postgres migration driver |
| "github.com/golang-migrate/migrate/v4" |
| _ "github.com/golang-migrate/migrate/v4/database/postgres" |
| |
| // imported to register the file source migration driver |
| _ "github.com/golang-migrate/migrate/v4/source/file" |
| // imported to register the postgres database driver |
| _ "github.com/lib/pq" |
| ) |
| |
| // DBConnURI generates a postgres connection string in URI format. This is |
| // necessary as migrate expects a URI. |
| func DBConnURI(dbName string) string { |
| var ( |
| user = serverconfig.GetEnv("GO_DISCOVERY_DATABASE_USER", "postgres") |
| password = serverconfig.GetEnv("GO_DISCOVERY_DATABASE_PASSWORD", "") |
| host = serverconfig.GetEnv("GO_DISCOVERY_DATABASE_HOST", "localhost") |
| port = serverconfig.GetEnv("GO_DISCOVERY_DATABASE_PORT", "5432") |
| ) |
| cs := fmt.Sprintf("postgres://%s/%s?sslmode=disable&user=%s&password=%s&port=%s&timezone=UTC", |
| host, dbName, url.QueryEscape(user), url.QueryEscape(password), url.QueryEscape(port)) |
| return cs |
| } |
| |
| // MultiErr can be used to combine one or more errors into a single error. |
| type MultiErr []error |
| |
| func (m MultiErr) Error() string { |
| var sb strings.Builder |
| for _, err := range m { |
| sep := "" |
| if sb.Len() > 0 { |
| sep = "|" |
| } |
| if err != nil { |
| sb.WriteString(sep + err.Error()) |
| } |
| } |
| return sb.String() |
| } |
| |
| // ConnectAndExecute connects to the postgres database specified by uri and |
| // executes dbFunc, then cleans up the database connection. |
| // It returns an error that Is derrors.NotFound if no connection could be made. |
| func ConnectAndExecute(uri string, dbFunc func(*sql.DB) error) (outerErr error) { |
| pg, err := sql.Open("postgres", uri) |
| if err == nil { |
| err = pg.Ping() |
| } |
| if err != nil { |
| return fmt.Errorf("%w: %v", derrors.NotFound, err) |
| } |
| defer func() { |
| if err := pg.Close(); err != nil { |
| outerErr = MultiErr{outerErr, err} |
| } |
| }() |
| return dbFunc(pg) |
| } |
| |
| // CreateDB creates a new database dbName. |
| func CreateDB(dbName string) error { |
| return ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| if _, err := pg.Exec(fmt.Sprintf(` |
| CREATE DATABASE %q |
| TEMPLATE=template0 |
| LC_COLLATE='C' |
| LC_CTYPE='C';`, dbName)); err != nil { |
| return fmt.Errorf("error creating %q: %v", dbName, err) |
| } |
| |
| return nil |
| }) |
| } |
| |
| // DropDB drops the database named dbName. |
| func DropDB(dbName string) error { |
| return ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| if _, err := pg.Exec(fmt.Sprintf("DROP DATABASE %q;", dbName)); err != nil { |
| return fmt.Errorf("error dropping %q: %v", dbName, err) |
| } |
| return nil |
| }) |
| } |
| |
| // CreateDBIfNotExists checks whether the given dbName is an existing database, |
| // and creates one if not. |
| func CreateDBIfNotExists(dbName string) error { |
| exists, err := checkIfDBExists(dbName) |
| if err != nil || exists { |
| return err |
| } |
| |
| log.Printf("Database %q does not exist, creating.", dbName) |
| return CreateDB(dbName) |
| } |
| |
| // checkIfDBExists check if dbName exists. |
| func checkIfDBExists(dbName string) (bool, error) { |
| var exists bool |
| |
| err := ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| rows, err := pg.Query("SELECT 1 from pg_database WHERE datname = $1 LIMIT 1", dbName) |
| if err != nil { |
| return err |
| } |
| defer rows.Close() |
| |
| if rows.Next() { |
| exists = true |
| return nil |
| } |
| |
| return rows.Err() |
| }) |
| |
| return exists, err |
| } |
| |
| // TryToMigrate attempts to migrate the database named dbName to the latest |
| // migration. If this operation fails in the migration step, it returns |
| // isMigrationError=true to signal that the database should be recreated. |
| func TryToMigrate(dbName string) (isMigrationError bool, outerErr error) { |
| dbURI := DBConnURI(dbName) |
| source := migrationsSource() |
| m, err := migrate.New(source, dbURI) |
| if err != nil { |
| return false, fmt.Errorf("migrate.New(): %v", err) |
| } |
| defer func() { |
| if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { |
| outerErr = MultiErr{outerErr, srcErr, dbErr} |
| } |
| }() |
| if err := m.Up(); err != nil && err != migrate.ErrNoChange { |
| return true, fmt.Errorf("m.Up() %q: %v", source, err) |
| } |
| return false, nil |
| } |
| |
| // migrationsSource returns a uri pointing to the migrations directory. It |
| // returns an error if unable to determine this path. |
| func migrationsSource() string { |
| migrationsDir := testhelper.TestDataPath("../../migrations") |
| return "file://" + filepath.ToSlash(migrationsDir) |
| } |
| |
| // ResetDB truncates all data from the given test DB. It should be called |
| // after every test that mutates the database. |
| func ResetDB(ctx context.Context, db *DB) error { |
| if err := db.Transact(ctx, sql.LevelDefault, func(tx *DB) error { |
| if _, err := tx.Exec(ctx, ` |
| TRUNCATE modules CASCADE; |
| TRUNCATE search_documents CASCADE; |
| TRUNCATE version_map; |
| TRUNCATE paths CASCADE; |
| TRUNCATE symbol_names CASCADE; |
| TRUNCATE imports_unique; |
| TRUNCATE latest_module_versions;`); err != nil { |
| return err |
| } |
| if _, err := tx.Exec(ctx, `TRUNCATE module_version_states CASCADE;`); err != nil { |
| return err |
| } |
| if _, err := tx.Exec(ctx, `TRUNCATE excluded_prefixes;`); err != nil { |
| return err |
| } |
| return nil |
| }); err != nil { |
| return fmt.Errorf("error resetting test DB: %v", err) |
| } |
| return nil |
| } |