| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package database |
| |
| import ( |
| "context" |
| "database/sql" |
| "database/sql/driver" |
| "errors" |
| "fmt" |
| "reflect" |
| "unicode" |
| "unicode/utf8" |
| |
| "github.com/lib/pq" |
| ) |
| |
| // StructScanner takes a struct and returns a function that, when called on a |
| // struct pointer of that type, returns a slice of arguments suitable for |
| // Row.Scan or Rows.Scan. The call to either Scan will populate the exported |
| // fields of the struct in the order they appear in the type definition. |
| // |
| // StructScanner panics if p is not a struct or a pointer to a struct. |
| // The function it returns will panic if its argument is not a pointer |
| // to a struct. |
| // |
| // Example: |
| // type Player struct { Name string; Score int } |
| // playerScanArgs := database.StructScanner(Player{}) |
| // err := db.RunQuery(ctx, "SELECT name, score FROM players", func(rows *sql.Rows) error { |
| // var p Player |
| // if err := rows.Scan(playerScanArgs(&p)...); err != nil { |
| // return err |
| // } |
| // // use p |
| // return nil |
| // }) |
| func StructScanner(s interface{}) func(p interface{}) []interface{} { |
| v := reflect.ValueOf(s) |
| if v.Kind() == reflect.Ptr { |
| v = v.Elem() |
| } |
| return structScannerForType(v.Type()) |
| } |
| |
| type fieldInfo struct { |
| num int // to pass to v.Field |
| kind reflect.Kind |
| } |
| |
| func structScannerForType(t reflect.Type) func(p interface{}) []interface{} { |
| if t.Kind() != reflect.Struct { |
| panic(fmt.Sprintf("%s is not a struct", t)) |
| } |
| |
| // Collect the numbers of the exported fields. |
| var fieldInfos []fieldInfo |
| for i := 0; i < t.NumField(); i++ { |
| r, _ := utf8.DecodeRuneInString(t.Field(i).Name) |
| if unicode.IsUpper(r) { |
| fieldInfos = append(fieldInfos, fieldInfo{i, t.Field(i).Type.Kind()}) |
| } |
| } |
| // Return a function that gets pointers to the exported fields. |
| return func(p interface{}) []interface{} { |
| v := reflect.ValueOf(p).Elem() |
| var ps []interface{} |
| for _, info := range fieldInfos { |
| p := v.Field(info.num).Addr().Interface() |
| switch info.kind { |
| case reflect.Slice: |
| p = pq.Array(p) |
| case reflect.Ptr: |
| p = NullPtr(p) |
| default: |
| } |
| ps = append(ps, p) |
| } |
| return ps |
| } |
| } |
| |
| // NullPtr is for scanning nullable database columns into pointer variables or |
| // fields. When given a pointer to to a pointer to some type T, it returns a |
| // value that can be passed to a Scan function. If the corresponding column is |
| // nil, the variable will be set to nil. Otherwise, it will be set to a newly |
| // allocated pointer to the column value. |
| func NullPtr(p interface{}) nullPtr { |
| v := reflect.ValueOf(p) |
| if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Ptr { |
| panic("NullPtr arg must be pointer to pointer") |
| } |
| return nullPtr{v} |
| } |
| |
| type nullPtr struct { |
| // ptr is a pointer to a pointer to something: **T |
| ptr reflect.Value |
| } |
| |
| func (n nullPtr) Scan(value interface{}) error { |
| // n.ptr is like a variable v of type **T |
| ntype := n.ptr.Elem().Type() // T |
| if value == nil { |
| n.ptr.Elem().Set(reflect.Zero(ntype)) // *v = nil |
| } else { |
| p := reflect.New(ntype.Elem()) // p := new(T) |
| p.Elem().Set(reflect.ValueOf(value)) // *p = value |
| n.ptr.Elem().Set(p) // *v = p |
| } |
| return nil |
| } |
| |
| func (n nullPtr) Value() (driver.Value, error) { |
| if n.ptr.Elem().IsNil() { |
| return nil, nil |
| } |
| return n.ptr.Elem().Elem().Interface(), nil |
| } |
| |
| // CollectStructs scans the the rows from the query into structs and appends |
| // them to pslice, which must be a pointer to a slice of structs. |
| // Example: |
| // type Player struct { Name string; Score int } |
| // var players []Player |
| // err := db.CollectStructs(ctx, &players, "SELECT name, score FROM players") |
| func (db *DB) CollectStructs(ctx context.Context, pslice interface{}, query string, args ...interface{}) error { |
| v := reflect.ValueOf(pslice) |
| if v.Kind() != reflect.Ptr { |
| return errors.New("collectStructs: arg is not a pointer") |
| } |
| ve := v.Elem() |
| if ve.Kind() != reflect.Slice { |
| return errors.New("collectStructs: arg is not a pointer to a slice") |
| } |
| isPointer := false |
| et := ve.Type().Elem() // slice element type |
| if et.Kind() == reflect.Ptr { |
| isPointer = true |
| et = et.Elem() |
| } |
| if et.Kind() != reflect.Struct { |
| return fmt.Errorf("slice element type is neither struct nor struct pointer: %s", ve.Type().Elem()) |
| } |
| |
| scanner := structScannerForType(et) |
| err := db.RunQuery(ctx, query, func(rows *sql.Rows) error { |
| e := reflect.New(et) |
| if err := rows.Scan(scanner(e.Interface())...); err != nil { |
| return err |
| } |
| if !isPointer { |
| e = e.Elem() |
| } |
| ve = reflect.Append(ve, e) |
| return nil |
| }, args...) |
| if err != nil { |
| return err |
| } |
| v.Elem().Set(ve) |
| return nil |
| } |