blob: 84f569776b1dc6cdc655d2eef7f4ad555e37b11c [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"strings"
"testing"
"time"
"golang.org/x/discovery/internal/testing/dbtest"
)
const testTimeout = 5 * time.Second
var testDB *DB
func TestMain(m *testing.M) {
const dbName = "discovery_postgres_test"
if err := dbtest.CreateDBIfNotExists(dbName); err != nil {
log.Fatal(err)
}
var err error
testDB, err = Open("postgres", dbtest.DBConnURI(dbName))
if err != nil {
log.Fatal(err)
}
code := m.Run()
if err := testDB.Close(); err != nil {
log.Fatal(err)
}
os.Exit(code)
}
func TestBulkInsert(t *testing.T) {
table := "test_bulk_insert"
for _, tc := range []struct {
name string
columns []string
values []interface{}
conflictAction string
wantErr bool
wantCount int
}{
{
name: "test-one-row",
columns: []string{"colA"},
values: []interface{}{"valueA"},
wantCount: 1,
},
{
name: "test-multiple-rows",
columns: []string{"colA"},
values: []interface{}{"valueA1", "valueA2", "valueA3"},
wantCount: 3,
},
{
name: "test-invalid-column-name",
columns: []string{"invalid_col"},
values: []interface{}{"valueA"},
wantErr: true,
},
{
name: "test-mismatch-num-cols-and-vals",
columns: []string{"colA", "colB"},
values: []interface{}{"valueA1", "valueB1", "valueA2"},
wantErr: true,
},
{
name: "test-conflict-no-action-true",
columns: []string{"colA"},
values: []interface{}{"valueA", "valueA"},
conflictAction: OnConflictDoNothing,
wantCount: 1,
},
{
name: "test-conflict-no-action-false",
columns: []string{"colA"},
values: []interface{}{"valueA", "valueA"},
wantErr: true,
},
{
// This should execute the statement
// INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)');
// which will insert a row with path value:
// '); TRUNCATE series CASCADE;)
// Rather than the statement
// INSERT INTO series (path) VALUES (''); TRUNCATE series CASCADE;));
// which would truncate most tables in the database.
name: "test-sql-injection",
columns: []string{"colA"},
values: []interface{}{fmt.Sprintf("''); TRUNCATE %s CASCADE;))", table)},
conflictAction: OnConflictDoNothing,
wantCount: 1,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
createQuery := fmt.Sprintf(`CREATE TABLE %s (
colA TEXT NOT NULL,
colB TEXT,
PRIMARY KEY (colA)
);`, table)
if _, err := testDB.Exec(ctx, createQuery); err != nil {
t.Fatal(err)
}
defer func() {
dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table)
if _, err := testDB.Exec(ctx, dropTableQuery); err != nil {
t.Fatal(err)
}
}()
if err := testDB.Transact(ctx, func(db *DB) error {
return db.BulkInsert(ctx, table, tc.columns, tc.values, tc.conflictAction)
}); tc.wantErr && err == nil || !tc.wantErr && err != nil {
t.Errorf("testDB.Transact: %v | wantErr = %t", err, tc.wantErr)
}
if tc.wantCount != 0 {
var count int
query := "SELECT COUNT(*) FROM " + table
row := testDB.QueryRow(ctx, query)
err := row.Scan(&count)
if err != nil {
t.Fatalf("testDB.queryRow(%q): %v", query, err)
}
if count != tc.wantCount {
t.Errorf("testDB.queryRow(%q) = %d; want = %d", query, count, tc.wantCount)
}
}
})
}
}
func TestLargeBulkInsert(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil {
t.Fatal(err)
}
const size = 150000
vals := make([]interface{}, size)
for i := 0; i < size; i++ {
vals[i] = i + 1
}
if err := testDB.Transact(ctx, func(db *DB) error {
return db.BulkInsert(ctx, "test_large_bulk", []string{"i"}, vals, "")
}); err != nil {
t.Fatal(err)
}
rows, err := testDB.Query(ctx, `SELECT i FROM test_large_bulk;`)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
sum := int64(0)
for rows.Next() {
var i int64
if err := rows.Scan(&i); err != nil {
t.Fatal(err)
}
sum += i
}
var want int64 = size * (size + 1) / 2
if sum != want {
t.Errorf("sum = %d, want %d", sum, want)
}
}
func TestDBAfterTransactFails(t *testing.T) {
ctx := context.Background()
var tx *DB
err := testDB.Transact(ctx, func(d *DB) error {
tx = d
return nil
})
if err != nil {
t.Fatal(err)
}
var i int
err = tx.QueryRow(ctx, `SELECT 1`).Scan(&i)
if err == nil {
t.Fatal("got nil, want error")
}
}
func TestBuildBulkUpdateQuery(t *testing.T) {
q := buildBulkUpdateQuery("tab", []string{"K", "C1", "C2"}, []string{"TEXT", "INT", "BOOL"})
got := strings.Join(strings.Fields(q), " ")
w := `
UPDATE tab
SET C1 = data.C1, C2 = data.C2
FROM (SELECT UNNEST($1::TEXT[]) AS K, UNNEST($2::INT[]) AS C1, UNNEST($3::BOOL[]) AS C2) AS data
WHERE tab.K = data.K`
want := strings.Join(strings.Fields(w), " ")
if got != want {
t.Errorf("\ngot\n%s\nwant\n%s", got, want)
}
}
func TestBulkUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
defer func(old int) { maxBulkUpdateArrayLen = old }(maxBulkUpdateArrayLen)
maxBulkUpdateArrayLen = 5
if _, err := testDB.Exec(ctx, `CREATE TABLE bulk_update (a INT, b INT)`); err != nil {
t.Fatal(err)
}
defer func() {
if _, err := testDB.Exec(ctx, `DROP TABLE bulk_update`); err != nil {
t.Fatal(err)
}
}()
cols := []string{"a", "b"}
var values []interface{}
for i := 0; i < 50; i++ {
values = append(values, i, i)
}
err := testDB.Transact(ctx, func(tx *DB) error {
return tx.BulkInsert(ctx, "bulk_update", cols, values, "")
})
if err != nil {
t.Fatal(err)
}
// Update all even values of column a.
updateVals := make([][]interface{}, 2)
for i := 0; i < len(values)/2; i += 2 {
updateVals[0] = append(updateVals[0], i)
updateVals[1] = append(updateVals[1], -i)
}
err = testDB.Transact(ctx, func(tx *DB) error {
return tx.BulkUpdate(ctx, "bulk_update", cols, []string{"INT", "INT"}, updateVals)
})
if err != nil {
t.Fatal(err)
}
err = testDB.RunQuery(ctx, `SELECT a, b FROM bulk_update`, func(rows *sql.Rows) error {
var a, b int
if err := rows.Scan(&a, &b); err != nil {
return err
}
want := a
if a%2 == 0 {
want = -a
}
if b != want {
t.Fatalf("a=%d: got %d, want %d", a, b, want)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}