database/sql: add support for multiple result sets

Many database systems allow returning multiple result sets
in a single query. This can be useful when dealing with many
intermediate results on the server and there is a need
to return more then one arity of data to the client.

Fixes #12382

Change-Id: I480a9ac6dadfc8743e0ba8b6d868ccf8442a9ca1
Reviewed-on: https://go-review.googlesource.com/30592
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
index ccc283d..b3d83f3 100644
--- a/src/database/sql/driver/driver.go
+++ b/src/database/sql/driver/driver.go
@@ -213,6 +213,22 @@
 	Next(dest []Value) error
 }
 
+// RowsNextResultSet extends the Rows interface by providing a way to signal
+// the driver to advance to the next result set.
+type RowsNextResultSet interface {
+	Rows
+
+	// HasNextResultSet is called at the end of the current result set and
+	// reports whether there is another result set after the current one.
+	HasNextResultSet() bool
+
+	// NextResultSet advances the driver to the next result set even
+	// if there are remaining rows in the current result set.
+	//
+	// NextResultSet should return io.EOF when there are no more result sets.
+	NextResultSet() error
+}
+
 // Tx is a transaction.
 type Tx interface {
 	Commit() error
diff --git a/src/database/sql/example_test.go b/src/database/sql/example_test.go
index dcb74e0..9032eac 100644
--- a/src/database/sql/example_test.go
+++ b/src/database/sql/example_test.go
@@ -44,3 +44,65 @@
 		fmt.Printf("Username is %s\n", username)
 	}
 }
+
+func ExampleDB_QueryMultipleResultSets() {
+	age := 27
+	q := `
+create temp table uid (id bigint); -- Create temp table for queries.
+insert into uid
+select id from users where age < ?; -- Populate temp table.
+
+-- First result set.
+select
+	users.id, name
+from
+	users
+	join uid on users.id = uid.id
+;
+
+-- Second result set.
+select 
+	ur.user, ur.role
+from
+	user_roles as ur
+	join uid on uid.id = ur.user
+;
+	`
+	rows, err := db.Query(q, age)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer rows.Close()
+
+	for rows.Next() {
+		var (
+			id   int64
+			name string
+		)
+		if err := rows.Scan(&id, &name); err != nil {
+			log.Fatal(err)
+		}
+		fmt.Printf("id %d name is %s\n", id, name)
+	}
+	if !rows.NextResultSet() {
+		log.Fatal("expected more result sets", rows.Err())
+	}
+	var roleMap = map[int64]string{
+		1: "user",
+		2: "admin",
+		3: "gopher",
+	}
+	for rows.Next() {
+		var (
+			id   int64
+			role int64
+		)
+		if err := rows.Scan(&id, &role); err != nil {
+			log.Fatal(err)
+		}
+		fmt.Printf("id %d has role %s\n", id, roleMap[role])
+	}
+	if err := rows.Err(); err != nil {
+		log.Fatal(err)
+	}
+}
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 5b238bf..aaa13a6 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -36,6 +36,8 @@
 // Any of these can be preceded by PANIC|<method>|, to cause the
 // named method on fakeStmt to panic.
 //
