sql: add Tx.Stmt to use an existing prepared stmt in a transaction

R=rsc
CC=golang-dev
https://golang.org/cl/5433059
diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go
index c055fdd..f17d12e 100644
--- a/src/pkg/exp/sql/sql.go
+++ b/src/pkg/exp/sql/sql.go
@@ -344,25 +344,26 @@
 	return tx.txi.Rollback()
 }
 
-// Prepare creates a prepared statement.
+// Prepare creates a prepared statement for use within a transaction.
 //
-// The statement is only valid within the scope of this transaction.
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
 func (tx *Tx) Prepare(query string) (*Stmt, error) {
-	// TODO(bradfitz): the restriction that the returned statement
-	// is only valid for this Transaction is lame and negates a
-	// lot of the benefit of prepared statements.  We could be
-	// more efficient here and either provide a method to take an
-	// existing Stmt (created on perhaps a different Conn), and
-	// re-create it on this Conn if necessary. Or, better: keep a
-	// map in DB of query string to Stmts, and have Stmt.Execute
-	// do the right thing and re-prepare if the Conn in use
-	// doesn't have that prepared statement.  But we'll want to
-	// avoid caching the statement in the case where we only call
-	// conn.Prepare implicitly (such as in db.Exec or tx.Exec),
-	// but the caller package can't be holding a reference to the
-	// returned statement.  Perhaps just looking at the reference
-	// count (by noting Stmt.Close) would be enough. We might also
-	// want a finalizer on Stmt to drop the reference count.
+	// TODO(bradfitz): We could be more efficient here and either
+	// provide a method to take an existing Stmt (created on
+	// perhaps a different Conn), and re-create it on this Conn if
+	// necessary. Or, better: keep a map in DB of query string to
+	// Stmts, and have Stmt.Execute do the right thing and
+	// re-prepare if the Conn in use doesn't have that prepared
+	// statement.  But we'll want to avoid caching the statement
+	// in the case where we only call conn.Prepare implicitly
+	// (such as in db.Exec or tx.Exec), but the caller package
+	// can't be holding a reference to the returned statement.
+	// Perhaps just looking at the reference count (by noting
+	// Stmt.Close) would be enough. We might also want a finalizer
+	// on Stmt to drop the reference count.
 	ci, err := tx.grabConn()
 	if err != nil {
 		return nil, err
@@ -383,6 +384,39 @@
 	return stmt, nil
 }
 
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//  updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+//  ...
+//  tx, err := db.Begin()
+//  ...
+//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+	// TODO(bradfitz): optimize this. Currently this re-prepares
+	// each time.  This is fine for now to illustrate the API but
+	// we should really cache already-prepared statements
+	// per-Conn. See also the big comment in Tx.Prepare.
+
+	if tx.db != stmt.db {
+		return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
+	}
+	ci, err := tx.grabConn()
+	if err != nil {
+		return &Stmt{stickyErr: err}
+	}
+	defer tx.releaseConn()
+	si, err := ci.Prepare(stmt.query)
+	return &Stmt{
+		db:        tx.db,
+		tx:        tx,
+		txsi:      si,
+		query:     stmt.query,
+		stickyErr: err,
+	}
+}
+
 // Exec executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
 func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
@@ -448,8 +482,9 @@
 // Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines.
 type Stmt struct {
 	// Immutable:
-	db    *DB    // where we came from
-	query string // that created the Sttm
+	db        *DB    // where we came from
+	query     string // that created the Stmt
+	stickyErr error  // if non-nil, this error is returned for all operations
 
 	// If in a transaction, else both nil:
 	tx   *Tx
@@ -513,6 +548,9 @@
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
 func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
+	if s.stickyErr != nil {
+		return nil, nil, nil, s.stickyErr
+	}
 	s.mu.Lock()
 	if s.closed {
 		s.mu.Unlock()
@@ -621,6 +659,9 @@
 
 // Close closes the statement.
 func (s *Stmt) Close() error {
+	if s.stickyErr != nil {
+		return s.stickyErr
+	}
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	if s.closed {