// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package database

import (
	"context"
	"database/sql"
	"database/sql/driver"
	"fmt"
	"reflect"
	"unicode"
	"unicode/utf8"

	"github.com/lib/pq"
)

// StructScanner returns a function that, when called on a
// struct pointer of its argument type, returns a slice of arguments suitable for
// Row.Scan or Rows.Scan. The call to either Scan will populate the exported
// fields of the struct in the order they appear in the type definition.
//
// StructScanner panics if p is not a struct or a pointer to a struct.
// The function it returns will panic if its argument is not a pointer
// to a struct.
//
// Example:
//
//	type Player struct { Name string; Score int }
//	playerScanArgs := database.StructScanner(Player{})
//	err := db.RunQuery(ctx, "SELECT name, score FROM players", func(rows *sql.Rows) error {
//	    var p Player
//	    if err := rows.Scan(playerScanArgs(&p)...); err != nil {
//	        return err
//	    }
//	    // use p
//	    return nil
//	})
func StructScanner[T any]() func(p *T) []any {
	return structScannerForType[T]()
}

type fieldInfo struct {
	num  int // to pass to v.Field
	kind reflect.Kind
}

func structScannerForType[T any]() func(p *T) []any {
	var x T
	t := reflect.TypeOf(x)
	if t.Kind() != reflect.Struct {
		panic(fmt.Sprintf("%s is not a struct", t))
	}

	// Collect the numbers of the exported fields.
	var fieldInfos []fieldInfo
	for i := 0; i < t.NumField(); i++ {
		r, _ := utf8.DecodeRuneInString(t.Field(i).Name)
		if unicode.IsUpper(r) {
			fieldInfos = append(fieldInfos, fieldInfo{i, t.Field(i).Type.Kind()})
		}
	}
	// Return a function that gets pointers to the exported fields.
	return func(p *T) []any {
		v := reflect.ValueOf(p).Elem()
		var ps []any
		for _, info := range fieldInfos {
			p := v.Field(info.num).Addr().Interface()
			switch info.kind {
			case reflect.Slice:
				if _, ok := p.(*[]byte); !ok {
					p = pq.Array(p)
				}
			case reflect.Ptr:
				p = NullPtr(p)
			default:
			}
			ps = append(ps, p)
		}
		return ps
	}
}

// NullPtr is for scanning nullable database columns into pointer variables or
// fields. When given a pointer to to a pointer to some type T, it returns a
// value that can be passed to a Scan function. If the corresponding column is
// nil, the variable will be set to nil. Otherwise, it will be set to a newly
// allocated pointer to the column value.
func NullPtr(p any) nullPtr {
	v := reflect.ValueOf(p)
	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Ptr {
		panic("NullPtr arg must be pointer to pointer")
	}
	return nullPtr{v}
}

type nullPtr struct {
	// ptr is a pointer to a pointer to something: **T
	ptr reflect.Value
}

func (n nullPtr) Scan(value any) error {
	// n.ptr is like a variable v of type **T
	ntype := n.ptr.Elem().Type() // T
	if value == nil {
		n.ptr.Elem().Set(reflect.Zero(ntype)) // *v = nil
	} else {
		p := reflect.New(ntype.Elem())       // p := new(T)
		p.Elem().Set(reflect.ValueOf(value)) // *p = value
		n.ptr.Elem().Set(p)                  // *v = p
	}
	return nil
}

func (n nullPtr) Value() (driver.Value, error) {
	if n.ptr.Elem().IsNil() {
		return nil, nil
	}
	return n.ptr.Elem().Elem().Interface(), nil
}

// CollectStructs scans the the rows from the query into structs and returns a slice of them.
// Example:
//
//	type Player struct { Name string; Score int }
//	var players []Player
//	err := db.CollectStructs(ctx, &players, "SELECT name, score FROM players")
func CollectStructs[T any](ctx context.Context, db *DB, query string, args ...any) ([]T, error) {
	scanner := structScannerForType[T]()
	var ts []T
	err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
		var s T
		if err := rows.Scan(scanner(&s)...); err != nil {
			return err
		}
		ts = append(ts, s)
		return nil
	}, args...)
	if err != nil {
		return nil, err
	}
	return ts, nil
}

func CollectStructPtrs[T any](ctx context.Context, db *DB, query string, args ...any) ([]*T, error) {
	scanner := structScannerForType[T]()
	var ts []*T
	err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
		var s T
		if err := rows.Scan(scanner(&s)...); err != nil {
			return err
		}
		ts = append(ts, &s)
		return nil
	}, args...)
	if err != nil {
		return nil, err
	}
	return ts, nil
}
