database/sql: ensure Stmts are correctly closed.
To make sure that there is no resource leak,
I suggest to fix the 'fakedb' driver such as it fails when any
Stmt is not closed.
First, add a check in fakeConn.Close().
Then, fix all missing Stmt.Close()/Rows.Close().
I am not sure that the strategy choose in fakeConn.Prepare/prepare* is ok.
The weak point in this patch is the change in Tx.Query:
- Tests pass without this change,
- I found it by manually analyzing the code,
- I just try to make Tx.Query look like DB.Query.
R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/5759050
diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go
index 8732d02..184e775 100644
--- a/src/pkg/database/sql/fakedb_test.go
+++ b/src/pkg/database/sql/fakedb_test.go
@@ -214,6 +214,9 @@
if c.db == nil {
return errors.New("can't close fakeConn; already closed")
}
+ if c.stmtsMade > c.stmtsClosed {
+ return errors.New("can't close; dangling statement(s)")
+ }
c.db = nil
return nil
}
@@ -250,6 +253,7 @@
// just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 3 {
+ stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
stmt.table = parts[0]
@@ -260,14 +264,17 @@
}
nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 {
+ stmt.Close()
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
column, value := nameVal[0], nameVal[1]
_, ok := c.db.columnType(stmt.table, column)
if !ok {
+ stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
}
if value != "?" {
+ stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column)
}
@@ -280,12 +287,14 @@
// parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 {
+ stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
}
stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") {
nameType := strings.Split(colspec, "=")
if len(nameType) != 2 {
+ stmt.Close()
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
stmt.colName = append(stmt.colName, nameType[0])
@@ -297,17 +306,20 @@
// parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 {
+ stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
}
stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") {
nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 {
+ stmt.Close()
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
column, value := nameVal[0], nameVal[1]
ctype, ok := c.db.columnType(stmt.table, column)
if !ok {
+ stmt.Close()
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
}
stmt.colName = append(stmt.colName, column)
@@ -323,10 +335,12 @@
case "int32":
i, err := strconv.Atoi(value)
if err != nil {
+ stmt.Close()
return nil, errf("invalid conversion to int32 from %q", value)
}
subsetVal = int64(i) // int64 is a subset type, but not int32
default:
+ stmt.Close()
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
}
stmt.colValue = append(stmt.colValue, subsetVal)
@@ -362,6 +376,7 @@
case "INSERT":
return c.prepareInsert(stmt, parts)
default:
+ stmt.Close()
return nil, errf("unsupported command type %q", cmd)
}
return stmt, nil