slices: rework the APIs of BinarySearch*
For golang/go#50340
Change-Id: If115b2b66d463d5f3788d017924f8dd38867551c
Reviewed-on: https://go-review.googlesource.com/c/exp/+/395414
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Eli Bendersky <eliben@golang.org>
diff --git a/slices/sort.go b/slices/sort.go
index b2035ab..4a57758 100644
--- a/slices/sort.go
+++ b/slices/sort.go
@@ -46,25 +46,34 @@
return true
}
-// BinarySearch searches for target in a sorted slice and returns the smallest
-// index at which target is found. If the target is not found, the index at
-// which it could be inserted into the slice is returned; therefore, if the
-// intention is to find target itself a separate check for equality with the
-// element at the returned index is required.
-func BinarySearch[E constraints.Ordered](x []E, target E) int {
- return search(len(x), func(i int) bool { return x[i] >= target })
+// BinarySearch searches for target in a sorted slice and returns the position
+// where target is found, or the position where target would appear in the
+// sort order; it also returns a bool saying whether the target is really found
+// in the slice. The slice must be sorted in increasing order.
+func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) {
+ // search returns the leftmost position where f returns true, or len(x) if f
+ // returns false for all x. This is the insertion position for target in x,
+ // and could point to an element that's either == target or not.
+ pos := search(len(x), func(i int) bool { return x[i] >= target })
+ if pos >= len(x) || x[pos] != target {
+ return pos, false
+ } else {
+ return pos, true
+ }
}
-// BinarySearchFunc uses binary search to find and return the smallest index i
-// in [0, n) at which ok(i) is true, assuming that on the range [0, n),
-// ok(i) == true implies ok(i+1) == true. That is, BinarySearchFunc requires
-// that ok is false for some (possibly empty) prefix of the input range [0, n)
-// and then true for the (possibly empty) remainder; BinarySearchFunc returns
-// the first true index. If there is no such index, BinarySearchFunc returns n.
-// (Note that the "not found" return value is not -1 as in, for instance,
-// strings.Index.) Search calls ok(i) only for i in the range [0, n).
-func BinarySearchFunc[E any](x []E, ok func(E) bool) int {
- return search(len(x), func(i int) bool { return ok(x[i]) })
+// BinarySearchFunc works like BinarySearch, but uses a custom comparison
+// function. The slice must be sorted in increasing order, where "increasing" is
+// defined by cmp. cmp(a, b) is expected to return an integer comparing the two
+// parameters: 0 if a == b, a negative number if a < b and a positive number if
+// a > b.
+func BinarySearchFunc[E any](x []E, target E, cmp func(E, E) int) (int, bool) {
+ pos := search(len(x), func(i int) bool { return cmp(x[i], target) >= 0 })
+ if pos >= len(x) || cmp(x[pos], target) != 0 {
+ return pos, false
+ } else {
+ return pos, true
+ }
}
// maxDepth returns a threshold at which quicksort should switch
diff --git a/slices/sort_test.go b/slices/sort_test.go
index 4f3145a..3a92579 100644
--- a/slices/sort_test.go
+++ b/slices/sort_test.go
@@ -7,6 +7,8 @@
import (
"math"
"math/rand"
+ "strconv"
+ "strings"
"testing"
)
@@ -151,31 +153,112 @@
}
func TestBinarySearch(t *testing.T) {
- data := []string{"aa", "ad", "ca", "xy"}
+ str1 := []string{"foo"}
+ str2 := []string{"ab", "ca"}
+ str3 := []string{"mo", "qo", "vo"}
+ str4 := []string{"ab", "ad", "ca", "xy"}
+
+ // slice with repeating elements
+ strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
+
+ // slice with all element equal
+ strSame := []string{"xx", "xx", "xx"}
+
tests := []struct {
- target string
- want int
+ data []string
+ target string
+ wantPos int
+ wantFound bool
}{
- {"aa", 0},
- {"ab", 1},
- {"ad", 1},
- {"ax", 2},
- {"ca", 2},
- {"cc", 3},
- {"dd", 3},
- {"xy", 3},
- {"zz", 4},
+ {[]string{}, "foo", 0, false},
+ {[]string{}, "", 0, false},
+
+ {str1, "foo", 0, true},
+ {str1, "bar", 0, false},
+ {str1, "zx", 1, false},
+
+ {str2, "aa", 0, false},
+ {str2, "ab", 0, true},
+ {str2, "ad", 1, false},
+ {str2, "ca", 1, true},
+ {str2, "ra", 2, false},
+
+ {str3, "bb", 0, false},
+ {str3, "mo", 0, true},
+ {str3, "nb", 1, false},
+ {str3, "qo", 1, true},
+ {str3, "tr", 2, false},
+ {str3, "vo", 2, true},
+ {str3, "xr", 3, false},
+
+ {str4, "aa", 0, false},
+ {str4, "ab", 0, true},
+ {str4, "ac", 1, false},
+ {str4, "ad", 1, true},
+ {str4, "ax", 2, false},
+ {str4, "ca", 2, true},
+ {str4, "cc", 3, false},
+ {str4, "dd", 3, false},
+ {str4, "xy", 3, true},
+ {str4, "zz", 4, false},
+
+ {strRepeats, "da", 2, true},
+ {strRepeats, "db", 5, false},
+ {strRepeats, "ma", 6, true},
+ {strRepeats, "mb", 8, false},
+
+ {strSame, "xx", 0, true},
+ {strSame, "ab", 0, false},
+ {strSame, "zz", 3, false},
}
for _, tt := range tests {
t.Run(tt.target, func(t *testing.T) {
- i := BinarySearch(data, tt.target)
- if i != tt.want {
- t.Errorf("BinarySearch want %d, got %d", tt.want, i)
+ {
+ pos, found := BinarySearch(tt.data, tt.target)
+ if pos != tt.wantPos || found != tt.wantFound {
+ t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+ }
}
- j := BinarySearchFunc(data, func(s string) bool { return s >= tt.target })
- if j != tt.want {
- t.Errorf("BinarySearchFunc want %d, got %d", tt.want, j)
+ {
+ pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
+ if pos != tt.wantPos || found != tt.wantFound {
+ t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+ }
+ }
+ })
+ }
+}
+
+func TestBinarySearchInts(t *testing.T) {
+ data := []int{20, 30, 40, 50, 60, 70, 80, 90}
+ tests := []struct {
+ target int
+ wantPos int
+ wantFound bool
+ }{
+ {20, 0, true},
+ {23, 1, false},
+ {43, 3, false},
+ {80, 6, true},
+ }
+ for _, tt := range tests {
+ t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
+ {
+ pos, found := BinarySearch(data, tt.target)
+ if pos != tt.wantPos || found != tt.wantFound {
+ t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+ }
+ }
+
+ {
+ cmp := func(a, b int) int {
+ return a - b
+ }
+ pos, found := BinarySearchFunc(data, tt.target, cmp)
+ if pos != tt.wantPos || found != tt.wantFound {
+ t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+ }
}
})
}