blob: 0e8c421525627e91aaa5244827ab028c826e4ac8 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"testing"
"time"
"github.com/golang-migrate/migrate/v4"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/testing/dbtest"
"golang.org/x/pkgsite/internal/testing/sample"
"golang.org/x/pkgsite/internal/testing/testhelper"
// imported to register the postgres migration driver
_ "github.com/golang-migrate/migrate/v4/database/postgres"
// imported to register the file source migration driver
_ "github.com/golang-migrate/migrate/v4/source/file"
// imported to register the postgres database driver
_ "github.com/lib/pq"
)
// recreateDB drops and recreates the database named dbName.
func recreateDB(dbName string) error {
return dbtest.ConnectAndExecute(dbtest.DBConnURI(""), func(pg *sql.DB) error {
if _, err := pg.Exec(fmt.Sprintf("DROP DATABASE %q;", dbName)); err != nil {
return fmt.Errorf("error dropping %q: %v", dbName, err)
}
if _, err := pg.Exec(fmt.Sprintf("CREATE DATABASE %q;", dbName)); err != nil {
return fmt.Errorf("error creating %q: %v", dbName, err)
}
return nil
})
}
// migrationsSource returns a uri pointing to the migrations directory. It
// returns an error if unable to determine this path.
func migrationsSource() string {
migrationsDir := testhelper.TestDataPath("../../migrations")
return "file://" + filepath.ToSlash(migrationsDir)
}
// tryToMigrate attempts to migrate the database named dbName to the latest
// migration. If this operation fails in the migration step, it returns
// isMigrationError=true to signal that the database should be recreated.
func tryToMigrate(dbName string) (isMigrationError bool, outerErr error) {
dbURI := dbtest.DBConnURI(dbName)
source := migrationsSource()
m, err := migrate.New(source, dbURI)
if err != nil {
return false, fmt.Errorf("migrate.New(): %v", err)
}
defer func() {
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
outerErr = dbtest.MultiErr{outerErr, srcErr, dbErr}
}
}()
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return true, fmt.Errorf("m.Up(): %v", err)
}
return false, nil
}
// SetupTestDB creates a test database named dbName if it does not already
// exist, and migrates it to the latest schema from the migrations directory.
func SetupTestDB(dbName string) (_ *DB, err error) {
defer derrors.Wrap(&err, "SetupTestDB(%q)", dbName)
if err := dbtest.CreateDBIfNotExists(dbName); err != nil {
return nil, fmt.Errorf("CreateDBIfNotExists(%q): %w", dbName, err)
}
if isMigrationError, err := tryToMigrate(dbName); err != nil {
if isMigrationError {
// failed during migration stage, recreate and try again
log.Printf("Migration failed for %s: %v, recreating database.", dbName, err)
if err := recreateDB(dbName); err != nil {
return nil, fmt.Errorf("recreateDB(%q): %v", dbName, err)
}
_, err = tryToMigrate(dbName)
}
if err != nil {
return nil, fmt.Errorf("unfixable error migrating database: %v.\nConsider running ./devtools/drop_test_dbs.sh", err)
}
}
db, err := database.Open("postgres", dbtest.DBConnURI(dbName), "test")
if err != nil {
return nil, err
}
return New(db), nil
}
// ResetTestDB truncates all data from the given test DB. It should be called
// after every test that mutates the database.
func ResetTestDB(db *DB, t *testing.T) {
ctx := context.Background()
t.Helper()
if err := db.db.Transact(ctx, sql.LevelDefault, func(tx *database.DB) error {
if _, err := tx.Exec(ctx, `
TRUNCATE modules CASCADE;
TRUNCATE version_map;
TRUNCATE imports_unique;
TRUNCATE experiments;`); err != nil {
return err
}
if _, err := tx.Exec(ctx, `TRUNCATE module_version_states CASCADE;`); err != nil {
return err
}
if _, err := tx.Exec(ctx, `TRUNCATE excluded_prefixes;`); err != nil {
return err
}
setExcludedPrefixesLastFetched(time.Time{})
return nil
}); err != nil {
t.Fatalf("error resetting test DB: %v", err)
}
}
// RunDBTests is a wrapper that runs the given testing suite in a test database
// named dbName. The given *DB reference will be set to the instantiated test
// database.
func RunDBTests(dbName string, m *testing.M, testDB **DB) {
database.QueryLoggingDisabled = true
db, err := SetupTestDB(dbName)
if err != nil {
if errors.Is(err, derrors.NotFound) && os.Getenv("GO_DISCOVERY_TESTDB") != "true" {
log.Printf("SKIPPING: could not connect to DB (see doc/postgres.md to set up): %v", err)
return
}
log.Fatal(err)
}
*testDB = db
code := m.Run()
if err := db.Close(); err != nil {
log.Fatal(err)
}
os.Exit(code)
}
// InsertSampleDirectory tree inserts a set of packages for testing
// GetDirectory and frontend.FetchDirectoryDetails.
func InsertSampleDirectoryTree(ctx context.Context, t *testing.T, testDB *DB) {
t.Helper()
for _, data := range []struct {
modulePath, version string
suffixes []string
}{
{
"std",
"v1.13.4",
[]string{
"archive/zip",
"archive/tar",
"cmd/go",
"cmd/internal/obj",
"cmd/internal/obj/arm",
"cmd/internal/obj/arm64",
},
},
{
"std",
"v1.13.0",
[]string{
"archive/zip",
"archive/tar",
"cmd/go",
"cmd/internal/obj",
"cmd/internal/obj/arm",
"cmd/internal/obj/arm64",
},
},
{
"github.com/hashicorp/vault/api",
"v1.1.2",
[]string{""},
},
{
"github.com/hashicorp/vault",
"v1.1.2",
[]string{
"api",
"builtin/audit/file",
"builtin/audit/socket",
"vault/replication",
"vault/seal/transit",
},
},
{
"github.com/hashicorp/vault",
"v1.2.3",
[]string{
"internal/foo",
"builtin/audit/file",
"builtin/audit/socket",
"vault/replication",
"vault/seal/transit",
},
},
{
"github.com/hashicorp/vault",
"v1.0.3",
[]string{
"api",
"builtin/audit/file",
"builtin/audit/socket",
},
},
} {
m := sample.Module(data.modulePath, data.version, data.suffixes...)
for _, p := range m.LegacyPackages {
p.Imports = nil
}
if err := testDB.InsertModule(ctx, m); err != nil {
t.Fatal(err)
}
}
}
// GetFromSearchDocuments retrieves the module path and version for the given
// package path from the search_documents table. If the path is not in the table,
// the third return value is false.
func GetFromSearchDocuments(ctx context.Context, t *testing.T, db *DB, packagePath string) (modulePath, version string, found bool) {
row := db.db.QueryRow(ctx, `
SELECT module_path, version
FROM search_documents
WHERE package_path = $1`,
packagePath)
err := row.Scan(&modulePath, &version)
switch err {
case sql.ErrNoRows:
return "", "", false
case nil:
return modulePath, version, true
default:
t.Fatal(err)
}
return
}