| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Type conversions for Scan. |
| |
| package sql |
| |
| import ( |
| "database/sql/driver" |
| "errors" |
| "fmt" |
| "reflect" |
| "strconv" |
| "time" |
| "unicode" |
| "unicode/utf8" |
| ) |
| |
| var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error |
| |
| func describeNamedValue(nv *driver.NamedValue) string { |
| if len(nv.Name) == 0 { |
| return fmt.Sprintf("$%d", nv.Ordinal) |
| } |
| return fmt.Sprintf("with name %q", nv.Name) |
| } |
| |
| func validateNamedValueName(name string) error { |
| if len(name) == 0 { |
| return nil |
| } |
| r, _ := utf8.DecodeRuneInString(name) |
| if unicode.IsLetter(r) { |
| return nil |
| } |
| return fmt.Errorf("name %q does not begin with a letter", name) |
| } |
| |
| // ccChecker wraps the driver.ColumnConverter and allows it to be used |
| // as if it were a NamedValueChecker. If the driver ColumnConverter |
| // is not present then the NamedValueChecker will return driver.ErrSkip. |
| type ccChecker struct { |
| cci driver.ColumnConverter |
| want int |
| } |
| |
| func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { |
| if c.cci == nil { |
| return driver.ErrSkip |
| } |
| // The column converter shouldn't be called on any index |
| // it isn't expecting. The final error will be thrown |
| // in the argument converter loop. |
| index := nv.Ordinal - 1 |
| if c.want <= index { |
| return nil |
| } |
| |
| // First, see if the value itself knows how to convert |
| // itself to a driver type. For example, a NullString |
| // struct changing into a string or nil. |
| if vr, ok := nv.Value.(driver.Valuer); ok { |
| sv, err := callValuerValue(vr) |
| if err != nil { |
| return err |
| } |
| if !driver.IsValue(sv) { |
| return fmt.Errorf("non-subset type %T returned from Value", sv) |
| } |
| nv.Value = sv |
| } |
| |
| // Second, ask the column to sanity check itself. For |
| // example, drivers might use this to make sure that |
| // an int64 values being inserted into a 16-bit |
| // integer field is in range (before getting |
| // truncated), or that a nil can't go into a NOT NULL |
| // column before going across the network to get the |
| // same error. |
| var err error |
| arg := nv.Value |
| nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) |
| if err != nil { |
| return err |
| } |
| if !driver.IsValue(nv.Value) { |
| return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) |
| } |
| return nil |
| } |
| |
| // defaultCheckNamedValue wraps the default ColumnConverter to have the same |
| // function signature as the CheckNamedValue in the driver.NamedValueChecker |
| // interface. |
| func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { |
| nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) |
| return err |
| } |
| |
| // driverArgs converts arguments from callers of Stmt.Exec and |
| // Stmt.Query into driver Values. |
| // |
| // The statement ds may be nil, if no statement is available. |
| func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { |
| nvargs := make([]driver.NamedValue, len(args)) |
| |
| // -1 means the driver doesn't know how to count the number of |
| // placeholders, so we won't sanity check input here and instead let the |
| // driver deal with errors. |
| want := -1 |
| |
| var si driver.Stmt |
| var cc ccChecker |
| if ds != nil { |
| si = ds.si |
| want = ds.si.NumInput() |
| cc.want = want |
| } |
| |
| // Check all types of interfaces from the start. |
| // Drivers may opt to use the NamedValueChecker for special |
| // argument types, then return driver.ErrSkip to pass it along |
| // to the column converter. |
| nvc, ok := si.(driver.NamedValueChecker) |
| if !ok { |
| nvc, ok = ci.(driver.NamedValueChecker) |
| } |
| cci, ok := si.(driver.ColumnConverter) |
| if ok { |
| cc.cci = cci |
| } |
| |
| // Loop through all the arguments, checking each one. |
| // If no error is returned simply increment the index |
| // and continue. However if driver.ErrRemoveArgument |
| // is returned the argument is not included in the query |
| // argument list. |
| var err error |
| var n int |
| for _, arg := range args { |
| nv := &nvargs[n] |
| if np, ok := arg.(NamedArg); ok { |
| if err = validateNamedValueName(np.Name); err != nil { |
| return nil, err |
| } |
| arg = np.Value |
| nv.Name = np.Name |
| } |
| nv.Ordinal = n + 1 |
| nv.Value = arg |
| |
| // Checking sequence has four routes: |
| // A: 1. Default |
| // B: 1. NamedValueChecker 2. Column Converter 3. Default |
| // C: 1. NamedValueChecker 3. Default |
| // D: 1. Column Converter 2. Default |
| // |
| // The only time a Column Converter is called is first |
| // or after NamedValueConverter. If first it is handled before |
| // the nextCheck label. Thus for repeats tries only when the |
| // NamedValueConverter is selected should the Column Converter |
| // be used in the retry. |
| checker := defaultCheckNamedValue |
| nextCC := false |
| switch { |
| case nvc != nil: |
| nextCC = cci != nil |
| checker = nvc.CheckNamedValue |
| case cci != nil: |
| checker = cc.CheckNamedValue |
| } |
| |
| nextCheck: |
| err = checker(nv) |
| switch err { |
| case nil: |
| n++ |
| continue |
| case driver.ErrRemoveArgument: |
| nvargs = nvargs[:len(nvargs)-1] |
| continue |
| case driver.ErrSkip: |
| if nextCC { |
| nextCC = false |
| checker = cc.CheckNamedValue |
| } else { |
| checker = defaultCheckNamedValue |
| } |
| goto nextCheck |
| default: |
| return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) |
| } |
| } |
| |
| // Check the length of arguments after conversion to allow for omitted |
| // arguments. |
| if want != -1 && len(nvargs) != want { |
| return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) |
| } |
| |
| return nvargs, nil |
| |
| } |
| |
| // convertAssign is the same as convertAssignRows, but without the optional |
| // rows argument. |
| func convertAssign(dest, src interface{}) error { |
| return convertAssignRows(dest, src, nil) |
| } |
| |
| // convertAssignRows copies to dest the value in src, converting it if possible. |
| // An error is returned if the copy would result in loss of information. |
| // dest should be a pointer type. If rows is passed in, the rows will |
| // be used as the parent for any cursor values converted from a |
| // driver.Rows to a *Rows. |
| func convertAssignRows(dest, src interface{}, rows *Rows) error { |
| // Common cases, without reflect. |
| switch s := src.(type) { |
| case string: |
| switch d := dest.(type) { |
| case *string: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = s |
| return nil |
| case *[]byte: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = []byte(s) |
| return nil |
| case *RawBytes: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = append((*d)[:0], s...) |
| return nil |
| } |
| case []byte: |
| switch d := dest.(type) { |
| case *string: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = string(s) |
| return nil |
| case *interface{}: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = cloneBytes(s) |
| return nil |
| case *[]byte: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = cloneBytes(s) |
| return nil |
| case *RawBytes: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = s |
| return nil |
| } |
| case time.Time: |
| switch d := dest.(type) { |
| case *time.Time: |
| *d = s |
| return nil |
| case *string: |
| *d = s.Format(time.RFC3339Nano) |
| return nil |
| case *[]byte: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = []byte(s.Format(time.RFC3339Nano)) |
| return nil |
| case *RawBytes: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) |
| return nil |
| } |
| case nil: |
| switch d := dest.(type) { |
| case *interface{}: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = nil |
| return nil |
| case *[]byte: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = nil |
| return nil |
| case *RawBytes: |
| if d == nil { |
| return errNilPtr |
| } |
| *d = nil |
| return nil |
| } |
| // The driver is returning a cursor the client may iterate over. |
| case driver.Rows: |
| switch d := dest.(type) { |
| case *Rows: |
| if d == nil { |
| return errNilPtr |
| } |
| if rows == nil { |
| return errors.New("invalid context to convert cursor rows, missing parent *Rows") |
| } |
| rows.closemu.Lock() |
| *d = Rows{ |
| dc: rows.dc, |
| releaseConn: func(error) {}, |
| rowsi: s, |
| } |
| // Chain the cancel function. |
| parentCancel := rows.cancel |
| rows.cancel = func() { |
| // When Rows.cancel is called, the closemu will be locked as well. |
| // So we can access rs.lasterr. |
| d.close(rows.lasterr) |
| if parentCancel != nil { |
| parentCancel() |
| } |
| } |
| rows.closemu.Unlock() |
| return nil |
| } |
| } |
| |
| var sv reflect.Value |
| |
| switch d := dest.(type) { |
| case *string: |
| sv = reflect.ValueOf(src) |
| switch sv.Kind() { |
| case reflect.Bool, |
| reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, |
| reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, |
| reflect.Float32, reflect.Float64: |
| *d = asString(src) |
| return nil |
| } |
| case *[]byte: |
| sv = reflect.ValueOf(src) |
| if b, ok := asBytes(nil, sv); ok { |
| *d = b |
| return nil |
| } |
| case *RawBytes: |
| sv = reflect.ValueOf(src) |
| if b, ok := asBytes([]byte(*d)[:0], sv); ok { |
| *d = RawBytes(b) |
| return nil |
| } |
| case *bool: |
| bv, err := driver.Bool.ConvertValue(src) |
| if err == nil { |
| *d = bv.(bool) |
| } |
| return err |
| case *interface{}: |
| *d = src |
| return nil |
| } |
| |
| if scanner, ok := dest.(Scanner); ok { |
| return scanner.Scan(src) |
| } |
| |
| dpv := reflect.ValueOf(dest) |
| if dpv.Kind() != reflect.Ptr { |
| return errors.New("destination not a pointer") |
| } |
| if dpv.IsNil() { |
| return errNilPtr |
| } |
| |
| if !sv.IsValid() { |
| sv = reflect.ValueOf(src) |
| } |
| |
| dv := reflect.Indirect(dpv) |
| if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { |
| switch b := src.(type) { |
| case []byte: |
| dv.Set(reflect.ValueOf(cloneBytes(b))) |
| default: |
| dv.Set(sv) |
| } |
| return nil |
| } |
| |
| if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { |
| dv.Set(sv.Convert(dv.Type())) |
| return nil |
| } |
| |
| // The following conversions use a string value as an intermediate representation |
| // to convert between various numeric types. |
| // |
| // This also allows scanning into user defined types such as "type Int int64". |
| // For symmetry, also check for string destination types. |
| switch dv.Kind() { |
| case reflect.Ptr: |
| if src == nil { |
| dv.Set(reflect.Zero(dv.Type())) |
| return nil |
| } |
| dv.Set(reflect.New(dv.Type().Elem())) |
| return convertAssignRows(dv.Interface(), src, rows) |
| case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| s := asString(src) |
| i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) |
| if err != nil { |
| err = strconvErr(err) |
| return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| } |
| dv.SetInt(i64) |
| return nil |
| case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| s := asString(src) |
| u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) |
| if err != nil { |
| err = strconvErr(err) |
| return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| } |
| dv.SetUint(u64) |
| return nil |
| case reflect.Float32, reflect.Float64: |
| s := asString(src) |
| f64, err := strconv.ParseFloat(s, dv.Type().Bits()) |
| if err != nil { |
| err = strconvErr(err) |
| return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| } |
| dv.SetFloat(f64) |
| return nil |
| case reflect.String: |
| switch v := src.(type) { |
| case string: |
| dv.SetString(v) |
| return nil |
| case []byte: |
| dv.SetString(string(v)) |
| return nil |
| } |
| } |
| |
| return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) |
| } |
| |
| func strconvErr(err error) error { |
| if ne, ok := err.(*strconv.NumError); ok { |
| return ne.Err |
| } |
| return err |
| } |
| |
| func cloneBytes(b []byte) []byte { |
| if b == nil { |
| return nil |
| } |
| c := make([]byte, len(b)) |
| copy(c, b) |
| return c |
| } |
| |
| func asString(src interface{}) string { |
| switch v := src.(type) { |
| case string: |
| return v |
| case []byte: |
| return string(v) |
| } |
| rv := reflect.ValueOf(src) |
| switch rv.Kind() { |
| case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| return strconv.FormatInt(rv.Int(), 10) |
| case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| return strconv.FormatUint(rv.Uint(), 10) |
| case reflect.Float64: |
| return strconv.FormatFloat(rv.Float(), 'g', -1, 64) |
| case reflect.Float32: |
| return strconv.FormatFloat(rv.Float(), 'g', -1, 32) |
| case reflect.Bool: |
| return strconv.FormatBool(rv.Bool()) |
| } |
| return fmt.Sprintf("%v", src) |
| } |
| |
| func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { |
| switch rv.Kind() { |
| case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| return strconv.AppendInt(buf, rv.Int(), 10), true |
| case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| return strconv.AppendUint(buf, rv.Uint(), 10), true |
| case reflect.Float32: |
| return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true |
| case reflect.Float64: |
| return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true |
| case reflect.Bool: |
| return strconv.AppendBool(buf, rv.Bool()), true |
| case reflect.String: |
| s := rv.String() |
| return append(buf, s...), true |
| } |
| return |
| } |
| |
| var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() |
| |
| // callValuerValue returns vr.Value(), with one exception: |
| // If vr.Value is an auto-generated method on a pointer type and the |
| // pointer is nil, it would panic at runtime in the panicwrap |
| // method. Treat it like nil instead. |
| // Issue 8415. |
| // |
| // This is so people can implement driver.Value on value types and |
| // still use nil pointers to those types to mean nil/NULL, just like |
| // string/*string. |
| // |
| // This function is mirrored in the database/sql/driver package. |
| func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { |
| if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && |
| rv.IsNil() && |
| rv.Type().Elem().Implements(valuerReflectType) { |
| return nil, nil |
| } |
| return vr.Value() |
| } |