database/sql: add context methods

Add context methods to sql and sql/driver methods. If
the driver doesn't implement context methods the connection
pool will still handle timeouts when a query fails to return
in time or when a connection is not available from the pool
in time.

There will be a follow-up CL that will add support for
context values that specify transaction levels and modes
that a driver can use.

Fixes #15123

Change-Id: Ia99f3957aa3f177b23044dd99d4ec217491a30a7
Reviewed-on: https://go-review.googlesource.com/29381
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go
new file mode 100644
index 0000000..65e1652
--- /dev/null
+++ b/src/database/sql/ctxutil.go
@@ -0,0 +1,231 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+	"context"
+	"database/sql/driver"
+	"errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+	if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+		return ciCtx.PrepareContext(ctx, query)
+	}
+	type R struct {
+		err   error
+		panic interface{}
+		si    driver.Stmt
+	}
+
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.si, r.err = ci.Prepare(query)
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.si, r.err
+	}
+}
+
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
+	if execerCtx, is := execer.(driver.ExecerContext); is {
+		return execerCtx.ExecContext(ctx, query, dargs)
+	}
+	type R struct {
+		err   error
+		panic interface{}
+		resi  driver.Result
+	}
+
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.resi, r.err = execer.Exec(query, dargs)
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.resi, r.err
+	}
+}
+
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
+	if queryerCtx, is := queryer.(driver.QueryerContext); is {
+		return queryerCtx.QueryContext(ctx, query, dargs)
+	}
+	type R struct {
+		err   error
+		panic interface{}
+		rowsi driver.Rows
+	}
+
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.rowsi, r.err = queryer.Query(query, dargs)
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.rowsi, r.err
+	}
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
+	if siCtx, is := si.(driver.StmtExecContext); is {
+		return siCtx.ExecContext(ctx, dargs)
+	}
+	type R struct {
+		err   error
+		panic interface{}
+		resi  driver.Result
+	}
+
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.resi, r.err = si.Exec(dargs)
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.resi, r.err
+	}
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
+	if siCtx, is := si.(driver.StmtQueryContext); is {
+		return siCtx.QueryContext(ctx, dargs)
+	}
+	type R struct {
+		err   error
+		panic interface{}
+		rowsi driver.Rows
+	}
+
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.rowsi, r.err = si.Query(dargs)
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.rowsi, r.err
+	}
+}
+
+var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
+
+func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
+	if ciCtx, is := ci.(driver.ConnBeginContext); is {
+		return ciCtx.BeginContext(ctx)
+	}
+	// TODO(kardianos): check the transaction level in ctx. If set and non-default
+	// then return an error here as the BeginContext driver value is not supported.
+
+	type R struct {
+		err   error
+		panic interface{}
+		txi   driver.Tx
+	}
+	rc := make(chan R, 1)
+	go func() {
+		r := R{}
+		defer func() {
+			if v := recover(); v != nil {
+				r.panic = v
+			}
+			rc <- r
+		}()
+		r.txi, r.err = ci.Begin()
+	}()
+	select {
+	case <-ctx.Done():
+		go func() {
+			<-rc
+			close(rc)
+		}()
+		return nil, ctx.Err()
+	case r := <-rc:
+		if r.panic != nil {
+			panic(r.panic)
+		}
+		return r.txi, r.err
+	}
+}
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
index 4dba85a..ccc283d 100644
--- a/src/database/sql/driver/driver.go
+++ b/src/database/sql/driver/driver.go
@@ -8,7 +8,10 @@
 // Most code should use package sql.
 package driver
 
-import "errors"
+import (
+	"context"
+	"errors"
+)
 
 // Value is a value that drivers must be able to handle.
 // It is either nil or an instance of one of these types:
@@ -65,6 +68,12 @@
 	Exec(query string, args []Value) (Result, error)
 }
 
+// ExecerContext is like execer, but must honor the context timeout and return
+// when the context is cancelled.
+type ExecerContext interface {
+	ExecContext(ctx context.Context, query string, args []Value) (Result, error)
+}
+
 // Queryer is an optional interface that may be implemented by a Conn.
 //
 // If a Conn does not implement Queryer, the sql package's DB.Query will
