internal/impl: faster oneof marshaling
Change size, marshal, and isinit operations on oneofs to look up the
currently-set oneof type in a map rather than testing for each possible
oneof field in turn.
Significantly improves oneof encoding speed for oneofs with a
substantial number of fields:
go test ./proto -bench=./oneof.*string.*test.TestAll -benchmem -count=8 -cpu=1
name old time/op new time/op delta
Encode/oneof_(string)_(*test.TestAllTypes) 911ns ± 1% 397ns ± 3% -56.45% (p=0.000 n=8+7)
Decode/oneof_(string)_(*test.TestAllTypes) 899ns ± 1% 922ns ± 1% +2.49% (p=0.001 n=7+7)
Fixes golang/protobuf#950
Change-Id: I9393a87975ce09011d885a8af4a63a639ea8452f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210281
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 4cfd9b0..366034c 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -20,40 +20,39 @@
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
-func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
- ot := si.oneofWrappersByNumber[fd.Number()]
- funcs := fieldCoder(fd, ot.Field(0).Type)
- fs := si.oneofsByName[fd.ContainingOneof().Name()]
- ft := fs.Type
- wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
- tagsize := wire.SizeVarint(wiretag)
- getInfo := func(p pointer) (pointer, bool) {
- v := p.AsValueOf(ft).Elem()
- if v.IsNil() {
- return pointer{}, false
- }
- v = v.Elem() // interface -> *struct
- if v.IsNil() || v.Elem().Type() != ot {
- return pointer{}, false
- }
- return pointerOfValue(v).Apply(zeroOffset), true
+// initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
+//
+// For size, marshal, and isInit operations, functions are set only on the first field
+// in the oneof. The functions are called when the oneof is non-nil, and will dispatch
+// to the appropriate field-specific function as necessary.
+//
+// The unmarshal function is set on each field individually as usual.
+func (mi *MessageInfo) initOneofFieldCoders(od pref.OneofDescriptor, si structInfo) {
+ type oneofFieldInfo struct {
+ wiretag uint64 // field tag (number + wire type)
+ tagsize int // size of the varint-encoded tag
+ funcs pointerCoderFuncs
}
- pcf := pointerCoderFuncs{
- size: func(p pointer, _ int, opts marshalOptions) int {
- v, ok := getInfo(p)
- if !ok {
- return 0
- }
- return funcs.size(v, tagsize, opts)
- },
- marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- v, ok := getInfo(p)
- if !ok {
- return b, nil
- }
- return funcs.marshal(b, v, wiretag, opts)
- },
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ fs := si.oneofsByName[od.Name()]
+ ft := fs.Type
+ oneofFields := make(map[reflect.Type]*oneofFieldInfo)
+ needIsInit := false
+ fields := od.Fields()
+ for i, lim := 0, fields.Len(); i < lim; i++ {
+ fd := od.Fields().Get(i)
+ num := fd.Number()
+ cf := mi.coderFields[num]
+ ot := si.oneofWrappersByNumber[num]
+ funcs := fieldCoder(fd, ot.Field(0).Type)
+ oneofFields[ot] = &oneofFieldInfo{
+ wiretag: cf.wiretag,
+ tagsize: cf.tagsize,
+ funcs: funcs,
+ }
+ if funcs.isInit != nil {
+ needIsInit = true
+ }
+ cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
var vw reflect.Value // pointer to wrapper type
vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
@@ -67,18 +66,43 @@
}
vi.Set(vw)
return n, nil
- },
- }
- if funcs.isInit != nil {
- pcf.isInit = func(p pointer) error {
- v, ok := getInfo(p)
- if !ok {
- return nil
- }
- return funcs.isInit(v)
}
}
- return pcf
+ getInfo := func(p pointer) (pointer, *oneofFieldInfo) {
+ v := p.AsValueOf(ft).Elem()
+ if v.IsNil() {
+ return pointer{}, nil
+ }
+ v = v.Elem() // interface -> *struct
+ if v.IsNil() {
+ return pointer{}, nil
+ }
+ return pointerOfValue(v).Apply(zeroOffset), oneofFields[v.Elem().Type()]
+ }
+ first := mi.coderFields[od.Fields().Get(0).Number()]
+ first.funcs.size = func(p pointer, tagsize int, opts marshalOptions) int {
+ p, info := getInfo(p)
+ if info == nil || info.funcs.size == nil {
+ return 0
+ }
+ return info.funcs.size(p, info.tagsize, opts)
+ }
+ first.funcs.marshal = func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ p, info := getInfo(p)
+ if info == nil || info.funcs.marshal == nil {
+ return b, nil
+ }
+ return info.funcs.marshal(b, p, info.wiretag, opts)
+ }
+ if needIsInit {
+ first.funcs.isInit = func(p pointer) error {
+ p, info := getInfo(p)
+ if info == nil || info.funcs.isInit == nil {
+ return nil
+ }
+ return info.funcs.isInit(p)
+ }
+ }
}
func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index d7584d4..6cfb5c7 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -68,7 +68,6 @@
switch {
case fd.ContainingOneof() != nil:
fieldOffset = offsetOf(fs, mi.Exporter)
- funcs = makeOneofFieldCoder(fd, si)
case fd.IsWeak():
fieldOffset = si.weakOffset
funcs = makeWeakMessageFieldCoder(fd)
@@ -91,6 +90,9 @@
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
+ for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
+ mi.initOneofFieldCoders(oneofs.Get(i), si)
+ }
if messageset.IsMessageSet(mi.Desc) {
if !mi.extensionOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 4ce3b1d..c793021 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -82,11 +82,11 @@
size += mi.sizeExtensions(e, opts)
}
for _, f := range mi.orderedCoderFields {
- fptr := p.Apply(f.offset)
- if f.isPointer && fptr.Elem().IsNil() {
+ if f.funcs.size == nil {
continue
}
- if f.funcs.size == nil {
+ fptr := p.Apply(f.offset)
+ if f.isPointer && fptr.Elem().IsNil() {
continue
}
size += f.funcs.size(fptr, f.tagsize, opts)
@@ -131,11 +131,11 @@
}
}
for _, f := range mi.orderedCoderFields {
- fptr := p.Apply(f.offset)
- if f.isPointer && fptr.Elem().IsNil() {
+ if f.funcs.marshal == nil {
continue
}
- if f.funcs.marshal == nil {
+ fptr := p.Apply(f.offset)
+ if f.isPointer && fptr.Elem().IsNil() {
continue
}
b, err = f.funcs.marshal(b, fptr, f.wiretag, opts)