blob: 1f990a0487c7706f3e8d07debd8735359ef0919d [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package postgres
import (
// imported to register the postgres migration driver
_ ""
// imported to register the file source migration driver
_ ""
// recreateDB drops and recreates the database named dbName.
func recreateDB(dbName string) error {
err := database.DropDB(dbName)
if err != nil {
return err
return database.CreateDB(dbName)
// SetupTestDB creates a test database named dbName if it does not already
// exist, and migrates it to the latest schema from the migrations directory.
func SetupTestDB(dbName string) (_ *DB, err error) {
defer derrors.Wrap(&err, "SetupTestDB(%q)", dbName)
if err := database.CreateDBIfNotExists(dbName); err != nil {
return nil, fmt.Errorf("CreateDBIfNotExists(%q): %w", dbName, err)
if isMigrationError, err := database.TryToMigrate(dbName); err != nil {
if isMigrationError {
// failed during migration stage, recreate and try again
log.Printf("Migration failed for %s: %v, recreating database.", dbName, err)
if err := recreateDB(dbName); err != nil {
return nil, fmt.Errorf("recreateDB(%q): %v", dbName, err)
_, err = database.TryToMigrate(dbName)
if err != nil {
return nil, fmt.Errorf("unfixable error migrating database: %v.\nConsider running ./devtools/", err)
db, err := database.Open("pgx", database.DBConnURI(dbName), "test")
if err != nil {
return nil, err
return New(db), nil
// ResetTestDB truncates all data from the given test DB. It should be called
// after every test that mutates the database.
func ResetTestDB(db *DB, t *testing.T) {
ctx := context.Background()
if err := database.ResetDB(ctx, db.db); err != nil {
t.Fatalf("error resetting test DB: %v", err)
db.expoller.Poll(ctx) // clear excluded prefixes
// RunDBTests is a wrapper that runs the given testing suite in a test database
// named dbName. The given *DB reference will be set to the instantiated test
// database.
func RunDBTests(dbName string, m *testing.M, testDB **DB) {
database.QueryLoggingDisabled = true
db, err := SetupTestDB(dbName)
if err != nil {
if errors.Is(err, derrors.NotFound) && os.Getenv("GO_DISCOVERY_TESTDB") != "true" {
log.Printf("SKIPPING: could not connect to DB (see doc/ to set up): %v", err)
*testDB = db
code := m.Run()
if err := db.Close(); err != nil {
// RunDBTestsInParallel sets up numDBs databases, then runs the tests. Before it runs them,
// it sets acquirep to a function that tests should use to acquire a database. The second
// return value of the function should be called in a defer statement to release the database.
// For example:
// func Test(t *testing.T) {
// db, release := acquire(t)
// defer release()
// }
func RunDBTestsInParallel(dbBaseName string, numDBs int, m *testing.M, acquirep *func(*testing.T) (*DB, func())) {
start := time.Now()
database.QueryLoggingDisabled = true
dbs := make(chan *DB, numDBs)
for i := 0; i < numDBs; i++ {
db, err := SetupTestDB(fmt.Sprintf("%s-%d", dbBaseName, i))
if err != nil {
if errors.Is(err, derrors.NotFound) && os.Getenv("GO_DISCOVERY_TESTDB") != "true" {
log.Printf("SKIPPING: could not connect to DB (see doc/ to set up): %v", err)
dbs <- db
*acquirep = func(t *testing.T) (*DB, func()) {
db := <-dbs
release := func() {
ResetTestDB(db, t)
dbs <- db
return db, release
log.Printf("parallel test setup for %d DBs took %s", numDBs, time.Since(start))
code := m.Run()
if len(dbs) != cap(dbs) {
log.Fatal("not all DBs were released")
for i := 0; i < numDBs; i++ {
db := <-dbs
if err := db.Close(); err != nil {
// MustInsertModule inserts m into db, calling t.Fatal on error.
// It also updates the latest-version information for m.
func MustInsertModule(ctx context.Context, t *testing.T, db *DB, m *internal.Module) {
mustInsertModule(ctx, t, db, m, "", true)
func MustInsertModuleGoMod(ctx context.Context, t *testing.T, db *DB, m *internal.Module, goMod string) {
mustInsertModule(ctx, t, db, m, goMod, true)
func MustInsertModuleNotLatest(ctx context.Context, t *testing.T, db *DB, m *internal.Module) {
mustInsertModule(ctx, t, db, m, "", false)
func mustInsertModule(ctx context.Context, t *testing.T, db *DB, m *internal.Module, goMod string, latest bool) {
var lmv *internal.LatestModuleVersions
if goMod == "-" {
if err := db.UpdateLatestModuleVersionsStatus(ctx, m.ModulePath, 404); err != nil {
} else if latest {
lmv = addLatest(ctx, t, db, m.ModulePath, m.Version, goMod)
if _, err := db.InsertModule(ctx, m, lmv); err != nil {
func addLatest(ctx context.Context, t *testing.T, db *DB, modulePath, version, modFile string) *internal.LatestModuleVersions {
if !strings.HasPrefix(strings.TrimSpace(modFile), "module") {
modFile = "module " + modulePath + "\n" + modFile
info, err := internal.NewLatestModuleVersions(modulePath, version, version, "", []byte(modFile))
if err != nil {
lmv, err := db.UpdateLatestModuleVersions(ctx, info)
if err != nil {
return lmv
// InsertSampleDirectory tree inserts a set of packages for testing
// GetUnit and frontend.FetchDirectoryDetails.
func InsertSampleDirectoryTree(ctx context.Context, t *testing.T, testDB *DB) {
for _, data := range []struct {
modulePath, version string
suffixes []string
} {
m := sample.Module(data.modulePath, data.version, data.suffixes...)
MustInsertModule(ctx, t, testDB, m)
// GetFromSearchDocuments retrieves the module path and version for the given
// package path from the search_documents table. If the path is not in the table,
// the third return value is false.
func GetFromSearchDocuments(ctx context.Context, t *testing.T, db *DB, packagePath string) (modulePath, version string, found bool) {
row := db.db.QueryRow(ctx, `
SELECT module_path, version
FROM search_documents
WHERE package_path = $1`,
err := row.Scan(&modulePath, &version)
switch err {
case sql.ErrNoRows:
return "", "", false
case nil:
return modulePath, version, true