blob: 1011c5023a97e7147d49067463d020f2a13d5838 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"sort"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/lib/pq"
)
type testStruct struct {
Name string
Score int
Slice []int64
}
func TestStructScanner(t *testing.T) {
var s testStruct
f := StructScanner(s)
args := f(&s)
*args[0].(*string) = "foo"
*args[1].(*int) = 3
*args[2].(*pq.Int64Array) = []int64{1, 2, 3}
want := testStruct{"foo", 3, []int64{1, 2, 3}}
if !cmp.Equal(s, want) {
t.Errorf("got %+v, want %+v", s, want)
}
// StructScanner should also work on pointers to structs.
_ = StructScanner(&s)
}
func TestCollectStructs(t *testing.T) {
ctx := context.Background()
if _, err := testDB.Exec(ctx, "DROP TABLE IF EXISTS structs"); err != nil {
t.Fatal(err)
}
_, err := testDB.Exec(ctx, `
CREATE TABLE structs (
name text NOT NULL,
score integer NOT NULL,
slice integer[]
)`)
if err != nil {
t.Fatal(err)
}
if err := testDB.BulkInsert(ctx, "structs", []string{"name", "score", "slice"}, []interface{}{
"A", 1, pq.Array([]int64(nil)),
"B", 2, pq.Array([]int64{1, 2}),
"C", 3, pq.Array([]int64{}),
}, ""); err != nil {
t.Fatal(err)
}
query := `SELECT name, score, slice FROM structs`
var got []testStruct
if err := testDB.CollectStructs(ctx, query, &got); err != nil {
t.Fatal(err)
}
sort.Slice(got, func(i, j int) bool { return got[i].Name < got[j].Name })
want := []testStruct{{"A", 1, nil}, {"B", 2, []int64{1, 2}}, {"C", 3, []int64{}}}
if !cmp.Equal(got, want) {
t.Errorf("got %+v, want %+v", got, want)
}
// Same, but with a slice of struct pointers.
var gotp []*testStruct
if err := testDB.CollectStructs(ctx, query, &gotp); err != nil {
t.Fatal(err)
}
sort.Slice(gotp, func(i, j int) bool { return got[i].Name < got[j].Name })
wantp := []*testStruct{{"A", 1, nil}, {"B", 2, []int64{1, 2}}, {"C", 3, []int64{}}}
if !cmp.Equal(gotp, wantp) {
t.Errorf("got %+v, want %+v", gotp, wantp)
}
}