crypto/rand: properly handle large Read on windows

Use the batched reader to chunk large Read calls on windows to a max of
1 << 31 - 1 bytes. This prevents an infinite loop when trying to read
more than 1 << 32 -1 bytes, due to how RtlGenRandom works.

This change moves the batched function from rand_unix.go to rand.go,
since it is now needed for both windows and unix implementations.

Fixes #52561

Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473
Reviewed-on: https://go-review.googlesource.com/c/go/+/402257
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
index b6248a4..af85b96 100644
--- a/src/crypto/rand/rand.go
+++ b/src/crypto/rand/rand.go
@@ -24,3 +24,21 @@
 func Read(b []byte) (n int, err error) {
 	return io.ReadFull(Reader, b)
 }
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+	return func(out []byte) error {
+		for len(out) > 0 {
+			read := len(out)
+			if read > readMax {
+				read = readMax
+			}
+			if err := f(out[:read]); err != nil {
+				return err
+			}
+			out = out[read:]
+		}
+		return nil
+	}
+}
diff --git a/src/crypto/rand/rand_batched_test.go b/src/crypto/rand/rand_batched_test.go
index dfb9517..8995377 100644
--- a/src/crypto/rand/rand_batched_test.go
+++ b/src/crypto/rand/rand_batched_test.go
@@ -23,8 +23,8 @@
 	}, 5)
 
 	p := make([]byte, 13)
-	if !fillBatched(p) {
-		t.Fatal("batched function returned false")
+	if err := fillBatched(p); err != nil {
+		t.Fatalf("batched function returned error: %s", err)
 	}
 	expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
 	if !bytes.Equal(expected, p) {
@@ -55,8 +55,8 @@
 			max = len(outputMarker)
 		}
 		howMuch := prand.Intn(max + 1)
-		if !fillBatched(outputMarker[:howMuch]) {
-			t.Fatal("batched function returned false")
+		if err := fillBatched(outputMarker[:howMuch]); err != nil {
+			t.Fatalf("batched function returned error: %s", err)
 		}
 		outputMarker = outputMarker[howMuch:]
 	}
@@ -67,14 +67,14 @@
 
 func TestBatchedError(t *testing.T) {
 	b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-	if b(make([]byte, 13)) {
+	if b(make([]byte, 13)) == nil {
 		t.Fatal("batched function should have returned an error")
 	}
 }
 
 func TestBatchedEmpty(t *testing.T) {
 	b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-	if !b(make([]byte, 0)) {
+	if b(make([]byte, 0)) != nil {
 		t.Fatal("empty slice should always return successful")
 	}
 }
diff --git a/src/crypto/rand/rand_unix.go b/src/crypto/rand/rand_unix.go
index 87ba9e3..64b8652 100644
--- a/src/crypto/rand/rand_unix.go
+++ b/src/crypto/rand/rand_unix.go
@@ -40,25 +40,7 @@
 
 // altGetRandom if non-nil specifies an OS-specific function to get
 // urandom-style randomness.
-var altGetRandom func([]byte) (ok bool)
-
-// batched returns a function that calls f to populate a []byte by chunking it
-// into subslices of, at most, readMax bytes.
-func batched(f func([]byte) error, readMax int) func([]byte) bool {
-	return func(out []byte) bool {
-		for len(out) > 0 {
-			read := len(out)
-			if read > readMax {
-				read = readMax
-			}
-			if f(out[:read]) != nil {
-				return false
-			}
-			out = out[read:]
-		}
-		return true
-	}
-}
+var altGetRandom func([]byte) (err error)
 
 func warnBlocked() {
 	println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
@@ -72,7 +54,7 @@
 		t := time.AfterFunc(time.Minute, warnBlocked)
 		defer t.Stop()
 	}
-	if altGetRandom != nil && altGetRandom(b) {
+	if altGetRandom != nil && altGetRandom(b) == nil {
 		return len(b), nil
 	}
 	if atomic.LoadUint32(&r.used) != 2 {
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
index 7379f14..6c0655c 100644
--- a/src/crypto/rand/rand_windows.go
+++ b/src/crypto/rand/rand_windows.go
@@ -9,7 +9,6 @@
 
 import (
 	"internal/syscall/windows"
-	"os"
 )
 
 func init() { Reader = &rngReader{} }
@@ -17,16 +16,11 @@
 type rngReader struct{}
 
 func (r *rngReader) Read(b []byte) (n int, err error) {
-	// RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate.
-	inputLen := uint32(len(b))
-
-	if inputLen == 0 {
-		return 0, nil
+	// RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+	// most 1<<31-1 bytes at a time so that  this works the same on 32-bit
+	// and 64-bit systems.
+	if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
+		return 0, err
 	}
-
-	err = windows.RtlGenRandom(b)
-	if err != nil {
-		return 0, os.NewSyscallError("RtlGenRandom", err)
-	}
-	return int(inputLen), nil
+	return len(b), nil
 }