internal/impl: optimize reflect methods

This change performs two optimizations:
* It uses a pre-constructed rangeInfos slice to iterate over
all the fields. This is more performant since iterating over a slice
is faster than iterating over a map. Furthermore, this slice
does not contain fields that are part of a oneof. If a oneof has
N fields, the time to check presence on the oneof is now O(1)
instead of O(N).
* It uses a dense field info slice that is optmized for the common
case where the field number is relatively low and close in value
to the index itself.

We also fix a minor bug in the construction of oneofInfo where
it wasn't treating a typed nil pointer to a wrapper struct as if
it were unset. This ensures WhichOneof and Has always agree.

name             old time/op    new time/op    delta
Reflect/Has-4      7.81µs ± 3%    6.74µs ± 3%  -13.61%  (p=0.000 n=9+9)
Reflect/Get-4      12.7µs ± 1%    11.3µs ± 4%  -10.85%  (p=0.000 n=8+10)
Reflect/Set-4      19.5µs ± 5%    17.8µs ± 2%   -8.99%  (p=0.000 n=10+10)
Reflect/Clear-4    12.0µs ± 4%    10.2µs ± 3%  -14.86%  (p=0.000 n=9+10)
Reflect/Range-4    6.58µs ± 1%    4.17µs ± 2%  -36.65%  (p=0.000 n=8+9)

Change-Id: I2c48b4d3fb6103ab238924950529ded0d37f8c8a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/196358
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index df19b19..b16acbf 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -658,10 +658,20 @@
 
 func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
 	m.messageInfo().init()