@@ -76,6 +85,12 @@
 	Query(query string, args []Value) (Rows, error)
 }
 
+// QueryerContext is like Queryer, but most honor the context timeout and return
+// when the context is cancelled.
+type QueryerContext interface {
+	QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
+}
+
 // Conn is a connection to a database. It is not used concurrently
 // by multiple goroutines.
 //
@@ -98,6 +113,23 @@
 	Begin() (Tx, error)
 }
 
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+	// PrepareContext returns a prepared statement, bound to this connection.
+	// context is for the preparation of the statement,
+	// it must not store the context within the statement itself.
+	PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// ConnBeginContext enhances the Conn interface with context.
+type ConnBeginContext interface {
+	// BeginContext starts and returns a new transaction.
+	// the provided context should be used to roll the transaction back
+	// if it is cancelled. If there is an isolation level in context
+	// that is not supported by the driver an error must be returned.
+	BeginContext(ctx context.Context) (Tx, error)
+}
+
 // Result is the result of a query execution.
 type Result interface {
 	// LastInsertId returns the database's auto-generated ID
@@ -139,6 +171,18 @@
 	Query(args []Value) (Rows, error)
 }
 
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+	// ExecContext must honor the context timeout and return when it is cancelled.
+	ExecContext(ctx context.Context, args []Value) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+	// QueryContext must honor the context timeout and return when it is cancelled.
+	QueryContext(ctx context.Context, args []Value) (Rows, error)
+}
+
 // ColumnConverter may be optionally implemented by Stmt if the
 // statement is aware of its own columns' types and can convert from
 // any type to a driver Value.
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 1e09a31..4c44e2b 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -13,6 +13,7 @@
 package sql
 
 import (
+	"context"
 	"database/sql/driver"
 	"errors"
 	"fmt"
@@ -297,8 +298,8 @@
 	return dc.createdAt.Add(timeout).Before(nowFunc())
 }
 
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
-	si, err := dc.ci.Prepare(query)
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+	si, err := ctxDriverPrepare(ctx, dc.ci, query)
 	if err == nil {
 		// Track each driverConn's open statements, so we can close them
 		// before closing the conn.
@@ -494,13 +495,13 @@
 	return db, nil
 }
 
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
 // establishing a connection if necessary.
-func (db *DB) Ping() error {
+func (db *DB) PingContext(ctx context.Context) error {
 	// TODO(bradfitz): give drivers an optional hook to implement
 	// this in a more efficient or more reliable way, if they
 	// have one.
-	dc, err := db.conn(cachedOrNewConn)
+	dc, err := db.conn(ctx, cachedOrNewConn)
 	if err != nil {
 		return err
 	}
@@ -508,6 +509,12 @@
 	return nil
 }
 
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+	return db.PingContext(context.Background())
+}
+
 // Close closes the database, releasing any open resources.
 //
 // It is rare to Close a DB, as the DB handle is meant to be
@@ -777,12 +784,16 @@
 var errDBClosed = errors.New("sql: database is closed")
 
 // conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
 	db.mu.Lock()
 	if db.closed {
 		db.mu.Unlock()
 		return nil, errDBClosed
 	}
+	// Check if the context is expired.
+	if err := ctx.Err(); err != nil {
+		return nil, err
+	}
 	lifetime := db.maxLifetime
 
 	// Prefer a free connection, if possible.
@@ -808,15 +819,21 @@
 		req := make(chan connRequest, 1)
 		db.connRequests = append(db.connRequests, req)
 		db.mu.Unlock()
-		ret, ok := <-req
-		if !ok {
-			return nil, errDBClosed
+
+		// Timeout the connection request with the context.
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case ret, ok := <-req:
+			if !ok {
+				return nil, errDBClosed
+			}
+			if ret.err == nil && ret.conn.expired(lifetime) {
+				ret.conn.Close()
+				return nil, driver.ErrBadConn
+			}
+			return ret.conn, ret.err
 		}
