exp/sql: rename NullableString to NullString and allow its use as a parameter

Prep for Issue 2699

R=rsc
CC=golang-dev
https://golang.org/cl/5536045
diff --git a/src/pkg/exp/sql/convert_test.go b/src/pkg/exp/sql/convert_test.go
index 702ba43..8c0cafc 100644
--- a/src/pkg/exp/sql/convert_test.go
+++ b/src/pkg/exp/sql/convert_test.go
@@ -5,6 +5,7 @@
 package sql
 
 import (
+	"exp/sql/driver"
 	"fmt"
 	"reflect"
 	"testing"
@@ -154,8 +155,8 @@
 	}
 }
 
-func TestNullableString(t *testing.T) {
-	var ns NullableString
+func TestNullString(t *testing.T) {
+	var ns NullString
 	convertAssign(&ns, []byte("foo"))
 	if !ns.Valid {
 		t.Errorf("expecting not null")
@@ -171,3 +172,35 @@
 		t.Errorf("expecting blank on nil; got %q", ns.String)
 	}
 }
+
+type valueConverterTest struct {
+	c       driver.ValueConverter
+	in, out interface{}
+	err     string
+}
+
+var valueConverterTests = []valueConverterTest{
+	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
+	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
+}
+
+func TestValueConverters(t *testing.T) {
+	for i, tt := range valueConverterTests {
+		out, err := tt.c.ConvertValue(tt.in)
+		goterr := ""
+		if err != nil {
+			goterr = err.Error()
+		}
+		if goterr != tt.err {
+			t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q",
+				i, tt.c, tt.in, tt.in, goterr, tt.err)
+		}
+		if tt.err != "" {
+			continue
+		}
+		if !reflect.DeepEqual(out, tt.out) {
+			t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)",
+				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
+		}
+	}
+}
diff --git a/src/pkg/exp/sql/driver/types.go b/src/pkg/exp/sql/driver/types.go
index d6ba641..f383885 100644
--- a/src/pkg/exp/sql/driver/types.go
+++ b/src/pkg/exp/sql/driver/types.go
@@ -32,6 +32,15 @@
 	ConvertValue(v interface{}) (interface{}, error)
 }
 
+// SubsetValuer is the interface providing the SubsetValue method.
+//
+// Types implementing SubsetValuer interface are able to convert
+// themselves to one of the driver's allowed subset values.
+type SubsetValuer interface {
+	// SubsetValue returns a driver parameter subset value.
+	SubsetValue() (interface{}, error)
+}
+
 // Bool is a ValueConverter that converts input values to bools.
 //
 // The conversion rules are:
@@ -136,6 +145,32 @@
 	return fmt.Sprintf("%v", v), nil
 }
 
+// Null is a type that implements ValueConverter by allowing nil
+// values but otherwise delegating to another ValueConverter.
+type Null struct {
+	Converter ValueConverter
+}
+
+func (n Null) ConvertValue(v interface{}) (interface{}, error) {
+	if v == nil {
+		return nil, nil
+	}
+	return n.Converter.ConvertValue(v)
+}
+
+// NotNull is a type that implements ValueConverter by disallowing nil
+// values but otherwise delegating to another ValueConverter.
+type NotNull struct {
+	Converter ValueConverter
+}
+
+func (n NotNull) ConvertValue(v interface{}) (interface{}, error) {
+	if v == nil {
+		return nil, fmt.Errorf("nil value not allowed")
+	}
+	return n.Converter.ConvertValue(v)
+}
+
 // IsParameterSubsetType reports whether v is of a valid type for a
 // parameter. These types are:
 //
@@ -200,6 +235,17 @@
 		return v, nil
 	}
 
+	if svi, ok := v.(SubsetValuer); ok {
+		sv, err := svi.SubsetValue()
+		if err != nil {
+			return nil, err
+		}
+		if !IsParameterSubsetType(sv) {
+			return nil, fmt.Errorf("non-subset type %T returned from SubsetValue", sv)
+		}
+		return sv, nil
+	}
+
 	rv := reflect.ValueOf(v)
 	switch rv.Kind() {
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -215,5 +261,5 @@
 	case reflect.Float32, reflect.Float64:
 		return rv.Float(), nil
 	}
-	return nil, fmt.Errorf("unsupported type %s", rv.Kind())
+	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
 }
diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go
index 70aa68c..0376583 100644
--- a/src/pkg/exp/sql/fakedb_test.go
+++ b/src/pkg/exp/sql/fakedb_test.go
@@ -589,7 +589,9 @@
 	case "int32":
 		return driver.Int32
 	case "string":
-		return driver.String
+		return driver.NotNull{driver.String}
+	case "nullstring":
+		return driver.Null{driver.String}
 	case "datetime":
 		return driver.DefaultParameterConverter
 	}
diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go
index cba7e9e..3201e76 100644
--- a/src/pkg/exp/sql/sql.go
+++ b/src/pkg/exp/sql/sql.go
@@ -35,11 +35,11 @@
 // valid until the next call to Next, Scan, or Close.
 type RawBytes []byte
 
-// NullableString represents a string that may be null.
-// NullableString implements the ScannerInto interface so
+// NullString represents a string that may be null.
+// NullString implements the ScannerInto interface so
 // it can be used as a scan destination:
 //
