database/sql: Close per-tx prepared statements when the associated tx ends

LGTM=bradfitz
R=golang-codereviews, bradfitz, mattn.jp
CC=golang-codereviews
https://golang.org/cl/131650043
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 90f813d..731b7a7 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -1043,6 +1043,13 @@
 	// or Rollback. once done, all operations fail with
 	// ErrTxDone.
 	done bool
+
+	// All Stmts prepared for this transaction.  These will be closed after the
+	// transaction has been committed or rolled back.
+	stmts struct {
+		sync.Mutex
+		v []*Stmt
+	}
 }
 
 var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
@@ -1064,6 +1071,15 @@
 	return tx.dc, nil
 }
 
+// Closes all Stmts prepared for this transaction.
+func (tx *Tx) closePrepared() {
+	tx.stmts.Lock()
+	for _, stmt := range tx.stmts.v {
+		stmt.Close()
+	}
+	tx.stmts.Unlock()
+}
+
 // Commit commits the transaction.
 func (tx *Tx) Commit() error {
 	if tx.done {
@@ -1071,8 +1087,12 @@
 	}
 	defer tx.close()
 	tx.dc.Lock()
-	defer tx.dc.Unlock()
-	return tx.txi.Commit()
+	err := tx.txi.Commit()
+	tx.dc.Unlock()
+	if err != driver.ErrBadConn {
+		tx.closePrepared()
+	}
+	return err
 }
 
 // Rollback aborts the transaction.
@@ -1082,8 +1102,12 @@
 	}
 	defer tx.close()
 	tx.dc.Lock()
-	defer tx.dc.Unlock()
-	return tx.txi.Rollback()
+	err := tx.txi.Rollback()
+	tx.dc.Unlock()
+	if err != driver.ErrBadConn {
+		tx.closePrepared()
+	}
+	return err
 }
 
 // Prepare creates a prepared statement for use within a transaction.
@@ -1127,6 +1151,9 @@
 		},
 		query: query,
 	}
+	tx.stmts.Lock()
+	tx.stmts.v = append(tx.stmts.v, stmt)
+	tx.stmts.Unlock()
 	return stmt, nil
 }
 
@@ -1155,7 +1182,7 @@
 	dc.Lock()
 	si, err := dc.ci.Prepare(stmt.query)
 	dc.Unlock()
-	return &Stmt{
+	txs := &Stmt{
 		db: tx.db,
 		tx: tx,
 		txsi: &driverStmt{
@@ -1165,6 +1192,10 @@
 		query:     stmt.query,
 		stickyErr: err,
 	}
+	tx.stmts.Lock()
+	tx.stmts.v = append(tx.stmts.v, txs)
+	tx.stmts.Unlock()
+	return txs
 }
 
 // Exec executes a query that doesn't return rows.
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 12e5a6f..34efdf2 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -441,6 +441,33 @@
 	}
 }
 
+func TestTxPrepare(t *testing.T) {
+	db := newTestDB(t, "")
+	defer closeDB(t, db)
+	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+	tx, err := db.Begin()
+	if err != nil {
+		t.Fatalf("Begin = %v", err)
+	}
+	stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
+	if err != nil {
+		t.Fatalf("Stmt, err = %v, %v", stmt, err)
+	}
+	defer stmt.Close()
+	_, err = stmt.Exec("Bobby", 7)
+	if err != nil {
+		t.Fatalf("Exec = %v", err)
+	}
+	err = tx.Commit()
+	if err != nil {
+		t.Fatalf("Commit = %v", err)
+	}
+	// Commit() should have closed the statement
+	if !stmt.closed {
+		t.Fatal("Stmt not closed after Commit")
+	}
+}
+
 func TestTxStmt(t *testing.T) {
 	db := newTestDB(t, "")
 	defer closeDB(t, db)
@@ -464,6 +491,10 @@
 	if err != nil {
 		t.Fatalf("Commit = %v", err)
 	}
+	// Commit() should have closed the statement
+	if !txs.closed {
+		t.Fatal("Stmt not closed after Commit")
+	}
 }
 
 // Issue: http://golang.org/issue/2784