+// Multiple of these can be combined when separated with a semicolon.
+//
 // When opening a fakeDriver's database, it starts empty with no
 // tables. All tables and data are stored in memory only.
 type fakeDriver struct {
@@ -109,6 +111,8 @@
 	table string
 	panic string
 
+	next *fakeStmt // used for returning multiple results.
+
 	closed bool
 
 	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
@@ -377,7 +381,7 @@
 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
 // (note that where columns must always contain ? marks,
 //  just a limitation for fakedb)
-func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
 	if len(parts) != 3 {
 		stmt.Close()
 		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
@@ -411,7 +415,7 @@
 }
 
 // parts are table|col=type,col2=type2
-func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
 	if len(parts) != 2 {
 		stmt.Close()
 		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
@@ -430,7 +434,7 @@
 }
 
 // parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
 	if len(parts) != 2 {
 		stmt.Close()
 		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -492,38 +496,52 @@
 		return nil, driver.ErrBadConn
 	}
 
-	parts := strings.Split(query, "|")
-	if len(parts) < 1 {
-		return nil, errf("empty query")
-	}
-	stmt := &fakeStmt{q: query, c: c}
-	if len(parts) >= 3 && parts[0] == "PANIC" {
-		stmt.panic = parts[1]
-		parts = parts[2:]
-	}
-	cmd := parts[0]
-	stmt.cmd = cmd
-	parts = parts[1:]
+	var firstStmt, prev *fakeStmt
+	for _, query := range strings.Split(query, ";") {
+		parts := strings.Split(query, "|")
+		if len(parts) < 1 {
+			return nil, errf("empty query")
+		}
+		stmt := &fakeStmt{q: query, c: c}
+		if firstStmt == nil {
+			firstStmt = stmt
+		}
+		if len(parts) >= 3 && parts[0] == "PANIC" {
+			stmt.panic = parts[1]
+			parts = parts[2:]
+		}
+		cmd := parts[0]
+		stmt.cmd = cmd
+		parts = parts[1:]
 
-	c.incrStat(&c.stmtsMade)
-	switch cmd {
-	case "WIPE":
-		// Nothing
-	case "SELECT":
-		return c.prepareSelect(stmt, parts)
-	case "CREATE":
-		return c.prepareCreate(stmt, parts)
-	case "INSERT":
-		return c.prepareInsert(stmt, parts)
-	case "NOSERT":
-		// Do all the prep-work like for an INSERT but don't actually insert the row.
-		// Used for some of the concurrent tests.
-		return c.prepareInsert(stmt, parts)
-	default:
-		stmt.Close()
-		return nil, errf("unsupported command type %q", cmd)
+		c.incrStat(&c.stmtsMade)
+		var err error
+		switch cmd {
+		case "WIPE":
+			// Nothing
+		case "SELECT":
+			stmt, err = c.prepareSelect(stmt, parts)
+		case "CREATE":
+			stmt, err = c.prepareCreate(stmt, parts)
+		case "INSERT":
+			stmt, err = c.prepareInsert(stmt, parts)
+		case "NOSERT":
+			// Do all the prep-work like for an INSERT but don't actually insert the row.
+			// Used for some of the concurrent tests.
+			stmt, err = c.prepareInsert(stmt, parts)
+		default:
+			stmt.Close()
+			return nil, errf("unsupported command type %q", cmd)
+		}
+		if err != nil {
+			return nil, err
+		}
+		if prev != nil {
+			prev.next = stmt
+		}
+		prev = stmt
 	}
-	return stmt, nil
+	return firstStmt, nil
 }
 
 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
@@ -550,6 +568,9 @@
 		s.c.incrStat(&s.c.stmtsClosed)
 		s.closed = true
 	}
+	if s.next != nil {
+		s.next.Close()
+	}
 	return nil
 }
 
@@ -667,64 +688,80 @@
 		panic("error in pkg db; should only get here if size is correct")
 	}
 
-	db.mu.Lock()
-	t, ok := db.table(s.table)
-	db.mu.Unlock()
-	if !ok {
-		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
-	}
+	setMRows := make([][]*row, 0, 1)
+	setColumns := make([][]string, 0, 1)
 