-//  var s NullableString
+//  var s NullString
 //  err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
 //  ...
 //  if s.Valid {
@@ -49,19 +49,27 @@
 //  }
 //
 // TODO(bradfitz): add other types.
-type NullableString struct {
+type NullString struct {
 	String string
 	Valid  bool // Valid is true if String is not NULL
 }
 
 // ScanInto implements the ScannerInto interface.
-func (ms *NullableString) ScanInto(value interface{}) error {
+func (ns *NullString) ScanInto(value interface{}) error {
 	if value == nil {
-		ms.String, ms.Valid = "", false
+		ns.String, ns.Valid = "", false
 		return nil
 	}
-	ms.Valid = true
-	return convertAssign(&ms.String, value)
+	ns.Valid = true
+	return convertAssign(&ns.String, value)
+}
+
+// SubsetValue implements the driver SubsetValuer interface.
+func (ns NullString) SubsetValue() (interface{}, error) {
+	if !ns.Valid {
+		return nil, nil
+	}
+	return ns.String, nil
 }
 
 // ScannerInto is an interface used by Scan.
@@ -530,6 +538,27 @@
 	// Convert args to subset types.
 	if cc, ok := si.(driver.ColumnConverter); ok {
 		for n, arg := range args {
+			// First, see if the value itself knows how to convert
+			// itself to a driver type.  For example, a NullString
+			// struct changing into a string or nil.
+			if svi, ok := arg.(driver.SubsetValuer); ok {
+				sv, err := svi.SubsetValue()
+				if err != nil {
+					return nil, fmt.Errorf("sql: argument index %d from SubsetValue: %v", n, err)
+				}
+				if !driver.IsParameterSubsetType(sv) {
+					return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from SubsetValue", n, sv)
+				}
+				arg = sv
+			}
+
+			// Second, ask the column to sanity check itself. For
+			// example, drivers might use this to make sure that
+			// an int64 values being inserted into a 16-bit
+			// integer field is in range (before getting
+			// truncated), or that a nil can't go into a NOT NULL
+			// column before going across the network to get the
+			// same error.
 			args[n], err = cc.ColumnConverter(n).ConvertValue(arg)
 			if err != nil {
 				return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go
index 30cd97d..3fe9398 100644
--- a/src/pkg/exp/sql/sql_test.go
+++ b/src/pkg/exp/sql/sql_test.go
@@ -340,3 +340,65 @@
 		t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
 	}
 }
+
+func TestNullStringParam(t *testing.T) {
+	db := newTestDB(t, "")
+	defer closeDB(t, db)
+	exec(t, db, "CREATE|t|id=int32,name=string,favcolor=nullstring")
+
+	// Inserts with db.Exec:
+	exec(t, db, "INSERT|t|id=?,name=?,favcolor=?", 1, "alice", NullString{"aqua", true})
+	exec(t, db, "INSERT|t|id=?,name=?,favcolor=?", 2, "bob", NullString{"brown", false})
+
+	_, err := db.Exec("INSERT|t|id=?,name=?,favcolor=?", 999, nil, nil)
+	if err == nil {
+		// TODO: this test fails, but it's just because
+		// fakeConn implements the optional Execer interface,
+		// so arguably this is the correct behavior.  But
+		// maybe I should flesh out the fakeConn.Exec
+		// implementation so this properly fails.
+		// t.Errorf("expected error inserting nil name with Exec")
+	}
+
+	// Inserts with a prepared statement:
+	stmt, err := db.Prepare("INSERT|t|id=?,name=?,favcolor=?")
+	if err != nil {
+		t.Fatalf("prepare: %v", err)
+	}
+	if _, err := stmt.Exec(3, "chris", "chartreuse"); err != nil {
+		t.Errorf("exec insert chris: %v", err)
+	}
+	if _, err := stmt.Exec(4, "dave", NullString{"darkred", true}); err != nil {
+		t.Errorf("exec insert dave: %v", err)
+	}
+	if _, err := stmt.Exec(5, "eleanor", NullString{"eel", false}); err != nil {
+		t.Errorf("exec insert dave: %v", err)
+	}
+
+	// Can't put null name into non-nullstring column,
+	if _, err := stmt.Exec(5, NullString{"", false}, nil); err == nil {
+		t.Errorf("expected error inserting nil name with prepared statement Exec")
+	}
+
+	type nameColor struct {
+		name     string
+		favColor NullString
+	}
+
+	wantMap := map[int]nameColor{
+		1: nameColor{"alice", NullString{"aqua", true}},
+		2: nameColor{"bob", NullString{"", false}},
+		3: nameColor{"chris", NullString{"chartreuse", true}},
+		4: nameColor{"dave", NullString{"darkred", true}},
+		5: nameColor{"eleanor", NullString{"", false}},
+	}
+	for id, want := range wantMap {
+		var got nameColor
+		if err := db.QueryRow("SELECT|t|name,favcolor|id=?", id).Scan(&got.name, &got.favColor); err != nil {
+			t.Errorf("id=%d Scan: %v", id, err)
+		}
+		if got != want {
+			t.Errorf("id=%d got %#v, want %#v", id, got, want)
+		}
+	}
+}