database/sql: Close per-tx prepared statements when the associated tx ends LGTM=bradfitz R=golang-codereviews, bradfitz, mattn.jp CC=golang-codereviews https://golang.org/cl/131650043
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 90f813d..731b7a7 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go
@@ -1043,6 +1043,13 @@ // or Rollback. once done, all operations fail with // ErrTxDone. done bool + + // All Stmts prepared for this transaction. These will be closed after the + // transaction has been committed or rolled back. + stmts struct { + sync.Mutex + v []*Stmt + } } var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") @@ -1064,6 +1071,15 @@ return tx.dc, nil } +// Closes all Stmts prepared for this transaction. +func (tx *Tx) closePrepared() { + tx.stmts.Lock() + for _, stmt := range tx.stmts.v { + stmt.Close() + } + tx.stmts.Unlock() +} + // Commit commits the transaction. func (tx *Tx) Commit() error { if tx.done { @@ -1071,8 +1087,12 @@ } defer tx.close() tx.dc.Lock() - defer tx.dc.Unlock() - return tx.txi.Commit() + err := tx.txi.Commit() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err } // Rollback aborts the transaction. @@ -1082,8 +1102,12 @@ } defer tx.close() tx.dc.Lock() - defer tx.dc.Unlock() - return tx.txi.Rollback() + err := tx.txi.Rollback() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err } // Prepare creates a prepared statement for use within a transaction. @@ -1127,6 +1151,9 @@ }, query: query, } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, stmt) + tx.stmts.Unlock() return stmt, nil } @@ -1155,7 +1182,7 @@ dc.Lock() si, err := dc.ci.Prepare(stmt.query) dc.Unlock() - return &Stmt{ + txs := &Stmt{ db: tx.db, tx: tx, txsi: &driverStmt{ @@ -1165,6 +1192,10 @@ query: stmt.query, stickyErr: err, } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, txs) + tx.stmts.Unlock() + return txs } // Exec executes a query that doesn't return rows.
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 12e5a6f..34efdf2 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go
@@ -441,6 +441,33 @@ } } +func TestTxPrepare(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + stmt, err := tx.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + _, err = stmt.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + // Commit() should have closed the statement + if !stmt.closed { + t.Fatal("Stmt not closed after Commit") + } +} + func TestTxStmt(t *testing.T) { db := newTestDB(t, "") defer closeDB(t, db) @@ -464,6 +491,10 @@ if err != nil { t.Fatalf("Commit = %v", err) } + // Commit() should have closed the statement + if !txs.closed { + t.Fatal("Stmt not closed after Commit") + } } // Issue: http://golang.org/issue/2784