-		if ret.err == nil && ret.conn.expired(lifetime) {
-			ret.conn.Close()
-			return nil, driver.ErrBadConn
-		}
-		return ret.conn, ret.err
 	}
 
 	db.numOpen++ // optimistically
@@ -952,40 +969,51 @@
 // connection to be opened.
 const maxBadConnRetries = 2
 
+// PrepareContext creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+// Context is for the preparation of the statment, not for the execution of
+// the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+	var stmt *Stmt
+	var err error
+	for i := 0; i < maxBadConnRetries; i++ {
+		stmt, err = db.prepare(ctx, query, cachedOrNewConn)
+		if err != driver.ErrBadConn {
+			break
+		}
+	}
+	if err == driver.ErrBadConn {
+		return db.prepare(ctx, query, alwaysNewConn)
+	}
+	return stmt, err
+}
+
 // Prepare creates a prepared statement for later queries or executions.
 // Multiple queries or executions may be run concurrently from the
 // returned statement.
 // The caller must call the statement's Close method
 // when the statement is no longer needed.
 func (db *DB) Prepare(query string) (*Stmt, error) {
-	var stmt *Stmt
-	var err error
-	for i := 0; i < maxBadConnRetries; i++ {
-		stmt, err = db.prepare(query, cachedOrNewConn)
-		if err != driver.ErrBadConn {
-			break
-		}
-	}
-	if err == driver.ErrBadConn {
-		return db.prepare(query, alwaysNewConn)
-	}
-	return stmt, err
+	return db.PrepareContext(context.Background(), query)
 }
 
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
 	// TODO: check if db.driver supports an optional
 	// driver.Preparer interface and call that instead, if so,
 	// otherwise we make a prepared statement that's bound
 	// to a connection, and to execute this prepared statement
 	// we either need to use this connection (if it's free), else
 	// get a new connection + re-prepare + execute on that one.
-	dc, err := db.conn(strategy)
+	dc, err := db.conn(ctx, strategy)
 	if err != nil {
 		return nil, err
 	}
 	var si driver.Stmt
 	withLock(dc, func() {
-		si, err = dc.prepareLocked(query)
+		si, err = dc.prepareLocked(ctx, query)
 	})
 	if err != nil {
 		db.putConn(dc, err)
@@ -1002,25 +1030,31 @@
 	return stmt, nil
 }
 
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
 // The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
 	var res Result
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		res, err = db.exec(query, args, cachedOrNewConn)
+		res, err = db.exec(ctx, query, args, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
 	if err == driver.ErrBadConn {
-		return db.exec(query, args, alwaysNewConn)
+		return db.exec(ctx, query, args, alwaysNewConn)
 	}
 	return res, err
 }
 
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
-	dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+	return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+	dc, err := db.conn(ctx, strategy)
 	if err != nil {
 		return nil, err
 	}
