database/sql: refcounting and lifetime fixes

Simplifies the contract for Driver.Stmt.Close in
the process of fixing issue 3865.

Fixes #3865
Update #4459 (maybe fixes it; uninvestigated)

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/7363043
diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go
index 376390a..2833992 100644
--- a/src/pkg/database/sql/sql.go
+++ b/src/pkg/database/sql/sql.go
@@ -11,6 +11,7 @@
 	"errors"
 	"fmt"
 	"io"
+	"runtime"
 	"sync"
 )
 
@@ -189,9 +190,66 @@
 	driver driver.Driver
 	dsn    string
 
-	mu       sync.Mutex // protects freeConn and closed
-	freeConn []driver.Conn
-	closed   bool
+	mu        sync.Mutex           // protects following fields
+	outConn   map[driver.Conn]bool // whether the conn is in use
+	freeConn  []driver.Conn
+	closed    bool
+	dep       map[finalCloser]depSet
+	onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned
+	lastPut   map[driver.Conn]string   // stacktrace of last conn's put; debug only
+}
+
+// depSet is a finalCloser's outstanding dependencies
+type depSet map[interface{}]bool // set of true bools
+
+// The finalCloser interface is used by (*DB).addDep and (*DB).get
+type finalCloser interface {
+	// finalClose is called when the reference count of an object
+	// goes to zero. (*DB).mu is not held while calling it.
+	finalClose() error
+}
+
+// addDep notes that x now depends on dep, and x's finalClose won't be
+// called until all of x's dependencies are removed with removeDep.
+func (db *DB) addDep(x finalCloser, dep interface{}) {
+	//println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
+	db.mu.Lock()
+	defer db.mu.Unlock()
+	if db.dep == nil {
+		db.dep = make(map[finalCloser]depSet)
+	}
+	xdep := db.dep[x]
+	if xdep == nil {
+		xdep = make(depSet)
+		db.dep[x] = xdep
+	}
+	xdep[dep] = true
+}
+
+// removeDep notes that x no longer depends on dep.
+// If x still has dependencies, nil is returned.
+// If x no longer has any dependencies, its finalClose method will be
+// called and its error value will be returned.
+func (db *DB) removeDep(x finalCloser, dep interface{}) error {
+	//println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
+	done := false
+
+	db.mu.Lock()
+	xdep := db.dep[x]
+	if xdep != nil {
+		delete(xdep, dep)
+		if len(xdep) == 0 {
+			delete(db.dep, x)
+			done = true
+		}
+	}
+	db.mu.Unlock()
+
+	if !done {
+		return nil
+	}
+	//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
+	return x.finalClose()
 }
 
 // Open opens a database specified by its database driver name and a
@@ -201,11 +259,20 @@
 // Most users will open a database via a driver-specific connection
 // helper function that returns a *DB.
 func Open(driverName, dataSourceName string) (*DB, error) {
-	driver, ok := drivers[driverName]
+	driveri, ok := drivers[driverName]
 	if !ok {
 		return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
 	}
-	return &DB{driver: driver, dsn: dataSourceName}, nil
+	// TODO: optionally proactively connect to a Conn to check
+	// the dataSourceName: golang.org/issue/4804
+	db := &DB{
+		driver:    driveri,
+		dsn:       dataSourceName,
+		outConn:   make(map[driver.Conn]bool),
+		lastPut:   make(map[driver.Conn]string),
+		onConnPut: make(map[driver.Conn][]func()),
+	}
+	return db, nil
 }
 
 // Close closes the database, releasing any open resources.
@@ -241,22 +308,38 @@
 	if n := len(db.freeConn); n > 0 {
 		conn := db.freeConn[n-1]
 		db.freeConn = db.freeConn[:n-1]
+		db.outConn[conn] = true
 		db.mu.Unlock()
 		return conn, nil
 	}
 	db.mu.Unlock()
-	return db.driver.Open(db.dsn)
+	conn, err := db.driver.Open(db.dsn)
+	if err == nil {
+		db.mu.Lock()
+		db.outConn[conn] = true
+		db.mu.Unlock()
+	}
+	return conn, err
 }
 
+// connIfFree returns (wanted, true) if wanted is still a valid conn and
+// isn't in use.
+//
+// If wanted is valid but in use, connIfFree returns (wanted, false).
+// If wanted is invalid, connIfFre returns (nil, false).
 func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
 	db.mu.Lock()
 	defer db.mu.Unlock()
+	if db.outConn[wanted] {
+		return conn, false
+	}
 	for i, conn := range db.freeConn {
 		if conn != wanted {
 			continue
 		}
 		db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
 		db.freeConn = db.freeConn[:len(db.freeConn)-1]
+		db.outConn[wanted] = true
 		return wanted, true
 	}
 	return nil, false
@@ -265,14 +348,52 @@
 // putConnHook is a hook for testing.
 var putConnHook func(*DB, driver.Conn)
 
