storage/db: query optimizations for duplicate keys

Queries with multiple parts referring to the same key are now
simplified before sending to the database. If the query can be proven
to return no results, the query is not sent to the database at all.

Change-Id: I2d307d09f463fb0e6e7bd9b9902115916e7ffffa
Reviewed-on: https://go-review.googlesource.com/36953
Run-TryBot: Quentin Smith <quentin@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/storage/db/db.go b/storage/db/db.go
index 89e7cd9..1aae98e 100644
--- a/storage/db/db.go
+++ b/storage/db/db.go
@@ -12,11 +12,11 @@
 	"fmt"
 	"io"
 	"regexp"
+	"sort"
 	"strconv"
 	"strings"
 	"text/template"
 	"time"
-	"unicode"
 
 	"golang.org/x/net/context"
 	"golang.org/x/perf/storage"
@@ -360,30 +360,37 @@
 	return u.tx.Rollback()
 }
 
-// parseQueryPart parses a single query part into a SQL expression and a list of arguments.
-func parseQueryPart(part string) (sql string, args []interface{}, err error) {
-	sepIndex := strings.IndexFunc(part, func(r rune) bool {
-		return r == ':' || r == '>' || r == '<' || unicode.IsSpace(r) || unicode.IsUpper(r)
-	})
-	if sepIndex < 0 {
-		return "", nil, fmt.Errorf("query part %q is missing operator", part)
-	}
-	key, sep, value := part[:sepIndex], part[sepIndex], part[sepIndex+1:]
-	switch sep {
-	case ':':
-		if value == "" {
-			// TODO(quentin): Implement support for searching for missing labels.
-			return "", nil, fmt.Errorf("missing value for query part %q", part)
+// parseQuery parses a query into a slice of SQL subselects and a slice of arguments.
+// The subselects must be joined with INNER JOIN in the order returned.
+func parseQuery(q string) (sql []string, args []interface{}, err error) {
+	var keys []string
+	parts := make(map[string]part)
+	for _, word := range query.SplitWords(q) {
+		p, err := parseWord(word)
+		if err != nil {
+			return nil, nil, err
 		}
-		return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value = ?", []interface{}{key, value}, nil
-	case '>', '<':
-		if sep == '>' && value == "" {
-			// Simplify queries for any value.
-			return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ?", []interface{}{key}, nil
+		if _, ok := parts[p.key]; ok {
+			parts[p.key], err = parts[p.key].merge(p)
+			if err != nil {
+				return nil, nil, err
+			}
+		} else {
+			keys = append(keys, p.key)
+			parts[p.key] = p
 		}
-		return fmt.Sprintf("SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value %c ?", sep), []interface{}{key, value}, nil
 	}
-	return "", nil, fmt.Errorf("query part %q has invalid key", part)
+	// Process each key
+	sort.Strings(keys)
+	for _, key := range keys {
+		s, a, err := parts[key].sql()
+		if err != nil {
+			return nil, nil, err
+		}
+		sql = append(sql, s)
+		args = append(args, a...)
+	}
+	return
 }
 
 // Query searches for results matching the given query string.
@@ -394,33 +401,30 @@
 // key>value - value greater than (useful for dates)
 // key<value - value less than (also useful for dates)
 func (db *DB) Query(q string) *Query {
-	qparts := query.SplitWords(q)
-
 	ret := &Query{q: q}
 
-	var args []interface{}
 	query := "SELECT r.Content FROM "
-	for i, part := range qparts {
+
+	sql, args, err := parseQuery(q)
+	if err != nil {
+		ret.err = err
+		return ret
+	}
+	for i, part := range sql {
 		if i > 0 {
 			query += " INNER JOIN "
 		}
-		partSql, partArgs, err := parseQueryPart(part)
-		ret.err = err
-		if err != nil {
-			return ret
-		}
-		query += fmt.Sprintf("(%s) t%d", partSql, i)
-		args = append(args, partArgs...)
+		query += fmt.Sprintf("(%s) t%d", part, i)
 		if i > 0 {
 			query += " USING (UploadID, RecordID)"
 		}
 	}
 
-	if len(qparts) > 0 {
+	if len(sql) > 0 {
 		query += " LEFT JOIN"
 	}
 	query += " Records r"
-	if len(qparts) > 0 {
+	if len(sql) > 0 {
 		query += " USING (UploadID, RecordID)"
 	}
 
@@ -516,7 +520,7 @@
 	if q.rows != nil {
 		return q.rows.Close()
 	}
-	return q.err
+	return q.Err()
 }
 
 // CountUploads returns the number of uploads in the database.
@@ -562,8 +566,6 @@
 // For each label in extraLabels, one unspecified record's value will be obtained for each upload.
 // If limit is non-zero, only the limit most recent uploads will be returned.
 func (db *DB) ListUploads(q string, extraLabels []string, limit int) *UploadList {
-	qparts := query.SplitWords(q)
-
 	var args []interface{}
 	query := "SELECT j.UploadID, rCount"
 	for i, label := range extraLabels {
@@ -571,26 +573,26 @@
 		args = append(args, label)
 	}
 	query += " FROM (SELECT UploadID, COUNT(*) as rCount FROM "
-	for i, part := range qparts {
+	sql, qArgs, err := parseQuery(q)
+	if err != nil {
+		return &UploadList{err: err}
+	}
+	args = append(args, qArgs...)
+	for i, part := range sql {
 		if i > 0 {
 			query += " INNER JOIN "
 		}
-		partSql, partArgs, err := parseQueryPart(part)
-		if err != nil {
-			return &UploadList{err: err}
-		}
-		query += fmt.Sprintf("(%s) t%d", partSql, i)
-		args = append(args, partArgs...)
+		query += fmt.Sprintf("(%s) t%d", part, i)
 		if i > 0 {
 			query += " USING (UploadID, RecordID)"
 		}
 	}
 
-	if len(qparts) > 0 {
+	if len(sql) > 0 {
 		query += " LEFT JOIN"
 	}
 	query += " Records r"
-	if len(qparts) > 0 {
+	if len(sql) > 0 {
 		query += " USING (UploadID, RecordID)"
 	}
 	query += " GROUP BY UploadID) j LEFT JOIN Uploads u USING (UploadID) ORDER BY u.Day DESC, u.Seq DESC, u.UploadID DESC"
diff --git a/storage/db/db_test.go b/storage/db/db_test.go
index 14f5cc6..4206100 100644
--- a/storage/db/db_test.go
+++ b/storage/db/db_test.go
@@ -267,6 +267,10 @@
 		want []int // nil means we want an error
 	}{
 		{"label0:0", []int{0}},
+		{"label0:1 label0:1 label0<2 label0>0", []int{1}},
+		{"label0>0 label0<2 label0:1 label0:1", []int{1}},
+		{"label0<2 label0<1", []int{0}},
+		{"label0>1021 label0>1022 label1:511", []int{1023}},
 		{"label1:0", []int{0, 1}},
 		{"label0:5 name:Name", []int{5}},
 		{"label0:0 label0:5", []int{}},
@@ -290,6 +294,7 @@
 				return
 			}
 			defer func() {
+				t.Logf("q.Debug: %s", q.Debug())
 				if err := q.Close(); err != nil {
 					t.Errorf("Close: %v", err)
 				}
diff --git a/storage/db/query.go b/storage/db/query.go
new file mode 100644
index 0000000..aeaaaca
--- /dev/null
+++ b/storage/db/query.go
@@ -0,0 +1,156 @@
+// Copyright 2017 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import (
+	"fmt"
+	"io"
+	"strings"
+	"unicode"
+)
+
+// operation is the enum for possible query operations.
+type operation rune
+
+// Query operations.
+// Only equals, lt, and gt can be specified in a query.
+// The order of these operations is used by merge.
+const (
+	equals operation = iota
+	ltgt
+	lt
+	gt
+)
+
+// A part is a single query part with a key, operator, and value.
+type part struct {
+	key      string
+	operator operation
+	// value and value2 hold the values to compare against.
+	value, value2 string
+}
+
+// sepToOperation maps runes to operation values.
+var sepToOperation = map[byte]operation{
+	':': equals,
+	'<': lt,
+	'>': gt,
+}
+
+// parseWord parse a single query part (as returned by SplitWords with quoting and escaping already removed) into a part struct.
+func parseWord(word string) (part, error) {
+	sepIndex := strings.IndexFunc(word, func(r rune) bool {
+		return r == ':' || r == '>' || r == '<' || unicode.IsSpace(r) || unicode.IsUpper(r)
+	})
+	if sepIndex < 0 {
+		return part{}, fmt.Errorf("query part %q is missing operator", word)
+	}
+	key, sep, value := word[:sepIndex], word[sepIndex], word[sepIndex+1:]
+	if oper, ok := sepToOperation[sep]; ok {
+		return part{key, oper, value, ""}, nil
+	}
+	return part{}, fmt.Errorf("query part %q has invalid key", word)
+}
+
+// merge merges two query parts together into a single query part.
+// The keys of the two parts must be equal.
+// If the result is a query part that can never match, io.EOF is returned as the error.
+func (p part) merge(p2 part) (part, error) {
+	if p2.operator < p.operator {
+		// Sort the parts so we only need half the table below.
+		p, p2 = p2, p
+	}
+	switch p.operator {
+	case equals:
+		switch p2.operator {
+		case equals:
+			if p.value == p2.value {
+				return p, nil
+			}
+			return part{}, io.EOF
+		case lt:
+			if p.value < p2.value {
+				return p, nil
+			}
+			return part{}, io.EOF
+		case gt:
+			if p.value > p2.value {
+				return p, nil
+			}
+			return part{}, io.EOF
+		case ltgt:
+			if p.value < p2.value && p.value > p2.value2 {
+				return p, nil
+			}
+			return part{}, io.EOF
+		}
+	case ltgt:
+		switch p2.operator {
+		case ltgt:
+			if p2.value < p.value {
+				p.value = p2.value
+			}
+			if p2.value2 > p.value2 {
+				p.value2 = p2.value2
+			}
+		case lt:
+			if p2.value < p.value {
+				p.value = p2.value
+			}
+		case gt:
+			if p2.value > p.value2 {
+				p.value2 = p2.value
+			}
+		}
+	case lt:
+		switch p2.operator {
+		case lt:
+			if p2.value < p.value {
+				return p2, nil
+			}
+			return p, nil
+		case gt:
+			p = part{p.key, ltgt, p.value, p2.value}
+		}
+	case gt:
+		// p2.operator == gt
+		if p2.value > p.value {
+			return p2, nil
+		}
+		return p, nil
+	}
+	// p.operator == ltgt
+	if p.value <= p.value2 || p.value == "" {
+		return part{}, io.EOF
+	}
+	if p.value2 == "" {
+		return part{p.key, lt, p.value, ""}, nil
+	}
+	return p, nil
+}
+
+// sql returns a SQL expression and a list of arguments for finding records matching p.
+func (p part) sql() (sql string, args []interface{}, err error) {
+	switch p.operator {
+	case equals:
+		if p.value == "" {
+			// TODO(quentin): Implement support for searching for missing labels.
+			return "", nil, fmt.Errorf("missing value for key %q", p.key)
+		}
+		return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value = ?", []interface{}{p.key, p.value}, nil
+	case lt:
+		return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value < ?", []interface{}{p.key, p.value}, nil
+	case gt:
+		if p.value == "" {
+			// Simplify queries for any value.
+			return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ?", []interface{}{p.key}, nil
+		}
+		return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value > ?", []interface{}{p.key, p.value}, nil
+	case ltgt:
+		return "SELECT UploadID, RecordID FROM RecordLabels WHERE Name = ? AND Value < ? AND Value > ?", []interface{}{p.key, p.value, p.value2}, nil
+	default:
+		panic("unknown operator " + string(p.operator))
+	}
+}
diff --git a/storage/db/query_test.go b/storage/db/query_test.go
new file mode 100644
index 0000000..dc63900
--- /dev/null
+++ b/storage/db/query_test.go
@@ -0,0 +1,44 @@
+// Copyright 2017 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import "testing"
+
+func TestParseWord(t *testing.T) {
+	tests := []struct {
+		word    string
+		want    part
+		wantErr bool
+	}{
+		{"key:value", part{"key", equals, "value", ""}, false},
+		{"key>value", part{"key", gt, "value", ""}, false},
+		{"key<value", part{"key", lt, "value", ""}, false},
+		{"bogus query", part{}, true},
+	}
+	for _, test := range tests {
+		t.Run(test.word, func(t *testing.T) {
+			p, err := parseWord(test.word)
+			if test.wantErr {
+				if err == nil {
+					t.Fatalf("have %#v, want error", p)
+				}
+				return
+			}
+			if err != nil {
+				t.Fatalf("have error %v", err)
+			}
+			if p != test.want {
+				t.Fatalf("parseWord = %#v, want %#v", p, test.want)
+			}
+			p, err = p.merge(part{p.key, gt, "", ""})
+			if err != nil {
+				t.Fatalf("failed to merge with noop: %v", err)
+			}
+			if p != test.want {
+				t.Fatalf("merge with noop = %#v, want %#v", p, test.want)
+			}
+		})
+	}
+}