database/sql: Retry with a fresh connection after maxBadConnRetries

Previously if the connection pool was larger than maxBadConnRetries
and there were a lot of bad connections in the pool (for example if
the database server was restarted), a query might have failed with an
ErrBadConn unnecessarily.  Instead of trying to guess how many times
to retry, try maxBadConnRetries times and then force a fresh
connection to be used for the last attempt.  At the same time, lower
maxBadConnRetries to a smaller value now that it's not that important
to retry so many times from the free connection list.

Fixes #8834

Change-Id: I6542f151a766a658980fb396fa4880ecf5874e3d
Reviewed-on: https://go-review.googlesource.com/2034
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index a993fd4..8cbbb29 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -89,7 +89,10 @@
 	stmtsMade   int
 	stmtsClosed int
 	numPrepare  int
-	bad         bool
+
+	// bad connection tests; see isBad()
+	bad       bool
+	stickyBad bool
 }
 
 func (c *fakeConn) incrStat(v *int) {
@@ -243,13 +246,15 @@
 }
 
 func (c *fakeConn) isBad() bool {
-	// if not simulating bad conn, do nothing
-	if !c.bad {
+	if c.stickyBad {
+		return true
+	} else if c.bad {
+		// alternate between bad conn and not bad conn
+		c.db.badConn = !c.db.badConn
+		return c.db.badConn
+	} else {
 		return false
 	}
-	// alternate between bad conn and not bad conn
-	c.db.badConn = !c.db.badConn
-	return c.db.badConn
 }
 
 func (c *fakeConn) Begin() (driver.Tx, error) {
@@ -466,7 +471,7 @@
 		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
 	}
 
-	if hookPrepareBadConn != nil && hookPrepareBadConn() {
+	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
 		return nil, driver.ErrBadConn
 	}
 
@@ -529,7 +534,7 @@
 		return nil, errClosed
 	}
 
-	if hookExecBadConn != nil && hookExecBadConn() {
+	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
 		return nil, driver.ErrBadConn
 	}
 
@@ -613,7 +618,7 @@
 		return nil, errClosed
 	}
 
-	if hookQueryBadConn != nil && hookQueryBadConn() {
+	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
 		return nil, driver.ErrBadConn
 	}
 
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 0a84163..96c93ed 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -235,6 +235,18 @@
 	maxOpen  int                    // <= 0 means unlimited
 }
 
+// connReuseStrategy determines how (*DB).conn returns database connections.
+type connReuseStrategy uint8
+
+const (
+	// alwaysNewConn forces a new connection to the database.
+	alwaysNewConn connReuseStrategy = iota
+	// cachedOrNewConn returns a cached connection, if available, else waits
+	// for one to become available (if MaxOpenConns has been reached) or
+	// creates a new database connection.
+	cachedOrNewConn
+)
+
 // driverConn wraps a driver.Conn with a mutex, to
 // be held during all calls into the Conn. (including any calls onto
 // interfaces returned via that Conn, such as calls on Tx, Stmt,
@@ -465,7 +477,7 @@
 	// TODO(bradfitz): give drivers an optional hook to implement
 	// this in a more efficient or more reliable way, if they
 	// have one.
-	dc, err := db.conn()
+	dc, err := db.conn(cachedOrNewConn)
 	if err != nil {
 		return err
 	}
@@ -651,17 +663,28 @@
 
 var errDBClosed = errors.New("sql: database is closed")
 
-// conn returns a newly-opened or cached *driverConn
-func (db *DB) conn() (*driverConn, error) {
+// conn returns a newly-opened or cached *driverConn.
+func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
 	db.mu.Lock()
 	if db.closed {
 		db.mu.Unlock()
 		return nil, errDBClosed
 	}
 
-	// If db.maxOpen > 0 and the number of open connections is over the limit
-	// and there are no free connection, make a request and wait.
-	if db.maxOpen > 0 && db.numOpen >= db.maxOpen && len(db.freeConn) == 0 {
+	// Prefer a free connection, if possible.
+	numFree := len(db.freeConn)
+	if strategy == cachedOrNewConn && numFree > 0 {
+		conn := db.freeConn[0]
+		copy(db.freeConn, db.freeConn[1:])
+		db.freeConn = db.freeConn[:numFree-1]
+		conn.inUse = true
+		db.mu.Unlock()
+		return conn, nil
+	}
+
+	// Out of free connections or we were asked not to use one.  If we're not
+	// allowed to open any more connections, make a request and wait.
+	if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
 		// Make the connRequest channel. It's buffered so that the
 		// connectionOpener doesn't block while waiting for the req to be read.
 		req := make(chan connRequest, 1)
@@ -671,15 +694,6 @@
 		return ret.conn, ret.err
 	}
 
-	if c := len(db.freeConn); c > 0 {
-		conn := db.freeConn[0]
-		copy(db.freeConn, db.freeConn[1:])
-		db.freeConn = db.freeConn[:c-1]
-		conn.inUse = true
-		db.mu.Unlock()
-		return conn, nil
-	}
-
 	db.numOpen++ // optimistically
 	db.mu.Unlock()
 	ci, err := db.driver.Open(db.dsn)
@@ -808,8 +822,9 @@
 }
 
 // maxBadConnRetries is the number of maximum retries if the driver returns
