blob: 884915970d1d6f0f2f0084b267e207a75e683bb5 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package database
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"unicode"
"unicode/utf8"
"github.com/lib/pq"
)
// StructScanner takes a struct and returns a function that, when called on a
// struct pointer of that type, returns a slice of arguments suitable for
// Row.Scan or Rows.Scan. The call to either Scan will populate the exported
// fields of the struct in the order they appear in the type definition.
//
// StructScanner panics if p is not a struct or a pointer to a struct.
// The function it returns will panic if its argument is not a pointer
// to a struct.
//
// Example:
// type Player struct { Name string; Score int }
// playerScanArgs := database.StructScanner(Player{})
// err := db.RunQuery(ctx, "SELECT name, score FROM players", func(rows *sql.Rows) error {
// var p Player
// if err := rows.Scan(playerScanArgs(&p)...); err != nil {
// return err
// }
// // use p
// return nil
// })
func StructScanner(s interface{}) func(p interface{}) []interface{} {
v := reflect.ValueOf(s)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
return structScannerForType(v.Type())
}
type fieldInfo struct {
num int // to pass to v.Field
kind reflect.Kind
}
func structScannerForType(t reflect.Type) func(p interface{}) []interface{} {
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("%s is not a struct", t))
}
// Collect the numbers of the exported fields.
var fieldInfos []fieldInfo
for i := 0; i < t.NumField(); i++ {
r, _ := utf8.DecodeRuneInString(t.Field(i).Name)
if unicode.IsUpper(r) {
fieldInfos = append(fieldInfos, fieldInfo{i, t.Field(i).Type.Kind()})
}
}
// Return a function that gets pointers to the exported fields.
return func(p interface{}) []interface{} {
v := reflect.ValueOf(p).Elem()
var ps []interface{}
for _, info := range fieldInfos {
p := v.Field(info.num).Addr().Interface()
switch info.kind {
case reflect.Slice:
if _, ok := p.(*[]byte); !ok {
p = pq.Array(p)
}
case reflect.Ptr:
p = NullPtr(p)
default:
}
ps = append(ps, p)
}
return ps
}
}
// NullPtr is for scanning nullable database columns into pointer variables or
// fields. When given a pointer to to a pointer to some type T, it returns a
// value that can be passed to a Scan function. If the corresponding column is
// nil, the variable will be set to nil. Otherwise, it will be set to a newly
// allocated pointer to the column value.
func NullPtr(p interface{}) nullPtr {
v := reflect.ValueOf(p)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Ptr {
panic("NullPtr arg must be pointer to pointer")
}
return nullPtr{v}
}
type nullPtr struct {
// ptr is a pointer to a pointer to something: **T
ptr reflect.Value
}
func (n nullPtr) Scan(value interface{}) error {
// n.ptr is like a variable v of type **T
ntype := n.ptr.Elem().Type() // T
if value == nil {
n.ptr.Elem().Set(reflect.Zero(ntype)) // *v = nil
} else {
p := reflect.New(ntype.Elem()) // p := new(T)
p.Elem().Set(reflect.ValueOf(value)) // *p = value
n.ptr.Elem().Set(p) // *v = p
}
return nil
}
func (n nullPtr) Value() (driver.Value, error) {
if n.ptr.Elem().IsNil() {
return nil, nil
}
return n.ptr.Elem().Elem().Interface(), nil
}
// CollectStructs scans the the rows from the query into structs and appends
// them to pslice, which must be a pointer to a slice of structs.
// Example:
// type Player struct { Name string; Score int }
// var players []Player
// err := db.CollectStructs(ctx, &players, "SELECT name, score FROM players")
func (db *DB) CollectStructs(ctx context.Context, pslice interface{}, query string, args ...interface{}) error {
v := reflect.ValueOf(pslice)
if v.Kind() != reflect.Ptr {
return errors.New("collectStructs: arg is not a pointer")
}
ve := v.Elem()
if ve.Kind() != reflect.Slice {
return errors.New("collectStructs: arg is not a pointer to a slice")
}
isPointer := false
et := ve.Type().Elem() // slice element type
if et.Kind() == reflect.Ptr {
isPointer = true
et = et.Elem()
}
if et.Kind() != reflect.Struct {
return fmt.Errorf("slice element type is neither struct nor struct pointer: %s", ve.Type().Elem())
}
scanner := structScannerForType(et)
err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
e := reflect.New(et)
if err := rows.Scan(scanner(e.Interface())...); err != nil {
return err
}
if !isPointer {
e = e.Elem()
}
ve = reflect.Append(ve, e)
return nil
}, args...)
if err != nil {
return err
}
v.Elem().Set(ve)
return nil
}