| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package database |
| |
| import ( |
| "context" |
| "database/sql" |
| "errors" |
| "fmt" |
| "io" |
| "log" |
| "os" |
| "sort" |
| "strings" |
| "testing" |
| "time" |
| |
| "github.com/google/go-cmp/cmp" |
| "github.com/jackc/pgconn" |
| "github.com/jackc/pgx/v4" |
| "github.com/jackc/pgx/v4/stdlib" |
| "golang.org/x/pkgsite/internal/derrors" |
| ) |
| |
| const testTimeout = 5 * time.Second |
| |
| const testDBName = "discovery_postgres_test" |
| |
| var testDB *DB |
| |
| func TestMain(m *testing.M) { |
| if err := CreateDBIfNotExists(testDBName); err != nil { |
| if errors.Is(err, derrors.NotFound) && os.Getenv("GO_DISCOVERY_TESTDB") != "true" { |
| log.Printf("SKIPPING: could not connect to DB (see doc/postgres.md to set up): %v", err) |
| return |
| } |
| log.Fatal(err) |
| } |
| |
| var err error |
| for _, driver := range []string{"postgres", "pgx"} { |
| log.Printf("with driver %q", driver) |
| testDB, err = Open(driver, DBConnURI(testDBName), "test") |
| if err != nil { |
| log.Fatalf("Open: %v %[1]T", err) |
| } |
| code := m.Run() |
| if err := testDB.Close(); err != nil { |
| log.Fatal(err) |
| } |
| if code != 0 { |
| os.Exit(code) |
| } |
| } |
| } |
| |
| func TestBulkInsert(t *testing.T) { |
| table := "test_bulk_insert" |
| |
| for _, test := range []struct { |
| name string |
| columns []string |
| values []interface{} |
| conflictAction string |
| wantErr bool |
| wantCount int |
| wantReturned []string |
| }{ |
| { |
| |
| name: "test-one-row", |
| columns: []string{"colA"}, |
| values: []interface{}{"valueA"}, |
| wantCount: 1, |
| }, |
| { |
| |
| name: "test-multiple-rows", |
| columns: []string{"colA"}, |
| values: []interface{}{"valueA1", "valueA2", "valueA3"}, |
| wantCount: 3, |
| }, |
| { |
| |
| name: "test-invalid-column-name", |
| columns: []string{"invalid_col"}, |
| values: []interface{}{"valueA"}, |
| wantErr: true, |
| }, |
| { |
| |
| name: "test-mismatch-num-cols-and-vals", |
| columns: []string{"colA", "colB"}, |
| values: []interface{}{"valueA1", "valueB1", "valueA2"}, |
| wantErr: true, |
| }, |
| { |
| |
| name: "insert-returning", |
| columns: []string{"colA", "colB"}, |
| values: []interface{}{"valueA1", "valueB1", "valueA2", "valueB2"}, |
| wantCount: 2, |
| wantReturned: []string{"valueA1", "valueA2"}, |
| }, |
| { |
| |
| name: "test-conflict", |
| columns: []string{"colA"}, |
| values: []interface{}{"valueA", "valueA"}, |
| wantErr: true, |
| }, |
| { |
| |
| name: "test-conflict-do-nothing", |
| columns: []string{"colA"}, |
| values: []interface{}{"valueA", "valueA"}, |
| conflictAction: OnConflictDoNothing, |
| wantCount: 1, |
| }, |
| { |
| |
| // This should execute the statement |
| // INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)'); |
| // which will insert a row with path value: |
| // '); TRUNCATE series CASCADE;) |
| // Rather than the statement |
| // INSERT INTO series (path) VALUES (''); TRUNCATE series CASCADE;)); |
| // which would truncate most tables in the database. |
| name: "test-sql-injection", |
| columns: []string{"colA"}, |
| values: []interface{}{fmt.Sprintf("''); TRUNCATE %s CASCADE;))", table)}, |
| conflictAction: OnConflictDoNothing, |
| wantCount: 1, |
| }, |
| } { |
| t.Run(test.name, func(t *testing.T) { |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout) |
| defer cancel() |
| |
| createQuery := fmt.Sprintf(`CREATE TABLE %s ( |
| colA TEXT NOT NULL, |
| colB TEXT, |
| PRIMARY KEY (colA) |
| );`, table) |
| if _, err := testDB.Exec(ctx, createQuery); err != nil { |
| t.Fatal(err) |
| } |
| defer func() { |
| dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table) |
| if _, err := testDB.Exec(ctx, dropTableQuery); err != nil { |
| t.Fatal(err) |
| } |
| }() |
| |
| var err error |
| var returned []string |
| if test.wantReturned == nil { |
| err = testDB.BulkInsert(ctx, table, test.columns, test.values, test.conflictAction) |
| } else { |
| err = testDB.BulkInsertReturning(ctx, table, test.columns, test.values, test.conflictAction, |
| []string{"colA"}, func(rows *sql.Rows) error { |
| var r string |
| if err := rows.Scan(&r); err != nil { |
| return err |
| } |
| returned = append(returned, r) |
| return nil |
| }) |
| } |
| if test.wantErr && err == nil || !test.wantErr && err != nil { |
| t.Errorf("got error %v, wantErr %t", err, test.wantErr) |
| } |
| if err != nil { |
| return |
| } |
| if test.wantCount != 0 { |
| var count int |
| query := "SELECT COUNT(*) FROM " + table |
| row := testDB.QueryRow(ctx, query) |
| err := row.Scan(&count) |
| if err != nil { |
| t.Fatalf("testDB.queryRow(%q): %v", query, err) |
| } |
| if count != test.wantCount { |
| t.Errorf("testDB.queryRow(%q) = %d; want = %d", query, count, test.wantCount) |
| } |
| } |
| if test.wantReturned != nil { |
| sort.Strings(returned) |
| if !cmp.Equal(returned, test.wantReturned) { |
| t.Errorf("returned: got %v, want %v", returned, test.wantReturned) |
| } |
| } |
| }) |
| } |
| } |
| |
| func TestLargeBulkInsert(t *testing.T) { |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout*3) |
| defer cancel() |
| if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { |
| t.Fatal(err) |
| } |
| const size = 150001 |
| vals := make([]interface{}, size) |
| for i := 0; i < size; i++ { |
| vals[i] = i + 1 |
| } |
| start := time.Now() |
| if err := testDB.Transact(ctx, sql.LevelDefault, func(db *DB) error { |
| return db.BulkInsert(ctx, "test_large_bulk", []string{"i"}, vals, "") |
| }); err != nil { |
| t.Fatal(err) |
| } |
| t.Logf("large bulk insert took %s", time.Since(start)) |
| rows, err := testDB.Query(ctx, `SELECT i FROM test_large_bulk;`) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer rows.Close() |
| sum := int64(0) |
| for rows.Next() { |
| var i int64 |
| if err := rows.Scan(&i); err != nil { |
| t.Fatal(err) |
| } |
| sum += i |
| } |
| var want int64 = size * (size + 1) / 2 |
| if sum != want { |
| t.Errorf("sum = %d, want %d", sum, want) |
| } |
| } |
| |
| func TestBulkUpsert(t *testing.T) { |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout*3) |
| defer cancel() |
| if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_replace (C1 int PRIMARY KEY, C2 int);`); err != nil { |
| t.Fatal(err) |
| } |
| for _, values := range [][]interface{}{ |
| {2, 4, 4, 8}, // First, insert some rows. |
| {1, -1, 2, -2, 3, -3, 4, -4}, // Then replace those rows while inserting others. |
| } { |
| err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { |
| return tx.BulkUpsert(ctx, "test_replace", []string{"C1", "C2"}, values, []string{"C1"}) |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| var got []interface{} |
| err = testDB.RunQuery(ctx, `SELECT C1, C2 FROM test_replace ORDER BY C1`, func(rows *sql.Rows) error { |
| var a, b int |
| if err := rows.Scan(&a, &b); err != nil { |
| return err |
| } |
| got = append(got, a, b) |
| return nil |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !cmp.Equal(got, values) { |
| t.Errorf("%v: got %v, want %v", values, got, values) |
| } |
| } |
| } |
| |
| func TestBuildUpsertConflictAction(t *testing.T) { |
| got := buildUpsertConflictAction([]string{"a", "b"}, []string{"c", "d"}) |
| want := "ON CONFLICT (c, d) DO UPDATE SET a=excluded.a, b=excluded.b" |
| if got != want { |
| t.Errorf("got %q, want %q", got, want) |
| } |
| } |
| |
| func TestDBAfterTransactFails(t *testing.T) { |
| ctx := context.Background() |
| var tx *DB |
| err := testDB.Transact(ctx, sql.LevelDefault, func(d *DB) error { |
| tx = d |
| return nil |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| var i int |
| err = tx.QueryRow(ctx, `SELECT 1`).Scan(&i) |
| if err == nil { |
| t.Fatal("got nil, want error") |
| } |
| } |
| |
| func TestBuildBulkUpdateQuery(t *testing.T) { |
| q := buildBulkUpdateQuery("tab", []string{"K", "C1", "C2"}, []string{"TEXT", "INT", "BOOL"}) |
| got := strings.Join(strings.Fields(q), " ") |
| w := ` |
| UPDATE tab |
| SET C1 = data.C1, C2 = data.C2 |
| FROM (SELECT UNNEST($1::TEXT[]) AS K, UNNEST($2::INT[]) AS C1, UNNEST($3::BOOL[]) AS C2) AS data |
| WHERE tab.K = data.K` |
| want := strings.Join(strings.Fields(w), " ") |
| if got != want { |
| t.Errorf("\ngot\n%s\nwant\n%s", got, want) |
| } |
| } |
| |
| func TestBulkUpdate(t *testing.T) { |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout) |
| defer cancel() |
| |
| defer func(old int) { maxBulkUpdateArrayLen = old }(maxBulkUpdateArrayLen) |
| maxBulkUpdateArrayLen = 5 |
| |
| if _, err := testDB.Exec(ctx, `CREATE TABLE bulk_update (a INT, b INT)`); err != nil { |
| t.Fatal(err) |
| } |
| defer func() { |
| if _, err := testDB.Exec(ctx, `DROP TABLE bulk_update`); err != nil { |
| t.Fatal(err) |
| } |
| }() |
| |
| cols := []string{"a", "b"} |
| var values []interface{} |
| for i := 0; i < 50; i++ { |
| values = append(values, i, i) |
| } |
| err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { |
| return tx.BulkInsert(ctx, "bulk_update", cols, values, "") |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Update all even values of column a. |
| updateVals := make([][]interface{}, 2) |
| for i := 0; i < len(values)/2; i += 2 { |
| updateVals[0] = append(updateVals[0], i) |
| updateVals[1] = append(updateVals[1], -i) |
| } |
| |
| err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { |
| return tx.BulkUpdate(ctx, "bulk_update", cols, []string{"INT", "INT"}, updateVals) |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| err = testDB.RunQuery(ctx, `SELECT a, b FROM bulk_update`, func(rows *sql.Rows) error { |
| var a, b int |
| if err := rows.Scan(&a, &b); err != nil { |
| return err |
| } |
| want := a |
| if a%2 == 0 { |
| want = -a |
| } |
| if b != want { |
| t.Fatalf("a=%d: got %d, want %d", a, b, want) |
| } |
| return nil |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| func TestTransactSerializable(t *testing.T) { |
| // Test that serializable transactions retry until success. |
| // This test was taken from the example at https://www.postgresql.org/docs/11/transaction-iso.html, |
| // section 13.2.3. |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout) |
| defer cancel() |
| |
| // Once in while, the test doesn't work. Repeat to de-flake. |
| var msg string |
| for i := 0; i < 10; i++ { |
| msg = testTransactSerializable(ctx, t) |
| if msg == "" { |
| return |
| } |
| } |
| t.Fatal(msg) |
| } |
| |
| func testTransactSerializable(ctx context.Context, t *testing.T) string { |
| const numTransactions = 4 |
| // A transaction that sums values in class 1 and inserts that sum into class 2, |
| // or vice versa. |
| insertSum := func(tx *DB, queryClass int) error { |
| var sum int |
| err := tx.QueryRow(ctx, `SELECT SUM(value) FROM ser WHERE class = $1`, queryClass).Scan(&sum) |
| if err != nil { |
| return err |
| } |
| insertClass := 3 - queryClass |
| _, err = tx.Exec(ctx, `INSERT INTO ser (class, value) VALUES ($1, $2)`, insertClass, sum) |
| return err |
| } |
| |
| sawRetries := false |
| for i := 0; i < 10; i++ { |
| for _, stmt := range []string{ |
| `DROP TABLE IF EXISTS ser`, |
| `CREATE TABLE ser (id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY, class INTEGER, value INTEGER)`, |
| `INSERT INTO ser (class, value) VALUES (1, 10), (1, 20), (2, 100), (2, 200)`, |
| } { |
| if _, err := testDB.Exec(ctx, stmt); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // Run the following two transactions multiple times concurrently: |
| // sum rows with class = 1 and insert as a row with class 2 |
| // sum rows with class = 2 and insert as a row with class 1 |
| errc := make(chan error, numTransactions) |
| for i := 0; i < numTransactions; i++ { |
| i := i |
| go func() { |
| errc <- testDB.Transact(ctx, sql.LevelSerializable, |
| func(tx *DB) error { return insertSum(tx, 1+i%2) }) |
| }() |
| } |
| // None of the transactions should fail. |
| for i := 0; i < numTransactions; i++ { |
| if err := <-errc; err != nil { |
| return err.Error() |
| } |
| } |
| t.Logf("max retries: %d", testDB.MaxRetries()) |
| // If nothing got retried, this test isn't exercising some important behavior. |
| // Try again. |
| if testDB.MaxRetries() > 0 { |
| sawRetries = true |
| break |
| } |
| } |
| if !sawRetries { |
| return "did not see any retries" |
| } |
| |
| // Demonstrate serializability: there should be numTransactions new rows in |
| // addition to the 4 we started with, and viewing the rows in insertion |
| // order, each of the new rows should have the sum of the other class's rows |
| // so far. |
| type row struct { |
| Class, Value int |
| } |
| var rows []row |
| if err := testDB.CollectStructs(ctx, &rows, `SELECT class, value FROM ser ORDER BY id`); err != nil { |
| return err.Error() |
| } |
| const initialRows = 4 |
| if got, want := len(rows), initialRows+numTransactions; got != want { |
| return fmt.Sprintf("got %d rows, want %d", got, want) |
| } |
| sum := make([]int, 2) |
| for i, r := range rows { |
| if got, want := r.Value, sum[2-r.Class]; got != want && i >= initialRows { |
| return fmt.Sprintf("row #%d: got %d, want %d", i, got, want) |
| } |
| sum[r.Class-1] += r.Value |
| } |
| return "" |
| } |
| |
| func TestCopyDoesNotUpsert(t *testing.T) { |
| // This test verifies that copying rows into a table will not overwrite existing rows. |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout) |
| defer cancel() |
| conn, err := testDB.db.Conn(ctx) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| for _, stmt := range []string{ |
| `DROP TABLE IF EXISTS test_copy`, |
| `CREATE TABLE test_copy (i INTEGER PRIMARY KEY)`, |
| `INSERT INTO test_copy (i) VALUES (1)`, |
| } { |
| if _, err := testDB.Exec(ctx, stmt); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| err = conn.Raw(func(c interface{}) error { |
| stdConn, ok := c.(*stdlib.Conn) |
| if !ok { |
| t.Skip("DB driver is not pgx") |
| } |
| rows := [][]interface{}{{1}, {2}} |
| _, err = stdConn.Conn().CopyFrom(ctx, []string{"test_copy"}, []string{"i"}, pgx.CopyFromRows(rows)) |
| return err |
| }) |
| |
| const constraintViolationCode = "23505" |
| var gerr *pgconn.PgError |
| if !errors.As(err, &gerr) || gerr.Code != constraintViolationCode { |
| t.Errorf("got %v, wanted code %s", gerr, constraintViolationCode) |
| } |
| } |
| |
| func TestRunQueryIncrementally(t *testing.T) { |
| ctx := context.Background() |
| for _, stmt := range []string{ |
| `DROP TABLE IF EXISTS test_rqi`, |
| `CREATE TABLE test_rqi (i INTEGER PRIMARY KEY)`, |
| `INSERT INTO test_rqi (i) VALUES (1), (2), (3), (4), (5)`, |
| } { |
| if _, err := testDB.Exec(ctx, stmt); err != nil { |
| t.Fatal(err) |
| } |
| } |
| query := `SELECT i FROM test_rqi ORDER BY i LIMIT $1` |
| var got []int |
| |
| // Run until all rows consumed. |
| err := testDB.RunQueryIncrementally(ctx, query, 2, func(rows *sql.Rows) error { |
| var i int |
| if err := rows.Scan(&i); err != nil { |
| return err |
| } |
| got = append(got, i) |
| return nil |
| }, 4) |
| if err != nil { |
| t.Fatal(err) |
| } |
| want := []int{1, 2, 3, 4} |
| if !cmp.Equal(got, want) { |
| t.Errorf("got %v, want %v", got, want) |
| } |
| |
| // Stop early. |
| got = nil |
| err = testDB.RunQueryIncrementally(ctx, query, 2, func(rows *sql.Rows) error { |
| var i int |
| if err := rows.Scan(&i); err != nil { |
| return err |
| } |
| got = append(got, i) |
| if len(got) == 3 { |
| return io.EOF |
| } |
| return nil |
| }, 10) |
| if err != nil { |
| t.Fatal(err) |
| } |
| want = []int{1, 2, 3} |
| if !cmp.Equal(got, want) { |
| t.Errorf("got %v, want %v", got, want) |
| } |
| |
| } |
| |
| func TestCollectStrings(t *testing.T) { |
| ctx := context.Background() |
| for _, stmt := range []string{ |
| `DROP TABLE IF EXISTS test_cs`, |
| `CREATE TABLE test_cs (s TEXT)`, |
| `INSERT INTO test_cs (s) VALUES ('a'), ('b'), ('c')`, |
| } { |
| if _, err := testDB.Exec(ctx, stmt); err != nil { |
| t.Fatal(err) |
| } |
| } |
| got, err := testDB.CollectStrings(ctx, `SELECT s FROM test_cs`) |
| if err != nil { |
| t.Fatal(err) |
| } |
| sort.Strings(got) |
| want := []string{"a", "b", "c"} |
| if !cmp.Equal(got, want) { |
| t.Errorf("got %v, want %v", got, want) |
| } |
| } |