database/sql: move from exp/sql

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5536076
diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go
new file mode 100644
index 0000000..b0d137c
--- /dev/null
+++ b/src/pkg/database/sql/fakedb_test.go
@@ -0,0 +1,598 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+)
+
+var _ = log.Printf
+
+// fakeDriver is a fake database that implements Go's driver.Driver
+// interface, just for testing.
+//
+// It speaks a query language that's semantically similar to but
+// syntantically different and simpler than SQL.  The syntax is as
+// follows:
+//
+//   WIPE
+//   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
+//     where types are: "string", [u]int{8,16,32,64}, "bool"
+//   INSERT|<tablename>|col=val,col2=val2,col3=?
+//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+//
+// When opening a a fakeDriver's database, it starts empty with no
+// tables.  All tables and data are stored in memory only.
+type fakeDriver struct {
+	mu        sync.Mutex
+	openCount int
+	dbs       map[string]*fakeDB
+}
+
+type fakeDB struct {
+	name string
+
+	mu     sync.Mutex
+	free   []*fakeConn
+	tables map[string]*table
+}
+
+type table struct {
+	mu      sync.Mutex
+	colname []string
+	coltype []string
+	rows    []*row
+}
+
+func (t *table) columnIndex(name string) int {
+	for n, nname := range t.colname {
+		if name == nname {
+			return n
+		}
+	}
+	return -1
+}
+
+type row struct {
+	cols []interface{} // must be same size as its table colname + coltype
+}
+
+func (r *row) clone() *row {
+	nrow := &row{cols: make([]interface{}, len(r.cols))}
+	copy(nrow.cols, r.cols)
+	return nrow
+}
+
+type fakeConn struct {
+	db *fakeDB // where to return ourselves to
+
+	currTx *fakeTx
+
+	// Stats for tests:
+	mu          sync.Mutex
+	stmtsMade   int
+	stmtsClosed int
+}
+
+func (c *fakeConn) incrStat(v *int) {
+	c.mu.Lock()
+	*v++
+	c.mu.Unlock()
+}
+
+type fakeTx struct {
+	c *fakeConn
+}
+
+type fakeStmt struct {
+	c *fakeConn
+	q string // just for debugging
+
+	cmd   string
+	table string
+
+	closed bool
+
+	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
+	colType      []string      // used by CREATE
+	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
+	placeholders int           // used by INSERT/SELECT: number of ? params
+
+	whereCol []string // used by SELECT (all placeholders)
+
+	placeholderConverter []driver.ValueConverter // used by INSERT
+}
+
+var fdriver driver.Driver = &fakeDriver{}
+
+func init() {
+	Register("test", fdriver)
+}
+
+// Supports dsn forms:
+//    <dbname>
+//    <dbname>;<opts>  (no currently supported options)
+func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
+	parts := strings.Split(dsn, ";")
+	if len(parts) < 1 {
+		return nil, errors.New("fakedb: no database name")
+	}
+	name := parts[0]
+
+	db := d.getDB(name)
+
+	d.mu.Lock()
+	d.openCount++
+	d.mu.Unlock()
+	return &fakeConn{db: db}, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+	d.mu.Lock()
+	defer d.mu.Unlock()
+	if d.dbs == nil {
+		d.dbs = make(map[string]*fakeDB)
+	}
+	db, ok := d.dbs[name]
+	if !ok {
+		db = &fakeDB{name: name}
+		d.dbs[name] = db
+	}
+	return db
+}
+
+func (db *fakeDB) wipe() {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+	db.tables = nil
+}
+
+func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+	if db.tables == nil {
+		db.tables = make(map[string]*table)
+	}
+	if _, exist := db.tables[name]; exist {
+		return fmt.Errorf("table %q already exists", name)
+	}
+	if len(columnNames) != len(columnTypes) {
+		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
+			name, len(columnNames), len(columnTypes))
+	}
+	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
+	return nil
+}
+
+// must be called with db.mu lock held
+func (db *fakeDB) table(table string) (*table, bool) {
+	if db.tables == nil {
+		return nil, false
+	}
+	t, ok := db.tables[table]
+	return t, ok
+}
+
+func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+	t, ok := db.table(table)
+	if !ok {
+		return
+	}
+	for n, cname := range t.colname {
+		if cname == column {
+			return t.coltype[n], true
+		}
+	}
+	return "", false
+}
+
+func (c *fakeConn) Begin() (driver.Tx, error) {
+	if c.currTx != nil {
+		return nil, errors.New("already in a transaction")
+	}
+	c.currTx = &fakeTx{c: c}
+	return c.currTx, nil
+}
+
+func (c *fakeConn) Close() error {
+	if c.currTx != nil {
+		return errors.New("can't close; in a Transaction")
+	}
+	if c.db == nil {
+		return errors.New("can't close; already closed")
+	}
+	c.db = nil
+	return nil
+}
+
+func checkSubsetTypes(args []interface{}) error {
+	for n, arg := range args {
+		switch arg.(type) {
+		case int64, float64, bool, nil, []byte, string, time.Time:
+		default:
+			return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+		}
+	}
+	return nil
+}
+
+func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
+	// This is an optional interface, but it's implemented here
+	// just to check that all the args of of the proper types.
+	// ErrSkip is returned so the caller acts as if we didn't
+	// implement this at all.
+	err := checkSubsetTypes(args)
+	if err != nil {
+		return nil, err
+	}
+	return nil, driver.ErrSkip
+}
+
+func errf(msg string, args ...interface{}) error {
+	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
+}
+
+// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
+// (note that where where columns must always contain ? marks,
+//  just a limitation for fakedb)
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+	if len(parts) != 3 {
+		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
+	}
+	stmt.table = parts[0]
+	stmt.colName = strings.Split(parts[1], ",")
+	for n, colspec := range strings.Split(parts[2], ",") {
+		if colspec == "" {
+			continue
+		}
+		nameVal := strings.Split(colspec, "=")
+		if len(nameVal) != 2 {
+			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+		}
+		column, value := nameVal[0], nameVal[1]
+		_, ok := c.db.columnType(stmt.table, column)
+		if !ok {
+			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
+		}
+		if value != "?" {
+			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
+				stmt.table, column)
+		}
+		stmt.whereCol = append(stmt.whereCol, column)
+		stmt.placeholders++
+	}
+	return stmt, nil
+}
+
+// parts are table|col=type,col2=type2
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+	if len(parts) != 2 {
+		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
+	}
+	stmt.table = parts[0]
+	for n, colspec := range strings.Split(parts[1], ",") {
+		nameType := strings.Split(colspec, "=")
+		if len(nameType) != 2 {
+			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+		}
+		stmt.colName = append(stmt.colName, nameType[0])
+		stmt.colType = append(stmt.colType, nameType[1])
+	}
+	return stmt, nil
+}
+
+// parts are table|col=?,col2=val
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+	if len(parts) != 2 {
+		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
+	}
+	stmt.table = parts[0]
+	for n, colspec := range strings.Split(parts[1], ",") {
+		nameVal := strings.Split(colspec, "=")
+		if len(nameVal) != 2 {
+			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+		}
+		column, value := nameVal[0], nameVal[1]
+		ctype, ok := c.db.columnType(stmt.table, column)
+		if !ok {
+			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
+		}
+		stmt.colName = append(stmt.colName, column)
+
+		if value != "?" {
+			var subsetVal interface{}
+			// Convert to driver subset type
+			switch ctype {
+			case "string":
+				subsetVal = []byte(value)
+			case "blob":
+				subsetVal = []byte(value)
+			case "int32":
+				i, err := strconv.Atoi(value)
+				if err != nil {
+					return nil, errf("invalid conversion to int32 from %q", value)
+				}
+				subsetVal = int64(i) // int64 is a subset type, but not int32
+			default:
+				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
+			}
+			stmt.colValue = append(stmt.colValue, subsetVal)
+		} else {
+			stmt.placeholders++
+			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
+			stmt.colValue = append(stmt.colValue, "?")
+		}
+	}
+	return stmt, nil
+}
+
+func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+	if c.db == nil {
+		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
+	}
+	parts := strings.Split(query, "|")
+	if len(parts) < 1 {
+		return nil, errf("empty query")
+	}
+	cmd := parts[0]
+	parts = parts[1:]
+	stmt := &fakeStmt{q: query, c: c, cmd: cmd}
+	c.incrStat(&c.stmtsMade)
+	switch cmd {
+	case "WIPE":
+		// Nothing
+	case "SELECT":
+		return c.prepareSelect(stmt, parts)
+	case "CREATE":
+		return c.prepareCreate(stmt, parts)
+	case "INSERT":
+		return c.prepareInsert(stmt, parts)
+	default:
+		return nil, errf("unsupported command type %q", cmd)
+	}
+	return stmt, nil
+}
+
+func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
+	return s.placeholderConverter[idx]
+}
+
+func (s *fakeStmt) Close() error {
+	if !s.closed {
+		s.c.incrStat(&s.c.stmtsClosed)
+		s.closed = true
+	}
+	return nil
+}
+
+var errClosed = errors.New("fakedb: statement has been closed")
+
+func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
+	if s.closed {
+		return nil, errClosed
+	}
+	err := checkSubsetTypes(args)
+	if err != nil {
+		return nil, err
+	}
+
+	db := s.c.db
+	switch s.cmd {
+	case "WIPE":
+		db.wipe()
+		return driver.DDLSuccess, nil
+	case "CREATE":
+		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
+			return nil, err
+		}
+		return driver.DDLSuccess, nil
+	case "INSERT":
+		return s.execInsert(args)
+	}
+	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
+	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
+}
+
+func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
+	db := s.c.db
+	if len(args) != s.placeholders {
+		panic("error in pkg db; should only get here if size is correct")
+	}
+	db.mu.Lock()
+	t, ok := db.table(s.table)
+	db.mu.Unlock()
+	if !ok {
+		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+	}
+
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	cols := make([]interface{}, len(t.colname))
+	argPos := 0
+	for n, colname := range s.colName {
+		colidx := t.columnIndex(colname)
+		if colidx == -1 {
+			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
+		}
+		var val interface{}
+		if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
+			val = args[argPos]
+			argPos++
+		} else {
+			val = s.colValue[n]
+		}
+		cols[colidx] = val
+	}
+
+	t.rows = append(t.rows, &row{cols: cols})
+	return driver.RowsAffected(1), nil
+}
+
+func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
+	if s.closed {
+		return nil, errClosed
+	}
+	err := checkSubsetTypes(args)
+	if err != nil {
+		return nil, err
+	}
+
+	db := s.c.db
+	if len(args) != s.placeholders {
+		panic("error in pkg db; should only get here if size is correct")
+	}
+
+	db.mu.Lock()
+	t, ok := db.table(s.table)
+	db.mu.Unlock()
+	if !ok {
+		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+	}
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	colIdx := make(map[string]int) // select column name -> column index in table
+	for _, name := range s.colName {
+		idx := t.columnIndex(name)
+		if idx == -1 {
+			return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+		}
+		colIdx[name] = idx
+	}
+
+	mrows := []*row{}
+rows:
+	for _, trow := range t.rows {
+		// Process the where clause, skipping non-match rows. This is lazy
+		// and just uses fmt.Sprintf("%v") to test equality.  Good enough
+		// for test code.
+		for widx, wcol := range s.whereCol {
+			idx := t.columnIndex(wcol)
+			if idx == -1 {
+				return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+			}
+			tcol := trow.cols[idx]
+			if bs, ok := tcol.([]byte); ok {
+				// lazy hack to avoid sprintf %v on a []byte
+				tcol = string(bs)
+			}
+			if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+				continue rows
+			}
+		}
+		mrow := &row{cols: make([]interface{}, len(s.colName))}
+		for seli, name := range s.colName {
+			mrow.cols[seli] = trow.cols[colIdx[name]]
+		}
+		mrows = append(mrows, mrow)
+	}
+
+	cursor := &rowsCursor{
+		pos:  -1,
+		rows: mrows,
+		cols: s.colName,
+	}
+	return cursor, nil
+}
+
+func (s *fakeStmt) NumInput() int {
+	return s.placeholders
+}
+
+func (tx *fakeTx) Commit() error {
+	tx.c.currTx = nil
+	return nil
+}
+
+func (tx *fakeTx) Rollback() error {
+	tx.c.currTx = nil
+	return nil
+}
+
+type rowsCursor struct {
+	cols   []string
+	pos    int
+	rows   []*row
+	closed bool
+
+	// a clone of slices to give out to clients, indexed by the
+	// the original slice's first byte address.  we clone them
+	// just so we're able to corrupt them on close.
+	bytesClone map[*byte][]byte
+}
+
+func (rc *rowsCursor) Close() error {
+	if !rc.closed {
+		for _, bs := range rc.bytesClone {
+			bs[0] = 255 // first byte corrupted
+		}
+	}
+	rc.closed = true
+	return nil
+}
+
+func (rc *rowsCursor) Columns() []string {
+	return rc.cols
+}
+
+func (rc *rowsCursor) Next(dest []interface{}) error {
+	if rc.closed {
+		return errors.New("fakedb: cursor is closed")
+	}
+	rc.pos++
+	if rc.pos >= len(rc.rows) {
+		return io.EOF // per interface spec
+	}
+	for i, v := range rc.rows[rc.pos].cols {
+		// TODO(bradfitz): convert to subset types? naah, I
+		// think the subset types should only be input to
+		// driver, but the sql package should be able to handle
+		// a wider range of types coming out of drivers. all
+		// for ease of drivers, and to prevent drivers from
+		// messing up conversions or doing them differently.
+		dest[i] = v
+
+		if bs, ok := v.([]byte); ok {
+			if rc.bytesClone == nil {
+				rc.bytesClone = make(map[*byte][]byte)
+			}
+			clone, ok := rc.bytesClone[&bs[0]]
+			if !ok {
+				clone = make([]byte, len(bs))
+				copy(clone, bs)
+				rc.bytesClone[&bs[0]] = clone
+			}
+			dest[i] = clone
+		}
+	}
+	return nil
+}
+
+func converterForType(typ string) driver.ValueConverter {
+	switch typ {
+	case "bool":
+		return driver.Bool
+	case "int32":
+		return driver.Int32
+	case "string":
+		return driver.NotNull{driver.String}
+	case "nullstring":
+		return driver.Null{driver.String}
+	case "datetime":
+		return driver.DefaultParameterConverter
+	}
+	panic("invalid fakedb column type of " + typ)
+}