slices: rework the APIs of BinarySearch*

For golang/go#50340

Change-Id: If115b2b66d463d5f3788d017924f8dd38867551c
Reviewed-on: https://go-review.googlesource.com/c/exp/+/395414
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Eli Bendersky‎ <eliben@golang.org>
diff --git a/slices/sort.go b/slices/sort.go
index b2035ab..4a57758 100644
--- a/slices/sort.go
+++ b/slices/sort.go
@@ -46,25 +46,34 @@
 	return true
 }
 
-// BinarySearch searches for target in a sorted slice and returns the smallest
-// index at which target is found. If the target is not found, the index at
-// which it could be inserted into the slice is returned; therefore, if the
-// intention is to find target itself a separate check for equality with the
-// element at the returned index is required.
-func BinarySearch[E constraints.Ordered](x []E, target E) int {
-	return search(len(x), func(i int) bool { return x[i] >= target })
+// BinarySearch searches for target in a sorted slice and returns the position
+// where target is found, or the position where target would appear in the
+// sort order; it also returns a bool saying whether the target is really found
+// in the slice. The slice must be sorted in increasing order.
+func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) {
+	// search returns the leftmost position where f returns true, or len(x) if f
+	// returns false for all x. This is the insertion position for target in x,
+	// and could point to an element that's either == target or not.
+	pos := search(len(x), func(i int) bool { return x[i] >= target })
+	if pos >= len(x) || x[pos] != target {
+		return pos, false
+	} else {
+		return pos, true
+	}
 }
 
-// BinarySearchFunc uses binary search to find and return the smallest index i
-// in [0, n) at which ok(i) is true, assuming that on the range [0, n),
-// ok(i) == true implies ok(i+1) == true. That is, BinarySearchFunc requires
-// that ok is false for some (possibly empty) prefix of the input range [0, n)
-// and then true for the (possibly empty) remainder; BinarySearchFunc returns
-// the first true index. If there is no such index, BinarySearchFunc returns n.
-// (Note that the "not found" return value is not -1 as in, for instance,
-// strings.Index.) Search calls ok(i) only for i in the range [0, n).
-func BinarySearchFunc[E any](x []E, ok func(E) bool) int {
-	return search(len(x), func(i int) bool { return ok(x[i]) })
+// BinarySearchFunc works like BinarySearch, but uses a custom comparison
+// function. The slice must be sorted in increasing order, where "increasing" is
+// defined by cmp. cmp(a, b) is expected to return an integer comparing the two
+// parameters: 0 if a == b, a negative number if a < b and a positive number if
+// a > b.
+func BinarySearchFunc[E any](x []E, target E, cmp func(E, E) int) (int, bool) {
+	pos := search(len(x), func(i int) bool { return cmp(x[i], target) >= 0 })
+	if pos >= len(x) || cmp(x[pos], target) != 0 {
+		return pos, false
+	} else {
+		return pos, true
+	}
 }
 
 // maxDepth returns a threshold at which quicksort should switch
diff --git a/slices/sort_test.go b/slices/sort_test.go
index 4f3145a..3a92579 100644
--- a/slices/sort_test.go
+++ b/slices/sort_test.go
@@ -7,6 +7,8 @@
 import (
 	"math"
 	"math/rand"
+	"strconv"
+	"strings"
 	"testing"
 )
 
@@ -151,31 +153,112 @@
 }
 
 func TestBinarySearch(t *testing.T) {
-	data := []string{"aa", "ad", "ca", "xy"}
+	str1 := []string{"foo"}
+	str2 := []string{"ab", "ca"}
+	str3 := []string{"mo", "qo", "vo"}
+	str4 := []string{"ab", "ad", "ca", "xy"}
+
+	// slice with repeating elements
+	strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
+
+	// slice with all element equal
+	strSame := []string{"xx", "xx", "xx"}
+
 	tests := []struct {
-		target string
-		want   int
+		data      []string
+		target    string
+		wantPos   int
+		wantFound bool
 	}{
-		{"aa", 0},
-		{"ab", 1},
-		{"ad", 1},
-		{"ax", 2},
-		{"ca", 2},
-		{"cc", 3},
-		{"dd", 3},
-		{"xy", 3},
-		{"zz", 4},
+		{[]string{}, "foo", 0, false},
+		{[]string{}, "", 0, false},
+
+		{str1, "foo", 0, true},
+		{str1, "bar", 0, false},
+		{str1, "zx", 1, false},
+
+		{str2, "aa", 0, false},
+		{str2, "ab", 0, true},
+		{str2, "ad", 1, false},
+		{str2, "ca", 1, true},
+		{str2, "ra", 2, false},
+
+		{str3, "bb", 0, false},
+		{str3, "mo", 0, true},
+		{str3, "nb", 1, false},
+		{str3, "qo", 1, true},
+		{str3, "tr", 2, false},
+		{str3, "vo", 2, true},
+		{str3, "xr", 3, false},
+
+		{str4, "aa", 0, false},
+		{str4, "ab", 0, true},
+		{str4, "ac", 1, false},
+		{str4, "ad", 1, true},
+		{str4, "ax", 2, false},
+		{str4, "ca", 2, true},
+		{str4, "cc", 3, false},
+		{str4, "dd", 3, false},
+		{str4, "xy", 3, true},
+		{str4, "zz", 4, false},
+
+		{strRepeats, "da", 2, true},
+		{strRepeats, "db", 5, false},
+		{strRepeats, "ma", 6, true},
+		{strRepeats, "mb", 8, false},
+
+		{strSame, "xx", 0, true},
+		{strSame, "ab", 0, false},
+		{strSame, "zz", 3, false},
 	}
 	for _, tt := range tests {
 		t.Run(tt.target, func(t *testing.T) {
-			i := BinarySearch(data, tt.target)
-			if i != tt.want {
-				t.Errorf("BinarySearch want %d, got %d", tt.want, i)
+			{
+				pos, found := BinarySearch(tt.data, tt.target)
+				if pos != tt.wantPos || found != tt.wantFound {
+					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+				}
 			}
 
-			j := BinarySearchFunc(data, func(s string) bool { return s >= tt.target })
-			if j != tt.want {
-				t.Errorf("BinarySearchFunc want %d, got %d", tt.want, j)
+			{
+				pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
+				if pos != tt.wantPos || found != tt.wantFound {
+					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+				}
+			}
+		})
+	}
+}
+
+func TestBinarySearchInts(t *testing.T) {
+	data := []int{20, 30, 40, 50, 60, 70, 80, 90}
+	tests := []struct {
+		target    int
+		wantPos   int
+		wantFound bool
+	}{
+		{20, 0, true},
+		{23, 1, false},
+		{43, 3, false},
+		{80, 6, true},
+	}
+	for _, tt := range tests {
+		t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
+			{
+				pos, found := BinarySearch(data, tt.target)
+				if pos != tt.wantPos || found != tt.wantFound {
+					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+				}
+			}
+
+			{
+				cmp := func(a, b int) int {
+					return a - b
+				}
+				pos, found := BinarySearchFunc(data, tt.target, cmp)
+				if pos != tt.wantPos || found != tt.wantFound {
+					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
+				}
 			}
 		})
 	}