@@ -1036,7 +1070,7 @@
 		}
 		var resi driver.Result
 		withLock(dc, func() {
-			resi, err = execer.Exec(query, dargs)
+			resi, err = ctxDriverExec(ctx, execer, query, dargs)
 		})
 		if err != driver.ErrSkip {
 			if err != nil {
@@ -1048,44 +1082,50 @@
 
 	var si driver.Stmt
 	withLock(dc, func() {
-		si, err = dc.ci.Prepare(query)
+		si, err = ctxDriverPrepare(ctx, dc.ci, query)
 	})
 	if err != nil {
 		return nil, err
 	}
 	defer withLock(dc, func() { si.Close() })
-	return resultFromStatement(driverStmt{dc, si}, args...)
+	return resultFromStatement(ctx, driverStmt{dc, si}, args...)
 }
 
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
 // The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
 	var rows *Rows
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		rows, err = db.query(query, args, cachedOrNewConn)
+		rows, err = db.query(ctx, query, args, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
 	if err == driver.ErrBadConn {
-		return db.query(query, args, alwaysNewConn)
+		return db.query(ctx, query, args, alwaysNewConn)
 	}
 	return rows, err
 }
 
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
-	ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+	return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+	ci, err := db.conn(ctx, strategy)
 	if err != nil {
 		return nil, err
 	}
 
-	return db.queryConn(ci, ci.releaseConn, query, args)
+	return db.queryConn(ctx, ci, ci.releaseConn, query, args)
 }
 
 // queryConn executes a query on the given connection.
 // The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
 	if queryer, ok := dc.ci.(driver.Queryer); ok {
 		dargs, err := driverArgs(nil, args)
 		if err != nil {
@@ -1094,7 +1134,7 @@
 		}
 		var rowsi driver.Rows
 		withLock(dc, func() {
-			rowsi, err = queryer.Query(query, dargs)
+			rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
 		})
 		if err != driver.ErrSkip {
 			if err != nil {
@@ -1115,7 +1155,7 @@
 	var si driver.Stmt
 	var err error
 	withLock(dc, func() {
-		si, err = dc.ci.Prepare(query)
+		si, err = ctxDriverPrepare(ctx, dc.ci, query)
 	})
 	if err != nil {
 		releaseConn(err)
@@ -1123,7 +1163,7 @@
 	}
 
 	ds := driverStmt{dc, si}
-	rowsi, err := rowsiFromStatement(ds, args...)
+	rowsi, err := rowsiFromStatement(ctx, ds, args...)
 	if err != nil {
 		withLock(dc, func() {
 			si.Close()
@@ -1143,49 +1183,77 @@
 	return rows, nil
 }
 
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+	rows, err := db.QueryContext(ctx, query, args...)
+	return &Row{rows: rows, err: err}
+}
+
 // QueryRow executes a query that is expected to return at most one row.
 // QueryRow always returns a non-nil value. Errors are deferred until
 // Row's Scan method is called.
 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
-	rows, err := db.Query(query, args...)
-	return &Row{rows: rows, err: err}
+	return db.QueryRowContext(context.Background(), query, args...)
 }
 
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginContext starts a transaction. If a non-default isolation level is used
+// that the driver doesn't support an error will be returned. Different drivers
+// may have slightly different meanings for the same isolation level.
+func (db *DB) BeginContext(ctx context.Context) (*Tx, error) {
 	var tx *Tx
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		tx, err = db.begin(cachedOrNewConn)
+		tx, err = db.begin(ctx, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
 	if err == driver.ErrBadConn {
-		return db.begin(alwaysNewConn)
+		return db.begin(ctx, alwaysNewConn)
 	}
 	return tx, err
 }
 
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
-	dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+	return db.BeginContext(context.Background())
+}
+
+func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) {
+	dc, err := db.conn(ctx, strategy)
 	if err != nil {
 		return nil, err
 	}
 	var txi driver.Tx
 	withLock(dc, func() {
-		txi, err = dc.ci.Begin()
+		txi, err = ctxDriverBegin(ctx, dc.ci)
 	})
 	if err != nil {
 		db.putConn(dc, err)
 		return nil, err
 	}
-	return &Tx{
-		db:  db,
-		dc:  dc,
-		txi: txi,
-	}, nil
+
+	// Schedule the transaction to rollback when the context is cancelled.
+	// The cancel function in Tx will be called after done is set to true.
+	ctx, cancel := context.WithCancel(ctx)
+	tx = &Tx{
+		db:     db,
+		dc:     dc,
+		txi:    txi,
+		cancel: cancel,
+	}
+	go func() {
+		select {
+		case <-ctx.Done():
+			if !tx.done {
+				tx.Rollback()
+			}
+		}
+	}()
+	return tx, nil
 }
 
 // Driver returns the database's underlying driver.
@@ -1222,6 +1290,9 @@
 		sync.Mutex
 		v []*Stmt
 	}
+
+	// cancel is called after done transitions from false to true.
+	cancel func()
 }
 
 // ErrTxDone is returned by any operation that is performed on a transaction
@@ -1234,11 +1305,12 @@
 	}
 	tx.done = true
 	tx.db.putConn(tx.dc, err)
