blob: a878985edf5617d523addec1543e21d666b9ab07 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
)
type copyRow struct {
Key int
Value string
}
func TestCopyInsert(t *testing.T) {
pgxOnly(t)
ctx := context.Background()
for _, stmt := range []string{
`DROP TABLE IF EXISTS test_streaming_insert`,
`CREATE TABLE test_streaming_insert (key INTEGER PRIMARY KEY, value TEXT)`,
`INSERT INTO test_streaming_insert (key, value) VALUES (1, 'foo'), (2, 'bar')`,
} {
if _, err := testDB.Exec(ctx, stmt); err != nil {
t.Fatal(err)
}
}
rows := [][]any{
{3, "baz"},
{4, "moo"},
}
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyInsert(ctx, "test_streaming_insert", []string{"key", "value"}, pgx.CopyFromRows(rows), "")
})
if err != nil {
t.Fatal(err)
}
compareRows(ctx, t, "test_streaming_insert", []copyRow{
{1, "foo"},
{2, "bar"},
{3, "baz"},
{4, "moo"},
})
}
func TestCopyUpsert(t *testing.T) {
pgxOnly(t)
ctx := context.Background()
for _, stmt := range []string{
`DROP TABLE IF EXISTS test_streaming_upsert`,
`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
`INSERT INTO test_streaming_upsert (key, value) VALUES (1, 'foo'), (2, 'bar')`,
} {
if _, err := testDB.Exec(ctx, stmt); err != nil {
t.Fatal(err)
}
}
rows := [][]any{
{3, "baz"}, // new row
{1, "moo"}, // replace "foo" with "moo"
}
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "")
})
if err != nil {
t.Fatal(err)
}
compareRows(ctx, t, "test_streaming_upsert", []copyRow{
{1, "moo"},
{2, "bar"},
{3, "baz"},
})
}
func compareRows(ctx context.Context, t *testing.T, table string, wantRows []copyRow) {
t.Helper()
q := fmt.Sprintf(`SELECT * FROM %s ORDER BY key`, table)
gotRows, err := CollectStructs[copyRow](ctx, testDB, q)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(gotRows, wantRows) {
t.Errorf("got %v, want %v", gotRows, wantRows)
}
}
func TestCopyUpsertGeneratedColumn(t *testing.T) {
pgxOnly(t)
ctx := context.Background()
stmt := `
DROP TABLE IF EXISTS test_copy_gen;
CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key));
INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')`
if _, err := testDB.Exec(ctx, stmt); err != nil {
t.Fatal(err)
}
rows := [][]any{
{13, "baz"}, // new row
{11, "moo"}, // replace "foo" with "moo"
}
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id")
})
if err != nil {
t.Fatal(err)
}
type row struct {
ID int64
Key int
Value string
}
wantRows := []row{
{1, 11, "moo"},
{2, 12, "bar"},
{3, 13, "baz"},
}
gotRows, err := CollectStructs[row](ctx, testDB, `SELECT * FROM test_copy_gen ORDER BY ID`)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(gotRows, wantRows) {
t.Errorf("got %v, want %v", gotRows, wantRows)
}
}
func pgxOnly(t *testing.T) {
conn, err := testDB.db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
conn.Raw(func(c any) error {
if _, ok := c.(*stdlib.Conn); !ok {
t.Skip("skipping; DB driver not pgx")
}
return nil
})
}