-	if s.table == "magicquery" {
-		if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
-			if args[0] == "sleep" {
-				time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+	for {
+		db.mu.Lock()
+		t, ok := db.table(s.table)
+		db.mu.Unlock()
+		if !ok {
+			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+		}
+
+		if s.table == "magicquery" {
+			if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
+				if args[0] == "sleep" {
+					time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+				}
 			}
 		}
-	}
 
-	t.mu.Lock()
-	defer t.mu.Unlock()
+		t.mu.Lock()
 
-	colIdx := make(map[string]int) // select column name -> column index in table
-	for _, name := range s.colName {
-		idx := t.columnIndex(name)
-		if idx == -1 {
-			return nil, fmt.Errorf("fakedb: unknown column name %q", name)
-		}
-		colIdx[name] = idx
-	}
-
-	mrows := []*row{}
-rows:
-	for _, trow := range t.rows {
-		// Process the where clause, skipping non-match rows. This is lazy
-		// and just uses fmt.Sprintf("%v") to test equality. Good enough
-		// for test code.
-		for widx, wcol := range s.whereCol {
-			idx := t.columnIndex(wcol)
+		colIdx := make(map[string]int) // select column name -> column index in table
+		for _, name := range s.colName {
+			idx := t.columnIndex(name)
 			if idx == -1 {
-				return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+				t.mu.Unlock()
+				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
 			}
-			tcol := trow.cols[idx]
-			if bs, ok := tcol.([]byte); ok {
-				// lazy hack to avoid sprintf %v on a []byte
-				tcol = string(bs)
-			}
-			if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
-				continue rows
-			}
+			colIdx[name] = idx
 		}
-		mrow := &row{cols: make([]interface{}, len(s.colName))}
-		for seli, name := range s.colName {
-			mrow.cols[seli] = trow.cols[colIdx[name]]
+
+		mrows := []*row{}
+	rows:
+		for _, trow := range t.rows {
+			// Process the where clause, skipping non-match rows. This is lazy
+			// and just uses fmt.Sprintf("%v") to test equality. Good enough
+			// for test code.
+			for widx, wcol := range s.whereCol {
+				idx := t.columnIndex(wcol)
+				if idx == -1 {
+					t.mu.Unlock()
+					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+				}
+				tcol := trow.cols[idx]
+				if bs, ok := tcol.([]byte); ok {
+					// lazy hack to avoid sprintf %v on a []byte
+					tcol = string(bs)
+				}
+				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+					continue rows
+				}
+			}
+			mrow := &row{cols: make([]interface{}, len(s.colName))}
+			for seli, name := range s.colName {
+				mrow.cols[seli] = trow.cols[colIdx[name]]
+			}
+			mrows = append(mrows, mrow)
 		}
-		mrows = append(mrows, mrow)
+
+		t.mu.Unlock()
+
+		setMRows = append(setMRows, mrows)
+		setColumns = append(setColumns, s.colName)
+
+		if s.next == nil {
+			break
+		}
+		s = s.next
 	}
 
 	cursor := &rowsCursor{
-		pos:    -1,
-		rows:   mrows,
-		cols:   s.colName,
+		posRow: -1,
+		rows:   setMRows,
+		cols:   setColumns,
 		errPos: -1,
 	}
 	return cursor, nil
@@ -760,9 +797,10 @@
 }
 
 type rowsCursor struct {
-	cols   []string
-	pos    int
-	rows   []*row
+	cols   [][]string
+	posSet int
+	posRow int
+	rows   [][]*row
 	closed bool
 
 	// errPos and err are for making Next return early with error.
@@ -786,7 +824,7 @@
 }
 
 func (rc *rowsCursor) Columns() []string {
-	return rc.cols
+	return rc.cols[rc.posSet]
 }
 
 var rowsCursorNextHook func(dest []driver.Value) error
@@ -799,14 +837,14 @@
 	if rc.closed {
 		return errors.New("fakedb: cursor is closed")
 	}