+	tx.cancel()
 	tx.dc = nil
 	tx.txi = nil
 }
 
-func (tx *Tx) grabConn() (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
 	if tx.done {
 		return nil, ErrTxDone
 	}
@@ -1292,7 +1364,10 @@
 // be used once the transaction has been committed or rolled back.
 //
 // To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+// Context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
 	// TODO(bradfitz): We could be more efficient here and either
 	// provide a method to take an existing Stmt (created on
 	// perhaps a different Conn), and re-create it on this Conn if
@@ -1306,7 +1381,7 @@
 	// Perhaps just looking at the reference count (by noting
 	// Stmt.Close) would be enough. We might also want a finalizer
 	// on Stmt to drop the reference count.
-	dc, err := tx.grabConn()
+	dc, err := tx.grabConn(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -1334,7 +1409,17 @@
 	return stmt, nil
 }
 
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+	return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
 // an existing statement.
 //
 // Example:
@@ -1342,11 +1427,11 @@
 //  ...
 //  tx, err := db.Begin()
 //  ...
-//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//  res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
 //
 // The returned statement operates within the transaction and can no longer
 // be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
 	// TODO(bradfitz): optimize this. Currently this re-prepares
 	// each time. This is fine for now to illustrate the API but
 	// we should really cache already-prepared statements
@@ -1355,7 +1440,7 @@
 	if tx.db != stmt.db {
 		return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
 	}
-	dc, err := tx.grabConn()
+	dc, err := tx.grabConn(ctx)
 	if err != nil {
 		return &Stmt{stickyErr: err}
 	}
@@ -1379,10 +1464,26 @@
 	return txs
 }
 
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//  updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+//  ...
+//  tx, err := db.Begin()
+//  ...
+//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+	return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
-	dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+	dc, err := tx.grabConn(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -1413,25 +1514,43 @@
 	}
 	defer withLock(dc, func() { si.Close() })
 
-	return resultFromStatement(driverStmt{dc, si}, args...)
+	return resultFromStatement(ctx, driverStmt{dc, si}, args...)
 }
 
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
-	dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+	return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+	dc, err := tx.grabConn(ctx)
 	if err != nil {
 		return nil, err
 	}
 	releaseConn := func(error) {}
-	return tx.db.queryConn(dc, releaseConn, query, args)
+	return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+	return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+	rows, err := tx.QueryContext(ctx, query, args...)
+	return &Row{rows: rows, err: err}
 }
 
 // QueryRow executes a query that is expected to return at most one row.
 // QueryRow always returns a non-nil value. Errors are deferred until
 // Row's Scan method is called.
 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
-	rows, err := tx.Query(query, args...)
-	return &Row{rows: rows, err: err}
+	return tx.QueryRowContext(context.Background(), query, args...)
 }
 
 // connStmt is a prepared statement on a particular connection.
