database/sql: add SetMaxOpenConns

Update #4805

Add the ability to set an open connection limit.
Fixed case where the Conn finalCloser was being called with db.mu locked.
Added separate benchmarks for each path for Exec and Query.
Replaced slice based idle pool with list based idle pool.

R=bradfitz
CC=golang-dev
https://golang.org/cl/10726044
diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go
index 4005f15..435d79c 100644
--- a/src/pkg/database/sql/sql_test.go
+++ b/src/pkg/database/sql/sql_test.go
@@ -8,6 +8,7 @@
 	"database/sql/driver"
 	"errors"
 	"fmt"
+	"math/rand"
 	"reflect"
 	"runtime"
 	"strings"
@@ -23,14 +24,12 @@
 	}
 	freedFrom := make(map[dbConn]string)
 	putConnHook = func(db *DB, c *driverConn) {
-		for _, oc := range db.freeConn {
-			if oc == c {
-				// print before panic, as panic may get lost due to conflicting panic
-				// (all goroutines asleep) elsewhere, since we might not unlock
-				// the mutex in freeConn here.
-				println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
-				panic("double free of conn.")
-			}
+		if c.listElem != nil {
+			// print before panic, as panic may get lost due to conflicting panic
+			// (all goroutines asleep) elsewhere, since we might not unlock
+			// the mutex in freeConn here.
+			println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
+			panic("double free of conn.")
 		}
 		freedFrom[dbConn{db, c}] = stack()
 	}
@@ -80,14 +79,15 @@
 			t.Errorf("Error closing fakeConn: %v", err)
 		}
 	})
-	for i, dc := range db.freeConn {
+	for node, i := db.freeConn.Front(), 0; node != nil; node, i = node.Next(), i+1 {
+		dc := node.Value.(*driverConn)
 		if n := len(dc.openStmt); n > 0 {
 			// Just a sanity check. This is legal in
 			// general, but if we make the tests clean up
 			// their statements first, then we can safely
 			// verify this is always zero here, and any
 			// other value is a leak.
-			t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
+			t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, db.freeConn.Len(), n)
 		}
 	}
 	err := db.Close()