-// driver.ErrBadConn to signal a broken connection.
-const maxBadConnRetries = 10
+// driver.ErrBadConn to signal a broken connection before forcing a new
+// connection to be opened.
+const maxBadConnRetries = 2
 
 // Prepare creates a prepared statement for later queries or executions.
 // Multiple queries or executions may be run concurrently from the
@@ -818,22 +833,25 @@
 	var stmt *Stmt
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		stmt, err = db.prepare(query)
+		stmt, err = db.prepare(query, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
+	if err == driver.ErrBadConn {
+		return db.prepare(query, alwaysNewConn)
+	}
 	return stmt, err
 }
 
-func (db *DB) prepare(query string) (*Stmt, error) {
+func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
 	// TODO: check if db.driver supports an optional
 	// driver.Preparer interface and call that instead, if so,
 	// otherwise we make a prepared statement that's bound
 	// to a connection, and to execute this prepared statement
 	// we either need to use this connection (if it's free), else
 	// get a new connection + re-prepare + execute on that one.
-	dc, err := db.conn()
+	dc, err := db.conn(strategy)
 	if err != nil {
 		return nil, err
 	}
@@ -861,16 +879,19 @@
 	var res Result
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		res, err = db.exec(query, args)
+		res, err = db.exec(query, args, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
+	if err == driver.ErrBadConn {
+		return db.exec(query, args, alwaysNewConn)
+	}
 	return res, err
 }
 
-func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
-	dc, err := db.conn()
+func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+	dc, err := db.conn(strategy)
 	if err != nil {
 		return nil, err
 	}
@@ -910,16 +931,19 @@
 	var rows *Rows
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		rows, err = db.query(query, args)
+		rows, err = db.query(query, args, cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
+	if err == driver.ErrBadConn {
+		return db.query(query, args, alwaysNewConn)
+	}
 	return rows, err
 }
 
-func (db *DB) query(query string, args []interface{}) (*Rows, error) {
-	ci, err := db.conn()
+func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+	ci, err := db.conn(strategy)
 	if err != nil {
 		return nil, err
 	}
@@ -998,16 +1022,19 @@
 	var tx *Tx
 	var err error
 	for i := 0; i < maxBadConnRetries; i++ {
-		tx, err = db.begin()
+		tx, err = db.begin(cachedOrNewConn)
 		if err != driver.ErrBadConn {
 			break
 		}
 	}
+	if err == driver.ErrBadConn {
+		return db.begin(alwaysNewConn)
+	}
 	return tx, err
 }
 
-func (db *DB) begin() (tx *Tx, err error) {
-	dc, err := db.conn()
+func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
+	dc, err := db.conn(strategy)
 	if err != nil {
 		return nil, err
 	}
@@ -1396,7 +1423,7 @@
 	s.mu.Unlock()
 
 	// TODO(bradfitz): or always wait for one? make configurable later?
-	dc, err := s.db.conn()
+	dc, err := s.db.conn(cachedOrNewConn)
 	if err != nil {
 		return nil, nil, nil, err
 	}
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 45554c6..94f80a6 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -1085,17 +1085,17 @@
 
 	db.SetMaxOpenConns(3)
 
-	conn0, err := db.conn()
+	conn0, err := db.conn(cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}
 
-	conn1, err := db.conn()
+	conn1, err := db.conn(cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}
 
-	conn2, err := db.conn()
+	conn2, err := db.conn(cachedOrNewConn)
 	if err != nil {
 		t.Fatalf("db open conn fail: %v", err)
 	}
@@ -1385,6 +1385,79 @@
 	}
 }
 
+// Test cases where there's more than maxBadConnRetries bad connections in the
+// pool (issue 8834)
+func TestManyErrBadConn(t *testing.T) {
+	manyErrBadConnSetup := func() *DB {
+		db := newTestDB(t, "people")
+
+		nconn := maxBadConnRetries + 1
+		db.SetMaxIdleConns(nconn)
+		db.SetMaxOpenConns(nconn)
+		// open enough connections
+		func() {
+			for i := 0; i < nconn; i++ {
+				rows, err := db.Query("SELECT|people|age,name|")
+				if err != nil {
+					t.Fatal(err)
+				}
+				defer rows.Close()
+			}
+		}()
+
+		if db.numOpen != nconn {
+			t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
+		} else if len(db.freeConn) != nconn {
+			t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
+		}
+		for _, conn := range db.freeConn {
+			conn.ci.(*fakeConn).stickyBad = true
+		}
+		return db
+	}
+
+	// Query
+	db := manyErrBadConnSetup()
+	defer closeDB(t, db)
+	rows, err := db.Query("SELECT|people|age,name|")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err = rows.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Exec
+	db = manyErrBadConnSetup()
+	defer closeDB(t, db)
+	_, err = db.Exec("INSERT|people|name=Julia,age=19")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Begin
+	db = manyErrBadConnSetup()
+	defer closeDB(t, db)
+	tx, err := db.Begin()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err = tx.Rollback(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Prepare
+	db = manyErrBadConnSetup()
+	defer closeDB(t, db)
+	stmt, err := db.Prepare("SELECT|people|age,name|")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err = stmt.Close(); err != nil {
+		t.Fatal(err)
+	}
+}
+
 // golang.org/issue/5781
 func TestErrBadConnReconnect(t *testing.T) {
 	db := newTestDB(t, "foo")