@@ -1468,15 +1587,15 @@
 	lastNumClosed uint64
 }
 
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
 // returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
 	s.closemu.RLock()
 	defer s.closemu.RUnlock()
 
 	var res Result
 	for i := 0; i < maxBadConnRetries; i++ {
-		dc, releaseConn, si, err := s.connStmt()
+		dc, releaseConn, si, err := s.connStmt(ctx)
 		if err != nil {
 			if err == driver.ErrBadConn {
 				continue
@@ -1484,7 +1603,7 @@
 			return nil, err
 		}
 
-		res, err = resultFromStatement(driverStmt{dc, si}, args...)
+		res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
 		releaseConn(err)
 		if err != driver.ErrBadConn {
 			return res, err
@@ -1493,13 +1612,19 @@
 	return nil, driver.ErrBadConn
 }
 
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+	return s.ExecContext(context.Background(), args...)
+}
+
 func driverNumInput(ds driverStmt) int {
 	ds.Lock()
 	defer ds.Unlock() // in case NumInput panics
 	return ds.si.NumInput()
 }
 
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
 	want := driverNumInput(ds)
 
 	// -1 means the driver doesn't know how to count the number of
@@ -1516,7 +1641,8 @@
 
 	ds.Lock()
 	defer ds.Unlock()
-	resi, err := ds.si.Exec(dargs)
+
+	resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
 	if err != nil {
 		return nil, err
 	}
@@ -1552,7 +1678,7 @@
 // connStmt returns a free driver connection on which to execute the
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
 	if err = s.stickyErr; err != nil {
 		return
 	}
@@ -1567,7 +1693,7 @@
 	// transaction was created on.
 	if s.tx != nil {
 		s.mu.Unlock()
-		ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+		ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
 		if err != nil {
 			return
 		}
@@ -1578,8 +1704,7 @@
 	s.removeClosedStmtLocked()
 	s.mu.Unlock()
 
-	// TODO(bradfitz): or always wait for one? make configurable later?
-	dc, err := s.db.conn(cachedOrNewConn)
+	dc, err := s.db.conn(ctx, cachedOrNewConn)
 	if err != nil {
 		return nil, nil, nil, err
 	}
@@ -1595,7 +1720,7 @@
 
 	// No luck; we need to prepare the statement on this connection
 	withLock(dc, func() {
-		si, err = dc.prepareLocked(s.query)
+		si, err = dc.prepareLocked(ctx, s.query)
 	})
 	if err != nil {
 		s.db.putConn(dc, err)
@@ -1609,15 +1734,15 @@
 	return dc, dc.releaseConn, si, nil
 }
 
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
 // and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
 	s.closemu.RLock()
 	defer s.closemu.RUnlock()
 
 	var rowsi driver.Rows
 	for i := 0; i < maxBadConnRetries; i++ {
-		dc, releaseConn, si, err := s.connStmt()
+		dc, releaseConn, si, err := s.connStmt(ctx)
 		if err != nil {
 			if err == driver.ErrBadConn {
 				continue
@@ -1625,7 +1750,7 @@
 			return nil, err
 		}
 
-		rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+		rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
 		if err == nil {
 			// Note: ownership of ci passes to the *Rows, to be freed
 			// with releaseConn.
@@ -1650,7 +1775,13 @@
 	return nil, driver.ErrBadConn
 }
 
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+	return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
 	var want int
 	withLock(ds, func() {
 		want = ds.si.NumInput()
@@ -1670,13 +1801,33 @@
 
 	ds.Lock()
 	defer ds.Unlock()
-	rowsi, err := ds.si.Query(dargs)
+
+	rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
 	if err != nil {
 		return nil, err
 	}
 	return rowsi, nil
 }
 
+// QueryRowContext executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+//  var name string
+//  err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+	rows, err := s.QueryContext(ctx, args...)
+	if err != nil {
+		return &Row{err: err}
+	}
+	return &Row{rows: rows}
+}
+
 // QueryRow executes a prepared query statement with the given arguments.
 // If an error occurs during the execution of the statement, that error will
 // be returned by a call to Scan on the returned *Row, which is always non-nil.
@@ -1689,11 +1840,7 @@
 //  var name string
 //  err := nameByUseridStmt.QueryRow(id).Scan(&name)
 func (s *Stmt) QueryRow(args ...interface{}) *Row {
-	rows, err := s.Query(args...)
-	if err != nil {
-		return &Row{err: err}
-	}
-	return &Row{rows: rows}
+	return s.QueryRowContext(context.Background(), args...)
 }
 
 // Close closes the statement.
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 41afd00..9fcb2e3 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -5,6 +5,7 @@
 package sql
 
 import (
+	"context"
 	"database/sql/driver"
 	"errors"
 	"fmt"
@@ -1159,17 +1160,19 @@
 
 	db.SetMaxOpenConns(3)
 
-	conn0, err := db.conn(cachedOrNewConn)
+	ctx := context.Background()
+
+	conn0, err := db.conn(ctx, cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}
 
-	conn1, err := db.conn(cachedOrNewConn)
+	conn1, err := db.conn(ctx, cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}
 
-	conn2, err := db.conn(cachedOrNewConn)
+	conn2, err := db.conn(ctx, cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}