storage/db: add ReplaceUpload for use in reindexing

ReplaceUpload removes any records for an upload so that those records
can be reinserted in the database.

Change-Id: I8fa0701b72c3ace380d3c7922df0c17b81a0d426
Reviewed-on: https://go-review.googlesource.com/35257
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/storage/db/db.go b/storage/db/db.go
index bbbe812..2c99f3c 100644
--- a/storage/db/db.go
+++ b/storage/db/db.go
@@ -12,6 +12,7 @@
 	"errors"
 	"fmt"
 	"io"
+	"regexp"
 	"strconv"
 	"strings"
 	"text/template"
@@ -29,9 +30,11 @@
 type DB struct {
 	sql *sql.DB // underlying database connection
 	// prepared statements
-	lastUpload   *sql.Stmt
-	insertUpload *sql.Stmt
-	insertRecord *sql.Stmt
+	lastUpload    *sql.Stmt
+	insertUpload  *sql.Stmt
+	insertRecord  *sql.Stmt
+	checkUpload   *sql.Stmt
+	deleteRecords *sql.Stmt
 }
 
 // OpenSQL creates a DB backed by a SQL database. The parameters are
@@ -142,6 +145,14 @@
 	if err != nil {
 		return err
 	}
+	db.checkUpload, err = db.sql.Prepare("SELECT 1 FROM Uploads WHERE UploadID = ?")
+	if err != nil {
+		return err
+	}
+	db.deleteRecords, err = db.sql.Prepare("DELETE FROM Records WHERE UploadID = ?")
+	if err != nil {
+		return err
+	}
 	return nil
 }
 
@@ -162,6 +173,42 @@
 // now is a hook for testing
 var now = time.Now
 
+// ReplaceUpload removes the records associated with id if any and
+// allows insertion of new records.
+func (db *DB) ReplaceUpload(id string) (*Upload, error) {
+	if _, err := db.deleteRecords.Exec(id); err != nil {
+		return nil, err
+	}
+	var found bool
+	err := db.checkUpload.QueryRow(id).Scan(&found)
+	switch err {
+	case sql.ErrNoRows:
+		var day sql.NullString
+		var num sql.NullInt64
+		if m := regexp.MustCompile(`^(\d+)\.(\d+)$`).FindStringSubmatch(id); m != nil {
+			day.Valid, num.Valid = true, true
+			day.String = m[1]
+			num.Int64, _ = strconv.ParseInt(m[2], 10, 64)
+		}
+		if _, err := db.insertUpload.Exec(id, day, num); err != nil {
+			return nil, err
+		}
+	case nil:
+	default:
+		return nil, err
+	}
+	tx, err := db.sql.Begin()
+	if err != nil {
+		return nil, err
+	}
+	u := &Upload{
+		ID: id,
+		db: db,
+		tx: tx,
+	}
+	return u, nil
+}
+
 // NewUpload returns an upload for storing new files.
 // All records written to the Upload will have the same upload ID.
 func (db *DB) NewUpload(ctx context.Context) (*Upload, error) {
@@ -210,11 +257,12 @@
 	if err != nil {
 		return nil, err
 	}
-	return &Upload{
+	u := &Upload{
 		ID: id,
 		db: db,
 		tx: utx,
-	}, nil
+	}
+	return u, nil
 }
 
 // InsertRecord inserts a single record in an existing upload.
@@ -352,7 +400,7 @@
 // Query is the result of a query.
 // Use Next to advance through the rows, making sure to call Close when done:
 //
-//   q, err := db.Query("key:value")
+//   q := db.Query("key:value")
 //   defer q.Close()
 //   for q.Next() {
 //     res := q.Result()
diff --git a/storage/db/db_test.go b/storage/db/db_test.go
index 13dd630..13ee26a 100644
--- a/storage/db/db_test.go
+++ b/storage/db/db_test.go
@@ -5,8 +5,12 @@
 package db_test
 
 import (
+	"bytes"
 	"context"
 	"fmt"
+	"io/ioutil"
+	"os"
+	"os/exec"
 	"reflect"
 	"strings"
 	"testing"
@@ -79,6 +83,95 @@
 	}
 }
 
