crypto/rand: properly handle large Read on windows
Use the batched reader to chunk large Read calls on windows to a max of
1 << 31 - 1 bytes. This prevents an infinite loop when trying to read
more than 1 << 32 -1 bytes, due to how RtlGenRandom works.
This change moves the batched function from rand_unix.go to rand.go,
since it is now needed for both windows and unix implementations.
Fixes #52561
Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473
Reviewed-on: https://go-review.googlesource.com/c/go/+/402257
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
index b6248a4..af85b96 100644
--- a/src/crypto/rand/rand.go
+++ b/src/crypto/rand/rand.go
@@ -24,3 +24,21 @@
func Read(b []byte) (n int, err error) {
return io.ReadFull(Reader, b)
}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
+}
diff --git a/src/crypto/rand/rand_batched_test.go b/src/crypto/rand/rand_batched_test.go
index dfb9517..8995377 100644
--- a/src/crypto/rand/rand_batched_test.go
+++ b/src/crypto/rand/rand_batched_test.go
@@ -23,8 +23,8 @@
}, 5)
p := make([]byte, 13)
- if !fillBatched(p) {
- t.Fatal("batched function returned false")
+ if err := fillBatched(p); err != nil {
+ t.Fatalf("batched function returned error: %s", err)
}
expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
if !bytes.Equal(expected, p) {
@@ -55,8 +55,8 @@
max = len(outputMarker)
}
howMuch := prand.Intn(max + 1)
- if !fillBatched(outputMarker[:howMuch]) {
- t.Fatal("batched function returned false")
+ if err := fillBatched(outputMarker[:howMuch]); err != nil {
+ t.Fatalf("batched function returned error: %s", err)
}
outputMarker = outputMarker[howMuch:]
}
@@ -67,14 +67,14 @@
func TestBatchedError(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
- if b(make([]byte, 13)) {
+ if b(make([]byte, 13)) == nil {
t.Fatal("batched function should have returned an error")
}
}
func TestBatchedEmpty(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
- if !b(make([]byte, 0)) {
+ if b(make([]byte, 0)) != nil {
t.Fatal("empty slice should always return successful")
}
}
diff --git a/src/crypto/rand/rand_unix.go b/src/crypto/rand/rand_unix.go
index 87ba9e3..64b8652 100644
--- a/src/crypto/rand/rand_unix.go
+++ b/src/crypto/rand/rand_unix.go
@@ -40,25 +40,7 @@
// altGetRandom if non-nil specifies an OS-specific function to get
// urandom-style randomness.
-var altGetRandom func([]byte) (ok bool)
-
-// batched returns a function that calls f to populate a []byte by chunking it
-// into subslices of, at most, readMax bytes.
-func batched(f func([]byte) error, readMax int) func([]byte) bool {
- return func(out []byte) bool {
- for len(out) > 0 {
- read := len(out)
- if read > readMax {
- read = readMax
- }
- if f(out[:read]) != nil {
- return false
- }
- out = out[read:]
- }
- return true
- }
-}
+var altGetRandom func([]byte) (err error)
func warnBlocked() {
println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
@@ -72,7 +54,7 @@
t := time.AfterFunc(time.Minute, warnBlocked)
defer t.Stop()
}
- if altGetRandom != nil && altGetRandom(b) {
+ if altGetRandom != nil && altGetRandom(b) == nil {
return len(b), nil
}
if atomic.LoadUint32(&r.used) != 2 {
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
index 7379f14..6c0655c 100644
--- a/src/crypto/rand/rand_windows.go
+++ b/src/crypto/rand/rand_windows.go
@@ -9,7 +9,6 @@
import (
"internal/syscall/windows"
- "os"
)
func init() { Reader = &rngReader{} }
@@ -17,16 +16,11 @@
type rngReader struct{}
func (r *rngReader) Read(b []byte) (n int, err error) {
- // RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate.
- inputLen := uint32(len(b))
-
- if inputLen == 0 {
- return 0, nil
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
+ return 0, err
}
-
- err = windows.RtlGenRandom(b)
- if err != nil {
- return 0, os.NewSyscallError("RtlGenRandom", err)
- }
- return int(inputLen), nil
+ return len(b), nil
}