blob: e057d0692d49dfc93e449b811cc9a9227a0bc4f5 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"unicode"
"unicode/utf8"
"github.com/lib/pq"
)
// StructScanner returns a function that, when called on a
// struct pointer of its argument type, returns a slice of arguments suitable for
// Row.Scan or Rows.Scan. The call to either Scan will populate the exported
// fields of the struct in the order they appear in the type definition.
//
// StructScanner panics if p is not a struct or a pointer to a struct.
// The function it returns will panic if its argument is not a pointer
// to a struct.
//
// Example:
//
// type Player struct { Name string; Score int }
// playerScanArgs := database.StructScanner(Player{})
// err := db.RunQuery(ctx, "SELECT name, score FROM players", func(rows *sql.Rows) error {
// var p Player
// if err := rows.Scan(playerScanArgs(&p)...); err != nil {
// return err
// }
// // use p
// return nil
// })
func StructScanner[T any]() func(p *T) []interface{} {
return structScannerForType[T]()
}
type fieldInfo struct {
num int // to pass to v.Field
kind reflect.Kind
}
func structScannerForType[T any]() func(p *T) []interface{} {
var x T
t := reflect.TypeOf(x)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("%s is not a struct", t))
}
// Collect the numbers of the exported fields.
var fieldInfos []fieldInfo
for i := 0; i < t.NumField(); i++ {
r, _ := utf8.DecodeRuneInString(t.Field(i).Name)
if unicode.IsUpper(r) {
fieldInfos = append(fieldInfos, fieldInfo{i, t.Field(i).Type.Kind()})
}
}
// Return a function that gets pointers to the exported fields.
return func(p *T) []interface{} {
v := reflect.ValueOf(p).Elem()
var ps []interface{}
for _, info := range fieldInfos {
p := v.Field(info.num).Addr().Interface()
switch info.kind {
case reflect.Slice:
if _, ok := p.(*[]byte); !ok {
p = pq.Array(p)
}
case reflect.Ptr:
p = NullPtr(p)
default:
}
ps = append(ps, p)
}
return ps
}
}
// NullPtr is for scanning nullable database columns into pointer variables or
// fields. When given a pointer to to a pointer to some type T, it returns a
// value that can be passed to a Scan function. If the corresponding column is
// nil, the variable will be set to nil. Otherwise, it will be set to a newly
// allocated pointer to the column value.
func NullPtr(p interface{}) nullPtr {
v := reflect.ValueOf(p)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Ptr {
panic("NullPtr arg must be pointer to pointer")
}
return nullPtr{v}
}
type nullPtr struct {
// ptr is a pointer to a pointer to something: **T
ptr reflect.Value
}
func (n nullPtr) Scan(value interface{}) error {
// n.ptr is like a variable v of type **T
ntype := n.ptr.Elem().Type() // T
if value == nil {
n.ptr.Elem().Set(reflect.Zero(ntype)) // *v = nil
} else {
p := reflect.New(ntype.Elem()) // p := new(T)
p.Elem().Set(reflect.ValueOf(value)) // *p = value
n.ptr.Elem().Set(p) // *v = p
}
return nil
}
func (n nullPtr) Value() (driver.Value, error) {
if n.ptr.Elem().IsNil() {
return nil, nil
}
return n.ptr.Elem().Elem().Interface(), nil
}
// CollectStructs scans the the rows from the query into structs and returns a slice of them.
// Example:
//
// type Player struct { Name string; Score int }
// var players []Player
// err := db.CollectStructs(ctx, &players, "SELECT name, score FROM players")
func CollectStructs[T any](ctx context.Context, db *DB, query string, args ...interface{}) ([]T, error) {
scanner := structScannerForType[T]()
var ts []T
err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
var s T
if err := rows.Scan(scanner(&s)...); err != nil {
return err
}
ts = append(ts, s)
return nil
}, args...)
if err != nil {
return nil, err
}
return ts, nil
}
func CollectStructPtrs[T any](ctx context.Context, db *DB, query string, args ...interface{}) ([]*T, error) {
scanner := structScannerForType[T]()
var ts []*T
err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
var s T
if err := rows.Scan(scanner(&s)...); err != nil {
return err
}
ts = append(ts, &s)
return nil
}, args...)
if err != nil {
return nil, err
}
return ts, nil
}