internal/impl: optimize reflect methods
This change performs two optimizations:
* It uses a pre-constructed rangeInfos slice to iterate over
all the fields. This is more performant since iterating over a slice
is faster than iterating over a map. Furthermore, this slice
does not contain fields that are part of a oneof. If a oneof has
N fields, the time to check presence on the oneof is now O(1)
instead of O(N).
* It uses a dense field info slice that is optmized for the common
case where the field number is relatively low and close in value
to the index itself.
We also fix a minor bug in the construction of oneofInfo where
it wasn't treating a typed nil pointer to a wrapper struct as if
it were unset. This ensures WhichOneof and Has always agree.
name old time/op new time/op delta
Reflect/Has-4 7.81µs ± 3% 6.74µs ± 3% -13.61% (p=0.000 n=9+9)
Reflect/Get-4 12.7µs ± 1% 11.3µs ± 4% -10.85% (p=0.000 n=8+10)
Reflect/Set-4 19.5µs ± 5% 17.8µs ± 2% -8.99% (p=0.000 n=10+10)
Reflect/Clear-4 12.0µs ± 4% 10.2µs ± 3% -14.86% (p=0.000 n=9+10)
Reflect/Range-4 6.58µs ± 1% 4.17µs ± 2% -36.65% (p=0.000 n=8+9)
Change-Id: I2c48b4d3fb6103ab238924950529ded0d37f8c8a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/196358
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index df19b19..b16acbf 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -658,10 +658,20 @@
func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.messageInfo().init()
- for _, fi := range m.messageInfo().fields {
- if fi.has(m.pointer()) {
- if !f(fi.fieldDesc, fi.get(m.pointer())) {
- return
+ for _, ri := range m.messageInfo().rangeInfos {
+ switch ri := ri.(type) {
+ case *fieldInfo:
+ if ri.has(m.pointer()) {
+ if !f(ri.fieldDesc, ri.get(m.pointer())) {
+ return
+ }
+ }
+ case *oneofInfo:
+ if n := ri.which(m.pointer()); n > 0 {
+ fi := m.messageInfo().fields[n]
+ if !f(fi.fieldDesc, fi.get(m.pointer())) {
+ return
+ }
}
}
}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 04e8f7c..594ed9e 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -16,6 +16,14 @@
fields map[pref.FieldNumber]*fieldInfo
oneofs map[pref.Name]*oneofInfo
+ // denseFields is a subset of fields where:
+ // 0 < fieldDesc.Number() < len(denseFields)
+ // It provides faster access to the fieldInfo, but may be incomplete.
+ denseFields []*fieldInfo
+
+ // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
+ rangeInfos []interface{} // either *fieldInfo or *oneofInfo
+
getUnknown func(pointer) pref.RawFields
setUnknown func(pointer, pref.RawFields)
extensionMap func(pointer) *extensionMap
@@ -39,8 +47,9 @@
func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
mi.fields = map[pref.FieldNumber]*fieldInfo{}
md := mi.Desc
- for i := 0; i < md.Fields().Len(); i++ {
- fd := md.Fields().Get(i)
+ fds := md.Fields()
+ for i := 0; i < fds.Len(); i++ {
+ fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
var fi fieldInfo
switch {
@@ -65,6 +74,24 @@
od := md.Oneofs().Get(i)
mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType)
}
+
+ mi.denseFields = make([]*fieldInfo, fds.Len()*2)
+ for i := 0; i < fds.Len(); i++ {
+ if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
+ mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
+ }
+ }
+
+ for i := 0; i < fds.Len(); {
+ fd := fds.Get(i)
+ if od := fd.ContainingOneof(); od != nil {
+ mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
+ i += od.Fields().Len()
+ } else {
+ mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
+ i++
+ }
+ }
}
func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
@@ -273,12 +300,19 @@
// checkField verifies that the provided field descriptor is valid.
// Exactly one of the returned values is populated.
func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
- if fi := mi.fields[fd.Number()]; fi != nil {
+ var fi *fieldInfo
+ if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
+ fi = mi.denseFields[n]
+ } else {
+ fi = mi.fields[n]
+ }
+ if fi != nil {
if fi.fieldDesc != fd {
panic("mismatching field descriptor")
}
return fi, nil
}
+
if fd.IsExtension() {
if fd.ContainingMessage().FullName() != mi.Desc.FullName() {
// TODO: Should this be exact containing message descriptor match?
diff --git a/internal/impl/message_reflect_field.go b/internal/impl/message_reflect_field.go
index e384aa3..63b4055 100644
--- a/internal/impl/message_reflect_field.go
+++ b/internal/impl/message_reflect_field.go
@@ -435,7 +435,11 @@
if rv.IsNil() {
return 0
}
- return wrappersByType[rv.Elem().Type().Elem()]
+ rv = rv.Elem()
+ if rv.IsNil() {
+ return 0
+ }
+ return wrappersByType[rv.Type().Elem()]
},
}
}
diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go
index 574fb4f..1c56375 100644
--- a/internal/impl/message_reflect_gen.go
+++ b/internal/impl/message_reflect_gen.go
@@ -42,10 +42,20 @@
func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.messageInfo().init()
- for _, fi := range m.messageInfo().fields {
- if fi.has(m.pointer()) {
- if !f(fi.fieldDesc, fi.get(m.pointer())) {
- return
+ for _, ri := range m.messageInfo().rangeInfos {
+ switch ri := ri.(type) {
+ case *fieldInfo:
+ if ri.has(m.pointer()) {
+ if !f(ri.fieldDesc, ri.get(m.pointer())) {
+ return
+ }
+ }
+ case *oneofInfo:
+ if n := ri.which(m.pointer()); n > 0 {
+ fi := m.messageInfo().fields[n]
+ if !f(fi.fieldDesc, fi.get(m.pointer())) {
+ return
+ }
}
}
}
@@ -149,10 +159,20 @@
func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.messageInfo().init()
- for _, fi := range m.messageInfo().fields {
- if fi.has(m.pointer()) {
- if !f(fi.fieldDesc, fi.get(m.pointer())) {
- return
+ for _, ri := range m.messageInfo().rangeInfos {
+ switch ri := ri.(type) {
+ case *fieldInfo:
+ if ri.has(m.pointer()) {
+ if !f(ri.fieldDesc, ri.get(m.pointer())) {
+ return
+ }
+ }
+ case *oneofInfo:
+ if n := ri.which(m.pointer()); n > 0 {
+ fi := m.messageInfo().fields[n]
+ if !f(fi.fieldDesc, fi.get(m.pointer())) {
+ return
+ }
}
}
}
diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go
index a836c89..6148d4d 100644
--- a/internal/impl/message_reflect_test.go
+++ b/internal/impl/message_reflect_test.go
@@ -1485,3 +1485,53 @@
})
runtime.KeepAlive(sink)
}
+
+func BenchmarkReflect(b *testing.B) {
+ m := new(testpb.TestAllTypes).ProtoReflect()
+ fds := m.Descriptor().Fields()
+ vs := make([]pref.Value, fds.Len())
+ for i := range vs {
+ vs[i] = m.NewField(fds.Get(i))
+ }
+
+ b.Run("Has", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < fds.Len(); j++ {
+ m.Has(fds.Get(j))
+ }
+ }
+ })
+ b.Run("Get", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < fds.Len(); j++ {
+ m.Get(fds.Get(j))
+ }
+ }
+ })
+ b.Run("Set", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < fds.Len(); j++ {
+ m.Set(fds.Get(j), vs[j])
+ }
+ }
+ })
+ b.Run("Clear", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < fds.Len(); j++ {
+ m.Clear(fds.Get(j))
+ }
+ }
+ })
+ b.Run("Range", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ m.Range(func(pref.FieldDescriptor, pref.Value) bool {
+ return true
+ })
+ }
+ })
+}