+// noteUnusedDriverStatement notes that si is no longer used and should
+// be closed whenever possible (when c is next not in use), unless c is
+// already closed.
+func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) {
+	db.mu.Lock()
+	defer db.mu.Unlock()
+	if db.outConn[c] {
+		db.onConnPut[c] = append(db.onConnPut[c], func() {
+			si.Close()
+		})
+	} else {
+		si.Close()
+	}
+}
+
+// debugGetPut determines whether getConn & putConn calls' stack traces
+// are returned for more verbose crashes.
+const debugGetPut = false
+
 // putConn adds a connection to the db's free pool.
 // err is optionally the last error that occurred on this connection.
 func (db *DB) putConn(c driver.Conn, err error) {
+	db.mu.Lock()
+	if !db.outConn[c] {
+		if debugGetPut {
+			fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c])
+		}
+		panic("sql: connection returned that was never out")
+	}
+	if debugGetPut {
+		db.lastPut[c] = stack()
+	}
+	delete(db.outConn, c)
+
+	if fns, ok := db.onConnPut[c]; ok {
+		for _, fn := range fns {
+			fn()
+		}
+		delete(db.onConnPut, c)
+	}
+
 	if err == driver.ErrBadConn {
 		// Don't reuse bad connections.
+		db.mu.Unlock()
 		return
 	}
-	db.mu.Lock()
 	if putConnHook != nil {
 		putConnHook(db, c)
 	}
@@ -300,7 +421,7 @@
 	return stmt, err
 }
 
-func (db *DB) prepare(query string) (stmt *Stmt, err error) {
+func (db *DB) prepare(query string) (*Stmt, error) {
 	// TODO: check if db.driver supports an optional
 	// driver.Preparer interface and call that instead, if so,
 	// otherwise we make a prepared statement that's bound
@@ -311,19 +432,17 @@
 	if err != nil {
 		return nil, err
 	}
-	defer func() {
-		db.putConn(ci, err)
-	}()
-
 	si, err := ci.Prepare(query)
 	if err != nil {
+		db.putConn(ci, err)
 		return nil, err
 	}
-	stmt = &Stmt{
+	stmt := &Stmt{
 		db:    db,
 		query: query,
 		css:   []connStmt{{ci, si}},
 	}
+	db.addDep(stmt, stmt)
 	return stmt, nil
 }
 
@@ -698,6 +817,8 @@
 	query     string // that created the Stmt
 	stickyErr error  // if non-nil, this error is returned for all operations
 
+	closemu sync.RWMutex // held exclusively during close, for read otherwise.
+
 	// If in a transaction, else both nil:
 	tx   *Tx
 	txsi driver.Stmt
@@ -715,6 +836,8 @@
 // Exec executes a prepared statement with the given arguments and
 // returns a Result summarizing the effect of the statement.
 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+	s.closemu.RLock()
+	defer s.closemu.RUnlock()
 	_, releaseConn, si, err := s.connStmt()
 	if err != nil {
 		return nil, err
@@ -813,6 +936,9 @@
 // Query executes a prepared query statement with the given arguments
 // and returns the query results as a *Rows.
 func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+	s.closemu.RLock()
+	defer s.closemu.RUnlock()
+
 	ci, releaseConn, si, err := s.connStmt()
 	if err != nil {
 		return nil, err
@@ -827,10 +953,15 @@
 	// Note: ownership of ci passes to the *Rows, to be freed
 	// with releaseConn.
 	rows := &Rows{
-		db:          s.db,
-		ci:          ci,
-		releaseConn: releaseConn,
-		rowsi:       rowsi,
+		db:    s.db,
+		ci:    ci,
+		rowsi: rowsi,
+		// releaseConn set below
+	}
+	s.db.addDep(s, rows)
+	rows.releaseConn = func(err error) {
+		releaseConn(err)
+		s.db.removeDep(s, rows)
 	}
 	return rows, nil
 }
@@ -876,6 +1007,9 @@
 
 // Close closes the statement.
 func (s *Stmt) Close() error {
+	s.closemu.Lock()
+	defer s.closemu.Unlock()
+
 	if s.stickyErr != nil {
 		return s.stickyErr
 	}
@@ -888,18 +1022,17 @@
 
 	if s.tx != nil {
 		s.txsi.Close()
-	} else {
-		for _, v := range s.css {
-			if ci, match := s.db.connIfFree(v.ci); match {
-				v.si.Close()
-				s.db.putConn(ci, nil)
-			} else {
-				// TODO(bradfitz): care that we can't close
-				// this statement because the statement's
-				// connection is in use?
-			}
-		}
+		return nil
 	}
+
+	return s.db.removeDep(s, s)
+}
+
+func (s *Stmt) finalClose() error {
+	for _, v := range s.css {
+		s.db.noteUnusedDriverStatement(v.ci, v.si)
+	}
+	s.css = nil
 	return nil
 }
 
@@ -918,7 +1051,7 @@
 //     ...
 type Rows struct {
 	db          *DB
-	ci          driver.Conn // owned; must call putconn when closed to release
+	ci          driver.Conn // owned; must call releaseConn when closed to release
 	releaseConn func(error)
 	rowsi       driver.Rows
 
@@ -1094,3 +1227,8 @@
 type result struct {
 	driver.Result
 }
+
+func stack() string {
+	var buf [1024]byte
+	return string(buf[:runtime.Stack(buf[:], false)])
+}