encoding/protodelim: If UnmarshalFrom gets a bufio.Reader, try to reuse its buffer instead of creating a new one

When unmarshalling many messages, this reduces the amount of memory allocated and saves CPU time.

Change-Id: I440b8b223319ba2ed31ce559c125b1d640d5880c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/491596
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Lasse Folger <lassefolger@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/protodelim/protodelim.go b/encoding/protodelim/protodelim.go
index e2b6cd4..75a50cb 100644
--- a/encoding/protodelim/protodelim.go
+++ b/encoding/protodelim/protodelim.go
@@ -6,6 +6,7 @@
 package protodelim
 
 import (
+	"bufio"
 	"encoding/binary"
 	"fmt"
 	"io"
@@ -116,8 +117,23 @@
 		return errors.Wrap(&SizeTooLargeError{Size: size, MaxSize: uint64(maxSize)}, "")
 	}
 
-	b := make([]byte, size)
-	_, err := io.ReadFull(r, b)
+	var b []byte
+	var err error
+	if br, ok := r.(*bufio.Reader); ok {
+		// Use the []byte from the bufio.Reader instead of having to allocate one.
+		// This reduces CPU usage and allocated bytes.
+		b, err = br.Peek(int(size))
+		if err == nil {
+			defer br.Discard(int(size))
+		} else {
+			b = nil
+		}
+	}
+	if b == nil {
+		b = make([]byte, size)
+		_, err = io.ReadFull(r, b)
+	}
+
 	if err == io.EOF {
 		return io.ErrUnexpectedEOF
 	}
diff --git a/encoding/protodelim/protodelim_test.go b/encoding/protodelim/protodelim_test.go
index 9c2458b..b9f8386 100644
--- a/encoding/protodelim/protodelim_test.go
+++ b/encoding/protodelim/protodelim_test.go
@@ -38,25 +38,130 @@
 		}
 	}
 
-	// Read and collect messages from buf.
-	var got []*test3.TestAllTypes
-	r := bufio.NewReader(buf)
-	for {
-		m := &test3.TestAllTypes{}
-		err := protodelim.UnmarshalFrom(r, m)
-		if errors.Is(err, io.EOF) {
-			break
-		}
+	for _, tc := range []struct {
+		name   string
+		reader protodelim.Reader
+	}{
+		{name: "defaultbuffer", reader: bufio.NewReader(bytes.NewBuffer(buf.Bytes()))},
+		{name: "smallbuffer", reader: bufio.NewReaderSize(bytes.NewBuffer(buf.Bytes()), 0)},
+		{name: "largebuffer", reader: bufio.NewReaderSize(bytes.NewBuffer(buf.Bytes()), 1<<20)},
+		{name: "notbufio", reader: notBufioReader{bufio.NewReader(bytes.NewBuffer(buf.Bytes()))}},
+	} {
+		t.Run(tc.name, func(t *testing.T) {
+			// Read and collect messages from buf.
+			var got []*test3.TestAllTypes
+			for {
+				m := &test3.TestAllTypes{}
+				err := protodelim.UnmarshalFrom(tc.reader, m)
+				if errors.Is(err, io.EOF) {
+					break
+				}
+				if err != nil {
+					t.Errorf("protodelim.UnmarshalFrom(_) = %v", err)
+					continue
+				}
+				got = append(got, m)
+			}
+
+			want := msgs
+			if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
+				t.Errorf("Unmarshaler collected messages: diff -want +got = %s", diff)
+			}
+		})
+	}
+}
+
+// Just a wrapper so that UnmarshalFrom doesn't recognize this as a bufio.Reader
+type notBufioReader struct {
+	*bufio.Reader
+}
+
+func TestUnmarshalFromBufioAllocations(t *testing.T) {
+	// Use a proto which won't require an additional allocations during unmarshalling.
+	// Write to buf
+	buf := &bytes.Buffer{}
+	m := &test3.TestAllTypes{SingularInt32: 1}
+	if n, err := protodelim.MarshalTo(buf, m); err != nil {
+		t.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", m, n, err)
+	}
+	reader := bufio.NewReaderSize(nil, 1<<20)
+	got := &test3.TestAllTypes{}
+
+	allocs := testing.AllocsPerRun(5, func() {
+		// Read from buf.
+		reader.Reset(bytes.NewBuffer(buf.Bytes()))
+		err := protodelim.UnmarshalFrom(reader, got)
 		if err != nil {
-			t.Errorf("protodelim.UnmarshalFrom(_) = %v", err)
-			continue
+			t.Fatalf("protodelim.UnmarshalFrom(_) = %v", err)
 		}
-		got = append(got, m)
+	})
+	if allocs != 1 {
+		// bytes.NewBuffer should be the only allocation.
+		t.Errorf("Got %v allocs. Wanted 1", allocs)
 	}
 
-	want := msgs
-	if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
-		t.Errorf("Unmarshaler collected messages: diff -want +got = %s", diff)
+	if diff := cmp.Diff(m, got, protocmp.Transform()); diff != "" {
+		t.Errorf("Unmarshaler read: diff -want +got = %s", diff)
+	}
+}
+
+func BenchmarkUnmarshalFrom(b *testing.B) {
+	var manyInt32 []int32
+	for i := int32(0); i < 10000; i++ {
+		manyInt32 = append(manyInt32, i)
+	}
+	var msgs []*test3.TestAllTypes
+	for i := 0; i < 10; i++ {
+		msgs = append(msgs, &test3.TestAllTypes{RepeatedInt32: manyInt32})
+	}
+
+	buf := &bytes.Buffer{}
+
+	// Write all messages to buf.
+	for _, m := range msgs {
+		if n, err := protodelim.MarshalTo(buf, m); err != nil {
+			b.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", m, n, err)
+		}
+	}
+	bufBytes := buf.Bytes()
+
+	type resetReader interface {
+		protodelim.Reader
+		Reset(io.Reader)
+	}
+
+	for _, tc := range []struct {
+		name   string
+		reader resetReader
+	}{
+		{name: "bufio1mib", reader: bufio.NewReaderSize(nil, 1<<20)},
+		{name: "bufio16mib", reader: bufio.NewReaderSize(nil, 1<<24)},
+		{name: "notbufio1mib", reader: notBufioReader{bufio.NewReaderSize(nil, 1<<20)}},
+		{name: "notbufio16mib", reader: notBufioReader{bufio.NewReaderSize(nil, 1<<24)}},
+	} {
+		b.Run(tc.name, func(b *testing.B) {
+			b.ReportAllocs()
+			b.ResetTimer()
+			for i := 0; i < b.N; i++ {
+				tc.reader.Reset(bytes.NewBuffer(bufBytes))
+				var got int
+				m := &test3.TestAllTypes{}
+				for {
+					err := protodelim.UnmarshalFrom(tc.reader, m)
+					if errors.Is(err, io.EOF) {
+						break
+					}
+					if err != nil {
+						b.Errorf("protodelim.UnmarshalFrom(_) = %v", err)
+						continue
+					}
+					got++
+				}
+				if got != len(msgs) {
+					b.Errorf("Got %v messages. Wanted %v", got, len(msgs))
+				}
+			}
+		})
 	}
 }