@@ -99,10 +99,10 @@
 // numPrepares assumes that db has exactly 1 idle conn and returns
 // its count of calls to Prepare
 func numPrepares(t *testing.T, db *DB) int {
-	if n := len(db.freeConn); n != 1 {
+	if n := db.freeConn.Len(); n != 1 {
 		t.Fatalf("free conns = %d; want 1", n)
 	}
-	return db.freeConn[0].ci.(*fakeConn).numPrepare
+	return (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn).numPrepare
 }
 
 func (db *DB) numDeps() int {
@@ -127,7 +127,7 @@
 func (db *DB) numFreeConns() int {
 	db.mu.Lock()
 	defer db.mu.Unlock()
-	return len(db.freeConn)
+	return db.freeConn.Len()
 }
 
 func (db *DB) dumpDeps(t *testing.T) {
@@ -642,10 +642,10 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	if len(db.freeConn) != 1 {
+	if db.freeConn.Len() != 1 {
 		t.Fatalf("expected 1 free conn")
 	}
-	fakeConn := db.freeConn[0].ci.(*fakeConn)
+	fakeConn := (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn)
 	if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
 		t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
 	}
@@ -841,13 +841,13 @@
 		t.Fatal(err)
 	}
 	tx.Commit()
-	if got := len(db.freeConn); got != 1 {
+	if got := db.freeConn.Len(); got != 1 {
 		t.Errorf("freeConns = %d; want 1", got)
 	}
 
 	db.SetMaxIdleConns(0)
 
-	if got := len(db.freeConn); got != 0 {
+	if got := db.freeConn.Len(); got != 0 {
 		t.Errorf("freeConns after set to zero = %d; want 0", got)
 	}
 
@@ -856,11 +856,146 @@
 		t.Fatal(err)
 	}
 	tx.Commit()
-	if got := len(db.freeConn); got != 0 {
+	if got := db.freeConn.Len(); got != 0 {
 		t.Errorf("freeConns = %d; want 0", got)
 	}
 }
 
+func TestMaxOpenConns(t *testing.T) {
+	if testing.Short() {
+		t.Skip("skipping in short mode")
+	}
+	defer setHookpostCloseConn(nil)
+	setHookpostCloseConn(func(_ *fakeConn, err error) {
+		if err != nil {
+			t.Errorf("Error closing fakeConn: %v", err)
+		}
+	})
+
+	db := newTestDB(t, "magicquery")
+	defer closeDB(t, db)
+
+	driver := db.driver.(*fakeDriver)
+
+	// Force the number of open connections to 0 so we can get an accurate
+	// count for the test
+	db.SetMaxIdleConns(0)
+
+	if g, w := db.numFreeConns(), 0; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(0, time.Second); n > 0 {
+		t.Errorf("number of dependencies = %d; expected 0", n)
+		db.dumpDeps(t)
+	}
+
+	driver.mu.Lock()
+	opens0 := driver.openCount
+	closes0 := driver.closeCount
+	driver.mu.Unlock()
+
+	db.SetMaxIdleConns(10)
+	db.SetMaxOpenConns(10)
+
+	stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Start 50 parallel slow queries.
+	const (
+		nquery      = 50
+		sleepMillis = 25
+		nbatch      = 2
+	)
+	var wg sync.WaitGroup
+	for batch := 0; batch < nbatch; batch++ {
+		for i := 0; i < nquery; i++ {
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+				var op string
+				if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
+					t.Error(err)
+				}
+			}()
+		}
+		// Sleep for twice the expected length of time for the
+		// batch of 50 queries above to finish before starting
+		// the next round.
+		time.Sleep(2 * sleepMillis * time.Millisecond)
+	}
+	wg.Wait()
+
+	if g, w := db.numFreeConns(), 10; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(20, time.Second); n > 20 {
+		t.Errorf("number of dependencies = %d; expected <= 20", n)
+		db.dumpDeps(t)
+	}
+
+	driver.mu.Lock()
+	opens := driver.openCount - opens0
+	closes := driver.closeCount - closes0
+	driver.mu.Unlock()
+
+	if opens > 10 {
+		t.Logf("open calls = %d", opens)
+		t.Logf("close calls = %d", closes)
+		t.Errorf("db connections opened = %d; want <= 10", opens)
+		db.dumpDeps(t)
+	}
+
+	if err := stmt.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	if g, w := db.numFreeConns(), 10; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(10, time.Second); n > 10 {
+		t.Errorf("number of dependencies = %d; expected <= 10", n)
+		db.dumpDeps(t)
+	}
+
+	db.SetMaxOpenConns(5)
+
+	if g, w := db.numFreeConns(), 5; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(5, time.Second); n > 5 {
+		t.Errorf("number of dependencies = %d; expected 0", n)
+		db.dumpDeps(t)
+	}
+
+	db.SetMaxOpenConns(0)
+
+	if g, w := db.numFreeConns(), 5; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(5, time.Second); n > 5 {
+		t.Errorf("number of dependencies = %d; expected 0", n)
+		db.dumpDeps(t)
+	}
+
+	db.SetMaxIdleConns(0)
+
+	if g, w := db.numFreeConns(), 0; g != w {
+		t.Errorf("free conns = %d; want %d", g, w)
+	}
+
+	if n := db.numDepsPollUntil(0, time.Second); n > 0 {
+		t.Errorf("number of dependencies = %d; expected 0", n)
+		db.dumpDeps(t)
+	}
+}
+
 // golang.org/issue/5323
 func TestStmtCloseDeps(t *testing.T) {
 	if testing.Short() {
@@ -926,8 +1061,8 @@
 	driver.mu.Lock()
 	opens := driver.openCount - opens0
 	closes := driver.closeCount - closes0
-	driver.mu.Unlock()
 	openDelta := (driver.openCount - driver.closeCount) - openDelta0
+	driver.mu.Unlock()
 
 	if openDelta > 2 {
 		t.Logf("open calls = %d", opens)
@@ -985,10 +1120,10 @@
 		t.Fatal(err)
 	}
 
-	if len(db.freeConn) != 1 {
-		t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+	if db.freeConn.Len() != 1 {
+		t.Fatalf("expected 1 freeConn; got %d", db.freeConn.Len())
 	}
-	dc := db.freeConn[0]
+	dc := db.freeConn.Front().Value.(*driverConn)
 	if dc.closed {
 		t.Errorf("conn shouldn't be closed")
 	}
@@ -1082,6 +1217,350 @@
 	}
 }
 
+type concurrentTest interface {
+	init(t testing.TB, db *DB)
+	finish(t testing.TB)
+	test(t testing.TB) error
+}
+
+type concurrentDBQueryTest struct {
+	db *DB
+}
+
+func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
+	c.db = db
+}
+
+func (c *concurrentDBQueryTest) finish(t testing.TB) {
+	c.db = nil
+}
+
+func (c *concurrentDBQueryTest) test(t testing.TB) error {
+	rows, err := c.db.Query("SELECT|people|name|")
+	if err != nil {
+		t.Error(err)
+		return err
+	}
+	var name string
+	for rows.Next() {
+		rows.Scan(&name)
+	}
+	rows.Close()
+	return nil
+}
+
+type concurrentDBExecTest struct {
+	db *DB
+}
+
+func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
+	c.db = db
+}
+
+func (c *concurrentDBExecTest) finish(t testing.TB) {
+	c.db = nil
+}
+
+func (c *concurrentDBExecTest) test(t testing.TB) error {
+	_, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+	if err != nil {
+		t.Error(err)
+		return err
+	}
+	return nil
+}
+
+type concurrentStmtQueryTest struct {
+	db   *DB
+	stmt *Stmt
+}
+
+func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.stmt, err = db.Prepare("SELECT|people|name|")
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentStmtQueryTest) finish(t testing.TB) {
+	if c.stmt != nil {
+		c.stmt.Close()
+		c.stmt = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentStmtQueryTest) test(t testing.TB) error {
+	rows, err := c.stmt.Query()
+	if err != nil {
+		t.Errorf("error on query:  %v", err)
+		return err
+	}
+
+	var name string
+	for rows.Next() {
+		rows.Scan(&name)
+	}
+	rows.Close()
+	return nil
+}
+
+type concurrentStmtExecTest struct {
+	db   *DB
+	stmt *Stmt
+}
+
+func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentStmtExecTest) finish(t testing.TB) {
+	if c.stmt != nil {
+		c.stmt.Close()
+		c.stmt = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentStmtExecTest) test(t testing.TB) error {
+	_, err := c.stmt.Exec(3, chrisBirthday)
+	if err != nil {
+		t.Errorf("error on exec:  %v", err)
+		return err
+	}
+	return nil
+}
+
+type concurrentTxQueryTest struct {
+	db *DB
+	tx *Tx
+}
+
+func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.tx, err = c.db.Begin()
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentTxQueryTest) finish(t testing.TB) {
+	if c.tx != nil {
+		c.tx.Rollback()
+		c.tx = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentTxQueryTest) test(t testing.TB) error {
+	rows, err := c.db.Query("SELECT|people|name|")
+	if err != nil {
+		t.Error(err)
+		return err
+	}
+	var name string
+	for rows.Next() {
+		rows.Scan(&name)
+	}
+	rows.Close()
+	return nil
+}
+
+type concurrentTxExecTest struct {
+	db *DB
+	tx *Tx
+}
+
+func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.tx, err = c.db.Begin()
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentTxExecTest) finish(t testing.TB) {
+	if c.tx != nil {
+		c.tx.Rollback()
+		c.tx = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentTxExecTest) test(t testing.TB) error {
+	_, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+	if err != nil {
+		t.Error(err)
+		return err
+	}
+	return nil
+}
+
+type concurrentTxStmtQueryTest struct {
+	db   *DB
+	tx   *Tx
+	stmt *Stmt
+}
+
+func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.tx, err = c.db.Begin()
+	if err != nil {
+		t.Fatal(err)
+	}
+	c.stmt, err = c.tx.Prepare("SELECT|people|name|")
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
+	if c.stmt != nil {
+		c.stmt.Close()
+		c.stmt = nil
+	}
+	if c.tx != nil {
+		c.tx.Rollback()
+		c.tx = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
+	rows, err := c.stmt.Query()
+	if err != nil {
+		t.Errorf("error on query:  %v", err)
+		return err
+	}
+
+	var name string
+	for rows.Next() {
+		rows.Scan(&name)
+	}
+	rows.Close()
+	return nil
+}
+
+type concurrentTxStmtExecTest struct {
+	db   *DB
+	tx   *Tx
+	stmt *Stmt
+}
+
+func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
+	c.db = db
+	var err error
+	c.tx, err = c.db.Begin()
+	if err != nil {
+		t.Fatal(err)
+	}
+	c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
+	if c.stmt != nil {
+		c.stmt.Close()
+		c.stmt = nil
+	}
+	if c.tx != nil {
+		c.tx.Rollback()
+		c.tx = nil
+	}
+	c.db = nil
+}
+
+func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
+	_, err := c.stmt.Exec(3, chrisBirthday)
+	if err != nil {
+		t.Errorf("error on exec:  %v", err)
+		return err
+	}
+	return nil
+}
+
+type concurrentRandomTest struct {
+	tests []concurrentTest
+}
+
+func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
+	c.tests = []concurrentTest{
+		new(concurrentDBQueryTest),
+		new(concurrentDBExecTest),
+		new(concurrentStmtQueryTest),
+		new(concurrentStmtExecTest),
+		new(concurrentTxQueryTest),
+		new(concurrentTxExecTest),
+		new(concurrentTxStmtQueryTest),
+		new(concurrentTxStmtExecTest),
+	}
+	for _, ct := range c.tests {
+		ct.init(t, db)
+	}
+}
+
+func (c *concurrentRandomTest) finish(t testing.TB) {
+	for _, ct := range c.tests {
+		ct.finish(t)
+	}
+}
+
+func (c *concurrentRandomTest) test(t testing.TB) error {
+	ct := c.tests[rand.Intn(len(c.tests))]
+	return ct.test(t)
+}
+
+func doConcurrentTest(t testing.TB, ct concurrentTest) {
+	maxProcs, numReqs := 1, 500
+	if testing.Short() {
+		maxProcs, numReqs = 4, 50
+	}
+	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+
+	db := newTestDB(t, "people")
+	defer closeDB(t, db)
+
+	ct.init(t, db)
+	defer ct.finish(t)
+
+	var wg sync.WaitGroup
+	wg.Add(numReqs)
+
+	reqs := make(chan bool)
+	defer close(reqs)
+
+	for i := 0; i < maxProcs*2; i++ {
+		go func() {
+			for _ = range reqs {
+				err := ct.test(t)
+				if err != nil {
+					wg.Done()
+					continue
+				}
+				wg.Done()
+			}
+		}()
+	}
+
+	for i := 0; i < numReqs; i++ {
+		reqs <- true
+	}
+
+	wg.Wait()
+}
+
 func manyConcurrentQueries(t testing.TB) {
 	maxProcs, numReqs := 16, 500
 	if testing.Short() {
@@ -1178,12 +1657,77 @@
 }
 
 func TestConcurrency(t *testing.T) {
-	manyConcurrentQueries(t)
+	doConcurrentTest(t, new(concurrentDBQueryTest))
+	doConcurrentTest(t, new(concurrentDBExecTest))
+	doConcurrentTest(t, new(concurrentStmtQueryTest))
+	doConcurrentTest(t, new(concurrentStmtExecTest))
+	doConcurrentTest(t, new(concurrentTxQueryTest))
+	doConcurrentTest(t, new(concurrentTxExecTest))
+	doConcurrentTest(t, new(concurrentTxStmtQueryTest))
+	doConcurrentTest(t, new(concurrentTxStmtExecTest))
+	doConcurrentTest(t, new(concurrentRandomTest))
 }
 
-func BenchmarkConcurrency(b *testing.B) {
+func BenchmarkConcurrentDBExec(b *testing.B) {
 	b.ReportAllocs()
+	ct := new(concurrentDBExecTest)
 	for i := 0; i < b.N; i++ {
-		manyConcurrentQueries(b)
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentStmtQuery(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentStmtQueryTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentStmtExec(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentStmtExecTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentTxQuery(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentTxQueryTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentTxExec(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentTxExecTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentTxStmtQueryTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentTxStmtExec(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentTxStmtExecTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
+	}
+}
+
+func BenchmarkConcurrentRandom(b *testing.B) {
+	b.ReportAllocs()
+	ct := new(concurrentRandomTest)
+	for i := 0; i < b.N; i++ {
+		doConcurrentTest(b, ct)
 	}
 }