-	rc.pos++
-	if rc.pos == rc.errPos {
+	rc.posRow++
+	if rc.posRow == rc.errPos {
 		return rc.err
 	}
-	if rc.pos >= len(rc.rows) {
+	if rc.posRow >= len(rc.rows[rc.posSet]) {
 		return io.EOF // per interface spec
 	}
-	for i, v := range rc.rows[rc.pos].cols {
+	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
 		// TODO(bradfitz): convert to subset types? naah, I
 		// think the subset types should only be input to
 		// driver, but the sql package should be able to handle
@@ -831,6 +869,19 @@
 	return nil
 }
 
+func (rc *rowsCursor) HasNextResultSet() bool {
+	return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+	if rc.HasNextResultSet() {
+		rc.posSet++
+		rc.posRow = -1
+		return nil
+	}
+	return io.EOF // Per interface spec.
+}
+
 // fakeDriverString is like driver.String, but indirects pointers like
 // DefaultValueConverter.
 //
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index c26d7d3..9703342 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -1943,6 +1943,47 @@
 	}
 	rs.lasterr = rs.rowsi.Next(rs.lastcols)
 	if rs.lasterr != nil {
+		// Close the connection if there is a driver error.
+		if rs.lasterr != io.EOF {
+			rs.Close()
+			return false
+		}
+		nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+		if !ok {
+			rs.Close()
+			return false
+		}
+		// The driver is at the end of the current result set.
+		// Test to see if there is another result set after the current one.
+		// Only close Rows if there is no futher result sets to read.
+		if !nextResultSet.HasNextResultSet() {
+			rs.Close()
+		}
+		return false
+	}
+	return true
+}
+
+// NextResultSet prepares the next result set for reading. It returns true if
+// there is further result sets, or false if there is no further result set
+// or if there is an error advancing to it. The Err method should be consulted
+// to distinguish between the two cases.
+//
+// After calling NextResultSet, the Next method should always be called before
+// scanning. If there are further result sets they may not have rows in the result
+// set.
+func (rs *Rows) NextResultSet() bool {
+	if rs.isClosed() {
+		return false
+	}
+	rs.lastcols = nil
+	nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+	if !ok {
+		rs.Close()
+		return false
+	}
+	rs.lasterr = nextResultSet.NextResultSet()
+	if rs.lasterr != nil {
 		rs.Close()
 		return false
 	}
@@ -2047,7 +2088,8 @@
 	return atomic.LoadInt32(&rs.closed) != 0
 }
 
-// Close closes the Rows, preventing further enumeration. If Next returns
+// Close closes the Rows, preventing further enumeration. If Next and
+// NextResultSet both return
 // false, the Rows are closed automatically and it will suffice to check the
 // result of Err. Close is idempotent and does not affect the result of Err.
 func (rs *Rows) Close() error {
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index ca14af7..bce210d 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -319,6 +319,82 @@
 	}
 }
 
+func TestMultiResultSetQuery(t *testing.T) {
+	db := newTestDB(t, "people")
+	defer closeDB(t, db)
+	prepares0 := numPrepares(t, db)
+	rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+	if err != nil {
+		t.Fatalf("Query: %v", err)
+	}
+	type row1 struct {
+		age  int
+		name string
+	}
+	type row2 struct {
+		name string
+	}
+	got1 := []row1{}
+	for rows.Next() {
+		var r row1
+		err = rows.Scan(&r.age, &r.name)
+		if err != nil {
+			t.Fatalf("Scan: %v", err)
+		}
+		got1 = append(got1, r)
+	}
+	err = rows.Err()
+	if err != nil {
+		t.Fatalf("Err: %v", err)
+	}
+	want1 := []row1{
+		{age: 1, name: "Alice"},
+		{age: 2, name: "Bob"},
+		{age: 3, name: "Chris"},
+	}
+	if !reflect.DeepEqual(got1, want1) {
+		t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+	}
+
+	if !rows.NextResultSet() {
+		t.Errorf("expected another result set")
+	}
+
+	got2 := []row2{}
+	for rows.Next() {
+		var r row2
+		err = rows.Scan(&r.name)
+		if err != nil {
+			t.Fatalf("Scan: %v", err)
+		}
+		got2 = append(got2, r)
+	}
+	err = rows.Err()
+	if err != nil {
+		t.Fatalf("Err: %v", err)
+	}
+	want2 := []row2{
+		{name: "Alice"},
+		{name: "Bob"},
+		{name: "Chris"},
+	}
+	if !reflect.DeepEqual(got2, want2) {
+		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+	}
+	if rows.NextResultSet() {
+		t.Errorf("expected no more result sets")
+	}
+
+	// And verify that the final rows.Next() call, which hit EOF,
+	// also closed the rows connection.
+	if n := db.numFreeConns(); n != 1 {
+		t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+	}
+	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+		t.Errorf("executed %d Prepare statements; want 1", prepares)
+	}
+}
+
 func TestByteOwnership(t *testing.T) {
 	db := newTestDB(t, "people")
 	defer closeDB(t, db)