| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package dbtest supports testing with a database. |
| package dbtest |
| |
| import ( |
| "database/sql" |
| "fmt" |
| "log" |
| "net/url" |
| "os" |
| "strings" |
| |
| "golang.org/x/pkgsite/internal/derrors" |
| // imported to register the postgres migration driver |
| _ "github.com/golang-migrate/migrate/v4/database/postgres" |
| // imported to register the file source migration driver |
| _ "github.com/golang-migrate/migrate/v4/source/file" |
| // imported to register the postgres database driver |
| _ "github.com/lib/pq" |
| ) |
| |
| func getEnv(key, fallback string) string { |
| if value, ok := os.LookupEnv(key); ok { |
| return value |
| } |
| return fallback |
| } |
| |
| // DBConnURI generates a postgres connection string in URI format. This is |
| // necessary as migrate expects a URI. |
| func DBConnURI(dbName string) string { |
| var ( |
| user = getEnv("GO_DISCOVERY_DATABASE_TEST_USER", "postgres") |
| password = getEnv("GO_DISCOVERY_DATABASE_TEST_PASSWORD", "") |
| host = getEnv("GO_DISCOVERY_DATABASE_TEST_HOST", "localhost") |
| port = getEnv("GO_DISCOVERY_DATABASE_TEST_PORT", "5432") |
| ) |
| cs := fmt.Sprintf("postgres://%s/%s?sslmode=disable&user=%s&password=%s&port=%s", |
| host, dbName, url.QueryEscape(user), url.QueryEscape(password), url.QueryEscape(port)) |
| return cs |
| } |
| |
| // MultiErr can be used to combine one or more errors into a single error. |
| type MultiErr []error |
| |
| func (m MultiErr) Error() string { |
| var sb strings.Builder |
| for _, err := range m { |
| sep := "" |
| if sb.Len() > 0 { |
| sep = "|" |
| } |
| if err != nil { |
| sb.WriteString(sep + err.Error()) |
| } |
| } |
| return sb.String() |
| } |
| |
| // ConnectAndExecute connects to the postgres database specified by uri and |
| // executes dbFunc, then cleans up the database connection. |
| // It returns an error that Is derrors.NotFound if no connection could be made. |
| func ConnectAndExecute(uri string, dbFunc func(*sql.DB) error) (outerErr error) { |
| pg, err := sql.Open("postgres", uri) |
| if err == nil { |
| err = pg.Ping() |
| } |
| if err != nil { |
| return fmt.Errorf("%w: %v", derrors.NotFound, err) |
| } |
| defer func() { |
| if err := pg.Close(); err != nil { |
| outerErr = MultiErr{outerErr, err} |
| } |
| }() |
| return dbFunc(pg) |
| } |
| |
| // CreateDB creates a new database dbName. |
| func CreateDB(dbName string) error { |
| return ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| if _, err := pg.Exec(fmt.Sprintf(` |
| CREATE DATABASE %q |
| TEMPLATE=template0 |
| LC_COLLATE='C' |
| LC_CTYPE='C';`, dbName)); err != nil { |
| return fmt.Errorf("error creating %q: %v", dbName, err) |
| } |
| |
| return nil |
| }) |
| } |
| |
| // DropDB drops the database named dbName. |
| func DropDB(dbName string) error { |
| return ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| if _, err := pg.Exec(fmt.Sprintf("DROP DATABASE %q;", dbName)); err != nil { |
| return fmt.Errorf("error dropping %q: %v", dbName, err) |
| } |
| return nil |
| }) |
| } |
| |
| // CreateDBIfNotExists checks whether the given dbName is an existing database, |
| // and creates one if not. |
| func CreateDBIfNotExists(dbName string) error { |
| exists, err := checkIfDBExists(dbName) |
| if err != nil || exists { |
| return err |
| } |
| |
| log.Printf("Test database %q does not exist, creating.", dbName) |
| return CreateDB(dbName) |
| } |
| |
| // checkIfDBExists check if dbName exists. |
| func checkIfDBExists(dbName string) (bool, error) { |
| var exists bool |
| |
| err := ConnectAndExecute(DBConnURI(""), func(pg *sql.DB) error { |
| rows, err := pg.Query("SELECT 1 from pg_database WHERE datname = $1 LIMIT 1", dbName) |
| if err != nil { |
| return err |
| } |
| defer rows.Close() |
| |
| if rows.Next() { |
| exists = true |
| return nil |
| } |
| |
| return rows.Err() |
| }) |
| |
| return exists, err |
| } |