database/sql: add context methods
Add context methods to sql and sql/driver methods. If
the driver doesn't implement context methods the connection
pool will still handle timeouts when a query fails to return
in time or when a connection is not available from the pool
in time.
There will be a follow-up CL that will add support for
context values that specify transaction levels and modes
that a driver can use.
Fixes #15123
Change-Id: Ia99f3957aa3f177b23044dd99d4ec217491a30a7
Reviewed-on: https://go-review.googlesource.com/29381
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go
new file mode 100644
index 0000000..65e1652
--- /dev/null
+++ b/src/database/sql/ctxutil.go
@@ -0,0 +1,231 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+ if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+ return ciCtx.PrepareContext(ctx, query)
+ }
+ type R struct {
+ err error
+ panic interface{}
+ si driver.Stmt
+ }
+
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.si, r.err = ci.Prepare(query)
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.si, r.err
+ }
+}
+
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
+ if execerCtx, is := execer.(driver.ExecerContext); is {
+ return execerCtx.ExecContext(ctx, query, dargs)
+ }
+ type R struct {
+ err error
+ panic interface{}
+ resi driver.Result
+ }
+
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.resi, r.err = execer.Exec(query, dargs)
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.resi, r.err
+ }
+}
+
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
+ if queryerCtx, is := queryer.(driver.QueryerContext); is {
+ return queryerCtx.QueryContext(ctx, query, dargs)
+ }
+ type R struct {
+ err error
+ panic interface{}
+ rowsi driver.Rows
+ }
+
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.rowsi, r.err = queryer.Query(query, dargs)
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.rowsi, r.err
+ }
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
+ if siCtx, is := si.(driver.StmtExecContext); is {
+ return siCtx.ExecContext(ctx, dargs)
+ }
+ type R struct {
+ err error
+ panic interface{}
+ resi driver.Result
+ }
+
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.resi, r.err = si.Exec(dargs)
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.resi, r.err
+ }
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
+ if siCtx, is := si.(driver.StmtQueryContext); is {
+ return siCtx.QueryContext(ctx, dargs)
+ }
+ type R struct {
+ err error
+ panic interface{}
+ rowsi driver.Rows
+ }
+
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.rowsi, r.err = si.Query(dargs)
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.rowsi, r.err
+ }
+}
+
+var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
+
+func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
+ if ciCtx, is := ci.(driver.ConnBeginContext); is {
+ return ciCtx.BeginContext(ctx)
+ }
+ // TODO(kardianos): check the transaction level in ctx. If set and non-default
+ // then return an error here as the BeginContext driver value is not supported.
+
+ type R struct {
+ err error
+ panic interface{}
+ txi driver.Tx
+ }
+ rc := make(chan R, 1)
+ go func() {
+ r := R{}
+ defer func() {
+ if v := recover(); v != nil {
+ r.panic = v
+ }
+ rc <- r
+ }()
+ r.txi, r.err = ci.Begin()
+ }()
+ select {
+ case <-ctx.Done():
+ go func() {
+ <-rc
+ close(rc)
+ }()
+ return nil, ctx.Err()
+ case r := <-rc:
+ if r.panic != nil {
+ panic(r.panic)
+ }
+ return r.txi, r.err
+ }
+}
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
index 4dba85a..ccc283d 100644
--- a/src/database/sql/driver/driver.go
+++ b/src/database/sql/driver/driver.go
@@ -8,7 +8,10 @@
// Most code should use package sql.
package driver
-import "errors"
+import (
+ "context"
+ "errors"
+)
// Value is a value that drivers must be able to handle.
// It is either nil or an instance of one of these types:
@@ -65,6 +68,12 @@
Exec(query string, args []Value) (Result, error)
}
+// ExecerContext is like execer, but must honor the context timeout and return
+// when the context is cancelled.
+type ExecerContext interface {
+ ExecContext(ctx context.Context, query string, args []Value) (Result, error)
+}
+
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will
@@ -76,6 +85,12 @@
Query(query string, args []Value) (Rows, error)
}
+// QueryerContext is like Queryer, but most honor the context timeout and return
+// when the context is cancelled.
+type QueryerContext interface {
+ QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
+}
+
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
@@ -98,6 +113,23 @@
Begin() (Tx, error)
}
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+ // PrepareContext returns a prepared statement, bound to this connection.
+ // context is for the preparation of the statement,
+ // it must not store the context within the statement itself.
+ PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// ConnBeginContext enhances the Conn interface with context.
+type ConnBeginContext interface {
+ // BeginContext starts and returns a new transaction.
+ // the provided context should be used to roll the transaction back
+ // if it is cancelled. If there is an isolation level in context
+ // that is not supported by the driver an error must be returned.
+ BeginContext(ctx context.Context) (Tx, error)
+}
+
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
@@ -139,6 +171,18 @@
Query(args []Value) (Rows, error)
}
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+ // ExecContext must honor the context timeout and return when it is cancelled.
+ ExecContext(ctx context.Context, args []Value) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+ // QueryContext must honor the context timeout and return when it is cancelled.
+ QueryContext(ctx context.Context, args []Value) (Rows, error)
+}
+
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 1e09a31..4c44e2b 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -13,6 +13,7 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
@@ -297,8 +298,8 @@
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
- si, err := dc.ci.Prepare(query)
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+ si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err == nil {
// Track each driverConn's open statements, so we can close them
// before closing the conn.
@@ -494,13 +495,13 @@
return db, nil
}
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
-func (db *DB) Ping() error {
+func (db *DB) PingContext(ctx context.Context) error {
// TODO(bradfitz): give drivers an optional hook to implement
// this in a more efficient or more reliable way, if they
// have one.
- dc, err := db.conn(cachedOrNewConn)
+ dc, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
return err
}
@@ -508,6 +509,12 @@
return nil
}
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+ return db.PingContext(context.Background())
+}
+
// Close closes the database, releasing any open resources.
//
// It is rare to Close a DB, as the DB handle is meant to be
@@ -777,12 +784,16 @@
var errDBClosed = errors.New("sql: database is closed")
// conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
+ // Check if the context is expired.
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
@@ -808,15 +819,21 @@
req := make(chan connRequest, 1)
db.connRequests = append(db.connRequests, req)
db.mu.Unlock()
- ret, ok := <-req
- if !ok {
- return nil, errDBClosed
+
+ // Timeout the connection request with the context.
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case ret, ok := <-req:
+ if !ok {
+ return nil, errDBClosed
+ }
+ if ret.err == nil && ret.conn.expired(lifetime) {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ return ret.conn, ret.err
}
- if ret.err == nil && ret.conn.expired(lifetime) {
- ret.conn.Close()
- return nil, driver.ErrBadConn
- }
- return ret.conn, ret.err
}
db.numOpen++ // optimistically
@@ -952,40 +969,51 @@
// connection to be opened.
const maxBadConnRetries = 2
+// PrepareContext creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+// Context is for the preparation of the statment, not for the execution of
+// the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ var stmt *Stmt
+ var err error
+ for i := 0; i < maxBadConnRetries; i++ {
+ stmt, err = db.prepare(ctx, query, cachedOrNewConn)
+ if err != driver.ErrBadConn {
+ break
+ }
+ }
+ if err == driver.ErrBadConn {
+ return db.prepare(ctx, query, alwaysNewConn)
+ }
+ return stmt, err
+}
+
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
- var stmt *Stmt
- var err error
- for i := 0; i < maxBadConnRetries; i++ {
- stmt, err = db.prepare(query, cachedOrNewConn)
- if err != driver.ErrBadConn {
- break
- }
- }
- if err == driver.ErrBadConn {
- return db.prepare(query, alwaysNewConn)
- }
- return stmt, err
+ return db.PrepareContext(context.Background(), query)
}
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
- dc, err := db.conn(strategy)
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.prepareLocked(query)
+ si, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
@@ -1002,25 +1030,31 @@
return stmt, nil
}
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
var res Result
var err error
for i := 0; i < maxBadConnRetries; i++ {
- res, err = db.exec(query, args, cachedOrNewConn)
+ res, err = db.exec(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.exec(query, args, alwaysNewConn)
+ return db.exec(ctx, query, args, alwaysNewConn)
}
return res, err
}
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
- dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
@@ -1036,7 +1070,7 @@
}
var resi driver.Result
withLock(dc, func() {
- resi, err = execer.Exec(query, dargs)
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1048,44 +1082,50 @@
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
for i := 0; i < maxBadConnRetries; i++ {
- rows, err = db.query(query, args, cachedOrNewConn)
+ rows, err = db.query(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.query(query, args, alwaysNewConn)
+ return db.query(ctx, query, args, alwaysNewConn)
}
return rows, err
}
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
- ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+ ci, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- return db.queryConn(ci, ci.releaseConn, query, args)
+ return db.queryConn(ctx, ci, ci.releaseConn, query, args)
}
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
@@ -1094,7 +1134,7 @@
}
var rowsi driver.Rows
withLock(dc, func() {
- rowsi, err = queryer.Query(query, dargs)
+ rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1115,7 +1155,7 @@
var si driver.Stmt
var err error
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
@@ -1123,7 +1163,7 @@
}
ds := driverStmt{dc, si}
- rowsi, err := rowsiFromStatement(ds, args...)
+ rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
withLock(dc, func() {
si.Close()
@@ -1143,49 +1183,77 @@
return rows, nil
}
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := db.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
- rows, err := db.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return db.QueryRowContext(context.Background(), query, args...)
}
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginContext starts a transaction. If a non-default isolation level is used
+// that the driver doesn't support an error will be returned. Different drivers
+// may have slightly different meanings for the same isolation level.
+func (db *DB) BeginContext(ctx context.Context) (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < maxBadConnRetries; i++ {
- tx, err = db.begin(cachedOrNewConn)
+ tx, err = db.begin(ctx, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.begin(alwaysNewConn)
+ return db.begin(ctx, alwaysNewConn)
}
return tx, err
}
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
- dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+ return db.BeginContext(context.Background())
+}
+
+func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var txi driver.Tx
withLock(dc, func() {
- txi, err = dc.ci.Begin()
+ txi, err = ctxDriverBegin(ctx, dc.ci)
})
if err != nil {
db.putConn(dc, err)
return nil, err
}
- return &Tx{
- db: db,
- dc: dc,
- txi: txi,
- }, nil
+
+ // Schedule the transaction to rollback when the context is cancelled.
+ // The cancel function in Tx will be called after done is set to true.
+ ctx, cancel := context.WithCancel(ctx)
+ tx = &Tx{
+ db: db,
+ dc: dc,
+ txi: txi,
+ cancel: cancel,
+ }
+ go func() {
+ select {
+ case <-ctx.Done():
+ if !tx.done {
+ tx.Rollback()
+ }
+ }
+ }()
+ return tx, nil
}
// Driver returns the database's underlying driver.
@@ -1222,6 +1290,9 @@
sync.Mutex
v []*Stmt
}
+
+ // cancel is called after done transitions from false to true.
+ cancel func()
}
// ErrTxDone is returned by any operation that is performed on a transaction
@@ -1234,11 +1305,12 @@
}
tx.done = true
tx.db.putConn(tx.dc, err)
+ tx.cancel()
tx.dc = nil
tx.txi = nil
}
-func (tx *Tx) grabConn() (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
if tx.done {
return nil, ErrTxDone
}
@@ -1292,7 +1364,10 @@
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+// Context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
@@ -1306,7 +1381,7 @@
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
@@ -1334,7 +1409,17 @@
return stmt, nil
}
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+ return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
@@ -1342,11 +1427,11 @@
// ...
// tx, err := db.Begin()
// ...
-// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
@@ -1355,7 +1440,7 @@
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
@@ -1379,10 +1464,26 @@
return txs
}
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
- dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
@@ -1413,25 +1514,43 @@
}
defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
- dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+ return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
releaseConn := func(error) {}
- return tx.db.queryConn(dc, releaseConn, query, args)
+ return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+ return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
- rows, err := tx.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
@@ -1468,15 +1587,15 @@
lastNumClosed uint64
}
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1484,7 +1603,7 @@
return nil, err
}
- res, err = resultFromStatement(driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
@@ -1493,13 +1612,19 @@
return nil, driver.ErrBadConn
}
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+ return s.ExecContext(context.Background(), args...)
+}
+
func driverNumInput(ds driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
@@ -1516,7 +1641,8 @@
ds.Lock()
defer ds.Unlock()
- resi, err := ds.si.Exec(dargs)
+
+ resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
@@ -1552,7 +1678,7 @@
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -1567,7 +1693,7 @@
// transaction was created on.
if s.tx != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+ ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
@@ -1578,8 +1704,7 @@
s.removeClosedStmtLocked()
s.mu.Unlock()
- // TODO(bradfitz): or always wait for one? make configurable later?
- dc, err := s.db.conn(cachedOrNewConn)
+ dc, err := s.db.conn(ctx, cachedOrNewConn)
if err != nil {
return nil, nil, nil, err
}
@@ -1595,7 +1720,7 @@
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- si, err = dc.prepareLocked(s.query)
+ si, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
@@ -1609,15 +1734,15 @@
return dc, dc.releaseConn, si, nil
}
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1625,7 +1750,7 @@
return nil, err
}
- rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
@@ -1650,7 +1775,13 @@
return nil, driver.ErrBadConn
}
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+ return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
@@ -1670,13 +1801,33 @@
ds.Lock()
defer ds.Unlock()
- rowsi, err := ds.si.Query(dargs)
+
+ rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return rowsi, nil
}
+// QueryRowContext executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+// var name string
+// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+ rows, err := s.QueryContext(ctx, args...)
+ if err != nil {
+ return &Row{err: err}
+ }
+ return &Row{rows: rows}
+}
+
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
@@ -1689,11 +1840,7 @@
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
func (s *Stmt) QueryRow(args ...interface{}) *Row {
- rows, err := s.Query(args...)
- if err != nil {
- return &Row{err: err}
- }
- return &Row{rows: rows}
+ return s.QueryRowContext(context.Background(), args...)
}
// Close closes the statement.
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 41afd00..9fcb2e3 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -5,6 +5,7 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
@@ -1159,17 +1160,19 @@
db.SetMaxOpenConns(3)
- conn0, err := db.conn(cachedOrNewConn)
+ ctx := context.Background()
+
+ conn0, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn1, err := db.conn(cachedOrNewConn)
+ conn1, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn2, err := db.conn(cachedOrNewConn)
+ conn2, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}