database/sql: don't close a driver.Conn until its Stmts are closed

Fixes #5046

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/8016044
diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go
index 55597f7..24c255f 100644
--- a/src/pkg/database/sql/fakedb_test.go
+++ b/src/pkg/database/sql/fakedb_test.go
@@ -229,7 +229,26 @@
 	return c.currTx, nil
 }
 
-func (c *fakeConn) Close() error {
+var hookPostCloseConn struct {
+	sync.Mutex
+	fn func(*fakeConn, error)
+}
+
+func setHookpostCloseConn(fn func(*fakeConn, error)) {
+	hookPostCloseConn.Lock()
+	defer hookPostCloseConn.Unlock()
+	hookPostCloseConn.fn = fn
+}
+
+func (c *fakeConn) Close() (err error) {
+	defer func() {
+		hookPostCloseConn.Lock()
+		fn := hookPostCloseConn.fn
+		hookPostCloseConn.Unlock()
+		if fn != nil {
+			fn(c, err)
+		}
+	}()
 	if c.currTx != nil {
 		return errors.New("can't close fakeConn; in a Transaction")
 	}
diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go
index bc92ecd..d1f929e 100644
--- a/src/pkg/database/sql/sql.go
+++ b/src/pkg/database/sql/sql.go
@@ -204,8 +204,42 @@
 // interfaces returned via that Conn, such as calls on Tx, Stmt,
 // Result, Rows)
 type driverConn struct {
-	sync.Mutex
-	ci driver.Conn
+	db *DB
+
+	sync.Mutex // guards following
+	ci         driver.Conn
+	closed     bool
+}
+
+// the dc.db's Mutex is held.
+func (dc *driverConn) closeDBLocked() error {
+	dc.Lock()
+	if dc.closed {
+		dc.Unlock()
+		return errors.New("sql: duplicate driverConn close")
+	}
+	dc.closed = true
+	dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+	return dc.db.removeDepLocked(dc, dc)()
+}
+
+func (dc *driverConn) Close() error {
+	dc.Lock()
+	if dc.closed {
+		dc.Unlock()
+		return errors.New("sql: duplicate driverConn close")
+	}
+	dc.closed = true
+	dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+	return dc.db.removeDep(dc, dc)
+}
+
+func (dc *driverConn) finalClose() error {
+	dc.Lock()
+	err := dc.ci.Close()
+	dc.ci = nil
+	dc.Unlock()
+	return err
 }
 
 // driverStmt associates a driver.Stmt with the
@@ -238,6 +272,10 @@
 	//println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
 	db.mu.Lock()
 	defer db.mu.Unlock()
+	db.addDepLocked(x, dep)
+}
+
+func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
 	if db.dep == nil {
 		db.dep = make(map[finalCloser]depSet)
 	}
@@ -254,10 +292,16 @@
 // If x no longer has any dependencies, its finalClose method will be
 // called and its error value will be returned.
 func (db *DB) removeDep(x finalCloser, dep interface{}) error {
+	db.mu.Lock()
+	fn := db.removeDepLocked(x, dep)
+	db.mu.Unlock()
+	return fn()
+}
+
+func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
 	//println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
 	done := false
 
-	db.mu.Lock()
 	xdep := db.dep[x]
 	if xdep != nil {
 		delete(xdep, dep)
@@ -266,13 +310,14 @@
 			done = true
 		}
 	}
-	db.mu.Unlock()
 
 	if !done {
-		return nil
+		return func() error { return nil }
 	}
-	//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
-	return x.finalClose()
+	return func() error {
+		//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
+		return x.finalClose()
+	}
 }
 
 // Open opens a database specified by its database driver name and a