-	for _, fi := range m.messageInfo().fields {
-		if fi.has(m.pointer()) {
-			if !f(fi.fieldDesc, fi.get(m.pointer())) {
-				return
+	for _, ri := range m.messageInfo().rangeInfos {
+		switch ri := ri.(type) {
+		case *fieldInfo:
+			if ri.has(m.pointer()) {
+				if !f(ri.fieldDesc, ri.get(m.pointer())) {
+					return
+				}
+			}
+		case *oneofInfo:
+			if n := ri.which(m.pointer()); n > 0 {
+				fi := m.messageInfo().fields[n]
+				if !f(fi.fieldDesc, fi.get(m.pointer())) {
+					return
+				}
 			}
 		}
 	}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 04e8f7c..594ed9e 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -16,6 +16,14 @@
 	fields map[pref.FieldNumber]*fieldInfo
 	oneofs map[pref.Name]*oneofInfo
 
+	// denseFields is a subset of fields where:
+	//	0 < fieldDesc.Number() < len(denseFields)
+	// It provides faster access to the fieldInfo, but may be incomplete.
+	denseFields []*fieldInfo
+
+	// rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
+	rangeInfos []interface{} // either *fieldInfo or *oneofInfo
+
 	getUnknown   func(pointer) pref.RawFields
 	setUnknown   func(pointer, pref.RawFields)
 	extensionMap func(pointer) *extensionMap
@@ -39,8 +47,9 @@
 func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
 	mi.fields = map[pref.FieldNumber]*fieldInfo{}
 	md := mi.Desc
-	for i := 0; i < md.Fields().Len(); i++ {
-		fd := md.Fields().Get(i)
+	fds := md.Fields()
+	for i := 0; i < fds.Len(); i++ {
+		fd := fds.Get(i)
 		fs := si.fieldsByNumber[fd.Number()]
 		var fi fieldInfo
 		switch {
@@ -65,6 +74,24 @@
 		od := md.Oneofs().Get(i)
 		mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType)
 	}
+
+	mi.denseFields = make([]*fieldInfo, fds.Len()*2)
+	for i := 0; i < fds.Len(); i++ {
+		if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
+			mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
+		}
+	}
+
+	for i := 0; i < fds.Len(); {
+		fd := fds.Get(i)
+		if od := fd.ContainingOneof(); od != nil {
+			mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
+			i += od.Fields().Len()
+		} else {
+			mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
+			i++
+		}
+	}
 }
 
 func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
@@ -273,12 +300,19 @@
 // checkField verifies that the provided field descriptor is valid.
 // Exactly one of the returned values is populated.
 func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
-	if fi := mi.fields[fd.Number()]; fi != nil {
+	var fi *fieldInfo
+	if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
+		fi = mi.denseFields[n]
+	} else {
+		fi = mi.fields[n]
+	}
+	if fi != nil {
 		if fi.fieldDesc != fd {
 			panic("mismatching field descriptor")
 		}
 		return fi, nil
 	}
+
 	if fd.IsExtension() {
 		if fd.ContainingMessage().FullName() != mi.Desc.FullName() {
 			// TODO: Should this be exact containing message descriptor match?
diff --git a/internal/impl/message_reflect_field.go b/internal/impl/message_reflect_field.go
index e384aa3..63b4055 100644
--- a/internal/impl/message_reflect_field.go
+++ b/internal/impl/message_reflect_field.go
@@ -435,7 +435,11 @@
 			if rv.IsNil() {
 				return 0
 			}
-			return wrappersByType[rv.Elem().Type().Elem()]
+			rv = rv.Elem()
+			if rv.IsNil() {
+				return 0
+			}
+			return wrappersByType[rv.Type().Elem()]
 		},
 	}
 }
diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go
index 574fb4f..1c56375 100644
--- a/internal/impl/message_reflect_gen.go
+++ b/internal/impl/message_reflect_gen.go
@@ -42,10 +42,20 @@
 
 func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
 	m.messageInfo().init()
-	for _, fi := range m.messageInfo().fields {
-		if fi.has(m.pointer()) {
-			if !f(fi.fieldDesc, fi.get(m.pointer())) {
-				return
+	for _, ri := range m.messageInfo().rangeInfos {
+		switch ri := ri.(type) {
+		case *fieldInfo:
+			if ri.has(m.pointer()) {
+				if !f(ri.fieldDesc, ri.get(m.pointer())) {
+					return
+				}
+			}
+		case *oneofInfo:
+			if n := ri.which(m.pointer()); n > 0 {
+				fi := m.messageInfo().fields[n]
+				if !f(fi.fieldDesc, fi.get(m.pointer())) {
+					return
+				}
 			}
 		}
 	}
@@ -149,10 +159,20 @@
 
 func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
 	m.messageInfo().init()
-	for _, fi := range m.messageInfo().fields {
-		if fi.has(m.pointer()) {
-			if !f(fi.fieldDesc, fi.get(m.pointer())) {
-				return
+	for _, ri := range m.messageInfo().rangeInfos {
+		switch ri := ri.(type) {
+		case *fieldInfo:
+			if ri.has(m.pointer()) {
+				if !f(ri.fieldDesc, ri.get(m.pointer())) {
+					return
+				}
+			}
+		case *oneofInfo:
+			if n := ri.which(m.pointer()); n > 0 {
+				fi := m.messageInfo().fields[n]
+				if !f(fi.fieldDesc, fi.get(m.pointer())) {
+					return
+				}
 			}
 		}
 	}
diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go
index a836c89..6148d4d 100644
--- a/internal/impl/message_reflect_test.go
+++ b/internal/impl/message_reflect_test.go
@@ -1485,3 +1485,53 @@
 	})
 	runtime.KeepAlive(sink)
 }
+
+func BenchmarkReflect(b *testing.B) {
+	m := new(testpb.TestAllTypes).ProtoReflect()
+	fds := m.Descriptor().Fields()
+	vs := make([]pref.Value, fds.Len())
+	for i := range vs {
+		vs[i] = m.NewField(fds.Get(i))
+	}
+
+	b.Run("Has", func(b *testing.B) {
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			for j := 0; j < fds.Len(); j++ {
+				m.Has(fds.Get(j))
+			}
+		}
+	})
+	b.Run("Get", func(b *testing.B) {
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			for j := 0; j < fds.Len(); j++ {
+				m.Get(fds.Get(j))
+			}
+		}
+	})
+	b.Run("Set", func(b *testing.B) {
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			for j := 0; j < fds.Len(); j++ {
+				m.Set(fds.Get(j), vs[j])
+			}
+		}
+	})
+	b.Run("Clear", func(b *testing.B) {
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			for j := 0; j < fds.Len(); j++ {
+				m.Clear(fds.Get(j))
+			}
+		}
+	})
+	b.Run("Range", func(b *testing.B) {
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			m.Range(func(pref.FieldDescriptor, pref.Value) bool {
+				return true
+			})
+		}
+	})
+}