+// checkQueryResults performs a query on db and verifies that the
+// results as printed by BenchmarkPrinter are equal to results.
+func checkQueryResults(t *testing.T, db *DB, query, results string) {
+	q := db.Query(query)
+	defer q.Close()
+
+	var buf bytes.Buffer
+	bp := benchfmt.NewPrinter(&buf)
+
+	for q.Next() {
+		if err := bp.Print(q.Result()); err != nil {
+			t.Fatalf("Print: %v", err)
+		}
+	}
+	if err := q.Err(); err != nil {
+		t.Fatalf("Err: %v", err)
+	}
+	if diff := diff(buf.String(), results); diff != "" {
+		t.Errorf("wrong results: (- have/+ want)\n%s", diff)
+	}
+}
+
+// TestReplaceUpload verifies that the expected number of rows exist after replacing an upload.
+func TestReplaceUpload(t *testing.T) {
+	SetNow(time.Unix(0, 0))
+	defer SetNow(time.Time{})
+	db, cleanup := dbtest.NewDB(t)
+	defer cleanup()
+
+	ctx := context.Background()
+
+	r := &benchfmt.Result{
+		benchfmt.Labels{"key": "value"},
+		nil,
+		1,
+		"BenchmarkName 1 ns/op",
+	}
+	u, err := db.NewUpload(ctx)
+	if err != nil {
+		t.Fatalf("NewUpload: %v", err)
+	}
+	r.Labels["uploadid"] = u.ID
+	for _, num := range []string{"1", "2"} {
+		r.Labels["num"] = num
+		if err := u.InsertRecord(r); err != nil {
+			t.Fatalf("InsertRecord: %v", err)
+		}
+	}
+
+	if err := u.Commit(); err != nil {
+		t.Fatalf("Commit: %v", err)
+	}
+
+	checkQueryResults(t, db, "key:value",
+		`key: value
+num: 1
+uploadid: 19700101.1
+BenchmarkName 1 ns/op
+num: 2
+BenchmarkName 1 ns/op
+`)
+
+	r.Labels["num"] = "3"
+
+	for _, uploadid := range []string{u.ID, "new"} {
+		u, err := db.ReplaceUpload(uploadid)
+		if err != nil {
+			t.Fatalf("ReplaceUpload: %v", err)
+		}
+		r.Labels["uploadid"] = u.ID
+		if err := u.InsertRecord(r); err != nil {
+			t.Fatalf("InsertRecord: %v", err)
+		}
+
+		if err := u.Commit(); err != nil {
+			t.Fatalf("Commit: %v", err)
+		}
+	}
+
+	checkQueryResults(t, db, "key:value",
+		`key: value
+num: 3
+uploadid: 19700101.1
+BenchmarkName 1 ns/op
+uploadid: new
+BenchmarkName 1 ns/op
+`)
+}
+
 // TestNewUpload verifies that NewUpload and InsertRecord wrote the correct rows to the database.
 func TestNewUpload(t *testing.T) {
 	SetNow(time.Unix(0, 0))
@@ -218,3 +311,36 @@
 		})
 	}
 }
+
+// diff returns the output of unified diff on s1 and s2. If the result
+// is non-empty, the strings differ or the diff command failed.
+func diff(s1, s2 string) string {
+	f1, err := ioutil.TempFile("", "benchfmt_test")
+	if err != nil {
+		return err.Error()
+	}
+	defer os.Remove(f1.Name())
+	defer f1.Close()
+
+	f2, err := ioutil.TempFile("", "benchfmt_test")
+	if err != nil {
+		return err.Error()
+	}
+	defer os.Remove(f2.Name())
+	defer f2.Close()
+
+	f1.Write([]byte(s1))
+	f2.Write([]byte(s2))
+
+	data, err := exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
+	if len(data) > 0 {
+		// diff exits with a non-zero status when the files don't match.
+		// Ignore that failure as long as we get output.
+		err = nil
+	}
+	if err != nil {
+		data = append(data, []byte(err.Error())...)
+	}
+	return string(data)
+
+}