@@ -320,9 +365,7 @@
 	defer db.mu.Unlock()
 	var err error
 	for _, dc := range db.freeConn {
-		dc.Lock()
-		err1 := dc.ci.Close()
-		dc.Unlock()
+		err1 := dc.closeDBLocked()
 		if err1 != nil {
 			err = err1
 		}
@@ -365,11 +408,7 @@
 		dc := db.freeConn[nfree-1]
 		db.freeConn[nfree-1] = nil
 		db.freeConn = db.freeConn[:nfree-1]
-		go func() {
-			dc.Lock()
-			dc.ci.Close()
-			dc.Unlock()
-		}()
+		go dc.Close()
 	}
 }
 
@@ -393,8 +432,12 @@
 	if err != nil {
 		return nil, err
 	}
-	dc := &driverConn{ci: ci}
+	dc := &driverConn{
+		db: db,
+		ci: ci,
+	}
 	db.mu.Lock()
+	db.addDepLocked(dc, dc)
 	db.outConn[dc] = true
 	db.mu.Unlock()
 	return dc, nil
@@ -484,9 +527,7 @@
 	// statements which are still active?
 	db.mu.Unlock()
 
-	dc.Lock()
-	dc.ci.Close()
-	dc.Unlock()
+	dc.Close()
 }
 
 // Prepare creates a prepared statement for later queries or executions.
@@ -528,6 +569,7 @@
 		css:   []connStmt{{dc, si}},
 	}
 	db.addDep(stmt, stmt)
+	db.addDep(dc, stmt)
 	db.putConn(dc, nil)
 	return stmt, nil
 }
@@ -1031,6 +1073,7 @@
 			if err != nil {
 				return nil, nil, nil, err
 			}
+			s.db.addDep(dc, s)
 			s.mu.Lock()
 			cs = connStmt{dc, si}
 			s.css = append(s.css, cs)
@@ -1149,6 +1192,7 @@
 func (s *Stmt) finalClose() error {
 	for _, v := range s.css {
 		s.db.noteUnusedDriverStatement(v.dc, v.si)
+		s.db.removeDep(v.dc, s)
 	}
 	s.css = nil
 	return nil
diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go
index 2a9592e..54aad3a 100644
--- a/src/pkg/database/sql/sql_test.go
+++ b/src/pkg/database/sql/sql_test.go
@@ -65,6 +65,12 @@
 		fmt.Printf("Panic: %v\n", e)
 		panic(e)
 	}
+	defer setHookpostCloseConn(nil)
+	setHookpostCloseConn(func(_ *fakeConn, err error) {
+		if err != nil {
+			t.Errorf("Error closing fakeConn: %v", err)
+		}
+	})
 	err := db.Close()
 	if err != nil {
 		t.Fatalf("error closing DB: %v", err)
@@ -790,3 +796,51 @@
 		t.Errorf("freeConns = %d; want 0", got)
 	}
 }
+
+// golang.org/issue/5046
+func TestCloseConnBeforeStmts(t *testing.T) {
+	defer setHookpostCloseConn(nil)
+	setHookpostCloseConn(func(_ *fakeConn, err error) {
+		if err != nil {
+			t.Errorf("Error closing fakeConn: %v", err)
+		}
+	})
+
+	db := newTestDB(t, "people")
+
+	stmt, err := db.Prepare("SELECT|people|name|")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(db.freeConn) != 1 {
+		t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+	}
+	dc := db.freeConn[0]
+	if dc.closed {
+		t.Errorf("conn shouldn't be closed")
+	}
+
+	err = db.Close()
+	if err != nil {
+		t.Errorf("db Close = %v", err)
+	}
+	if !dc.closed {
+		t.Errorf("after db.Close, driverConn should be closed")
+	}
+	if dc.ci == nil {
+		t.Errorf("after db.Close, driverConn should still have its Conn interface")
+	}
+
+	err = stmt.Close()
+	if err != nil {
+		t.Errorf("Stmt close = %v", err)
+	}
+
+	if !dc.closed {
+		t.Errorf("conn should be closed")
+	}
+	if dc.ci != nil {
+		t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
+	}
+}