internal/impl: faster oneof marshaling

Change size, marshal, and isinit operations on oneofs to look up the
currently-set oneof type in a map rather than testing for each possible
oneof field in turn.

Significantly improves oneof encoding speed for oneofs with a
substantial number of fields:

  go test ./proto -bench=./oneof.*string.*test.TestAll -benchmem -count=8 -cpu=1

  name                                        old time/op    new time/op    delta
  Encode/oneof_(string)_(*test.TestAllTypes)     911ns ± 1%     397ns ± 3%  -56.45%  (p=0.000 n=8+7)
  Decode/oneof_(string)_(*test.TestAllTypes)     899ns ± 1%     922ns ± 1%   +2.49%  (p=0.001 n=7+7)

Fixes golang/protobuf#950

Change-Id: I9393a87975ce09011d885a8af4a63a639ea8452f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210281
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 4cfd9b0..366034c 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -20,40 +20,39 @@
 func (errInvalidUTF8) Error() string     { return "string field contains invalid UTF-8" }
 func (errInvalidUTF8) InvalidUTF8() bool { return true }
 
-func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
-	ot := si.oneofWrappersByNumber[fd.Number()]
-	funcs := fieldCoder(fd, ot.Field(0).Type)
-	fs := si.oneofsByName[fd.ContainingOneof().Name()]
-	ft := fs.Type
-	wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
-	tagsize := wire.SizeVarint(wiretag)
-	getInfo := func(p pointer) (pointer, bool) {
-		v := p.AsValueOf(ft).Elem()
-		if v.IsNil() {
-			return pointer{}, false
-		}
-		v = v.Elem() // interface -> *struct
-		if v.IsNil() || v.Elem().Type() != ot {
-			return pointer{}, false
-		}
-		return pointerOfValue(v).Apply(zeroOffset), true
+// initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
+//
+// For size, marshal, and isInit operations, functions are set only on the first field
+// in the oneof. The functions are called when the oneof is non-nil, and will dispatch
+// to the appropriate field-specific function as necessary.
+//
+// The unmarshal function is set on each field individually as usual.
+func (mi *MessageInfo) initOneofFieldCoders(od pref.OneofDescriptor, si structInfo) {
+	type oneofFieldInfo struct {
+		wiretag uint64 // field tag (number + wire type)
+		tagsize int    // size of the varint-encoded tag
+		funcs   pointerCoderFuncs
 	}
-	pcf := pointerCoderFuncs{
-		size: func(p pointer, _ int, opts marshalOptions) int {
-			v, ok := getInfo(p)
-			if !ok {
-				return 0
-			}
-			return funcs.size(v, tagsize, opts)
-		},
-		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			v, ok := getInfo(p)
-			if !ok {
-				return b, nil
-			}
-			return funcs.marshal(b, v, wiretag, opts)
-		},
-		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+	fs := si.oneofsByName[od.Name()]
+	ft := fs.Type
+	oneofFields := make(map[reflect.Type]*oneofFieldInfo)
+	needIsInit := false
+	fields := od.Fields()
+	for i, lim := 0, fields.Len(); i < lim; i++ {
+		fd := od.Fields().Get(i)
+		num := fd.Number()
+		cf := mi.coderFields[num]
+		ot := si.oneofWrappersByNumber[num]
+		funcs := fieldCoder(fd, ot.Field(0).Type)
+		oneofFields[ot] = &oneofFieldInfo{
+			wiretag: cf.wiretag,
+			tagsize: cf.tagsize,
+			funcs:   funcs,
+		}
+		if funcs.isInit != nil {
+			needIsInit = true
+		}
+		cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
 			var vw reflect.Value         // pointer to wrapper type
 			vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
 			if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
@@ -67,18 +66,43 @@
 			}
 			vi.Set(vw)
 			return n, nil
-		},
-	}
-	if funcs.isInit != nil {
-		pcf.isInit = func(p pointer) error {
-			v, ok := getInfo(p)
-			if !ok {
-				return nil
-			}
-			return funcs.isInit(v)
 		}
 	}
-	return pcf
+	getInfo := func(p pointer) (pointer, *oneofFieldInfo) {
+		v := p.AsValueOf(ft).Elem()
+		if v.IsNil() {
+			return pointer{}, nil
+		}
+		v = v.Elem() // interface -> *struct
+		if v.IsNil() {
+			return pointer{}, nil
+		}
+		return pointerOfValue(v).Apply(zeroOffset), oneofFields[v.Elem().Type()]
+	}
+	first := mi.coderFields[od.Fields().Get(0).Number()]
+	first.funcs.size = func(p pointer, tagsize int, opts marshalOptions) int {
+		p, info := getInfo(p)
+		if info == nil || info.funcs.size == nil {
+			return 0
+		}
+		return info.funcs.size(p, info.tagsize, opts)
+	}
+	first.funcs.marshal = func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+		p, info := getInfo(p)
+		if info == nil || info.funcs.marshal == nil {
+			return b, nil
+		}
+		return info.funcs.marshal(b, p, info.wiretag, opts)
+	}
+	if needIsInit {
+		first.funcs.isInit = func(p pointer) error {
+			p, info := getInfo(p)
+			if info == nil || info.funcs.isInit == nil {
+				return nil
+			}
+			return info.funcs.isInit(p)
+		}
+	}
 }
 
 func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index d7584d4..6cfb5c7 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -68,7 +68,6 @@
 		switch {
 		case fd.ContainingOneof() != nil:
 			fieldOffset = offsetOf(fs, mi.Exporter)
-			funcs = makeOneofFieldCoder(fd, si)
 		case fd.IsWeak():
 			fieldOffset = si.weakOffset
 			funcs = makeWeakMessageFieldCoder(fd)
@@ -91,6 +90,9 @@
 		mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
 		mi.coderFields[cf.num] = cf
 	}
+	for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
+		mi.initOneofFieldCoders(oneofs.Get(i), si)
+	}
 	if messageset.IsMessageSet(mi.Desc) {
 		if !mi.extensionOffset.IsValid() {
 			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 4ce3b1d..c793021 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -82,11 +82,11 @@
 		size += mi.sizeExtensions(e, opts)
 	}
 	for _, f := range mi.orderedCoderFields {
-		fptr := p.Apply(f.offset)
-		if f.isPointer && fptr.Elem().IsNil() {
+		if f.funcs.size == nil {
 			continue
 		}
-		if f.funcs.size == nil {
+		fptr := p.Apply(f.offset)
+		if f.isPointer && fptr.Elem().IsNil() {
 			continue
 		}
 		size += f.funcs.size(fptr, f.tagsize, opts)
@@ -131,11 +131,11 @@
 		}
 	}
 	for _, f := range mi.orderedCoderFields {
-		fptr := p.Apply(f.offset)
-		if f.isPointer && fptr.Elem().IsNil() {
+		if f.funcs.marshal == nil {
 			continue
 		}
-		if f.funcs.marshal == nil {
+		fptr := p.Apply(f.offset)
+		if f.isPointer && fptr.Elem().IsNil() {
 			continue
 		}
 		b, err = f.funcs.marshal(b, fptr, f.wiretag, opts)