internal/impl: faster map fast path

Avoid using protobuf reflection on map values in the fast path. Range
operations in particular are expensive in protoreflect, because the
closure passed to Map.Range escapes.

Iterate maps using a reflect.MapIter when available.

When operating on maps of messages where we have a *MessageInfo for the
message type, directly jump to the fast-path *MessageInfo methods rather
than passing through the proto package.

Benchmarks deltas for a google.protobuf.Struct with JSON represention:
  {"parameters":{"a":{"b":{"c":{"d":{"e":{"f":{"g":{}}}}}}}}}

Compared to previous revision:

  name                      old time/op  new time/op  delta
  NestedStruct/Size         7.22µs ± 2%  4.84µs ± 2%  -32.96%  (p=0.000 n=8+8)
  NestedStruct/Size-8       9.30µs ± 2%  5.89µs ± 2%  -36.60%  (p=0.000 n=8+8)
  NestedStruct/Marshal      77.6µs ±12%   9.8µs ± 4%  -87.33%  (p=0.000 n=8+8)
  NestedStruct/Marshal-8    91.6µs ± 2%  11.9µs ± 2%  -86.99%  (p=0.000 n=8+8)
  NestedStruct/Unmarshal    11.5µs ± 4%   8.7µs ± 2%  -24.76%  (p=0.000 n=8+8)
  NestedStruct/Unmarshal-8  15.4µs ± 4%  11.9µs ± 2%  -22.41%  (p=0.000 n=8+8)

Compared to github.com/golang/protobuf:

  name                      old time/op  new time/op  delta
  NestedStruct/Size         5.42µs ± 1%  4.84µs ± 2%  -10.61%  (p=0.000 n=8+8)
  NestedStruct/Size-8       6.34µs ± 2%  5.89µs ± 2%   -7.10%  (p=0.000 n=8+8)
  NestedStruct/Marshal      12.5µs ± 2%   9.8µs ± 4%  -21.41%  (p=0.000 n=7+8)
  NestedStruct/Marshal-8    14.1µs ± 3%  11.9µs ± 2%  -15.52%  (p=0.000 n=8+8)
  NestedStruct/Unmarshal    9.66µs ± 1%  8.65µs ± 2%  -10.40%  (p=0.000 n=7+8)
  NestedStruct/Unmarshal-8  11.7µs ± 3%  11.9µs ± 2%   +1.95%  (p=0.038 n=8+8)

Change-Id: I0effe6491f30d41f31904777f74eca3ac3694db3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/211737
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 8b85c69..5f7d9e2 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -6,20 +6,22 @@
 
 import (
 	"reflect"
+	"sort"
 
 	"google.golang.org/protobuf/internal/encoding/wire"
-	"google.golang.org/protobuf/internal/mapsort"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
 type mapInfo struct {
-	goType     reflect.Type
-	keyWiretag uint64
-	valWiretag uint64
-	keyFuncs   valueCoderFuncs
-	valFuncs   valueCoderFuncs
-	keyZero    pref.Value
-	keyKind    pref.Kind
+	goType         reflect.Type
+	keyWiretag     uint64
+	valWiretag     uint64
+	keyFuncs       valueCoderFuncs
+	valFuncs       valueCoderFuncs
+	keyZero        pref.Value
+	keyKind        pref.Kind
+	valMessageInfo *MessageInfo
+	conv           *mapConverter
 }
 
 func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
@@ -30,7 +32,7 @@
 	valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
 	keyFuncs := encoderFuncsForValue(keyField)
 	valFuncs := encoderFuncsForValue(valField)
-	conv := NewConverter(ft, fd)
+	conv := newMapConverter(ft, fd)
 
 	mapi := &mapInfo{
 		goType:     ft,
@@ -40,30 +42,34 @@
 		valFuncs:   valFuncs,
 		keyZero:    keyField.Default(),
 		keyKind:    keyField.Kind(),
+		conv:       conv,
+	}
+	if valField.Kind() == pref.MessageKind {
+		mapi.valMessageInfo = getMessageInfo(ft.Elem())
 	}
 
 	funcs = pointerCoderFuncs{
 		size: func(p pointer, tagsize int, opts marshalOptions) int {
-			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
-			return sizeMap(mapv, tagsize, mapi, opts)
+			return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
 		},
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
-			return appendMap(b, mapv, wiretag, mapi, opts)
+			return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
 		},
 		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
 			mp := p.AsValueOf(ft)
 			if mp.Elem().IsNil() {
 				mp.Elem().Set(reflect.MakeMap(mapi.goType))
 			}
-			mapv := conv.PBValueOf(mp.Elem()).Map()
-			return consumeMap(b, mapv, wtyp, mapi, opts)
+			if mapi.valMessageInfo == nil {
+				return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
+			} else {
+				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
+			}
 		},
 	}
 	if valFuncs.isInit != nil {
 		funcs.isInit = func(p pointer) error {
-			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
-			return isInitMap(mapv, mapi)
+			return isInitMap(p.AsValueOf(ft).Elem(), mapi)
 		}
 	}
 	return funcs
@@ -74,21 +80,30 @@
 	mapValTagSize = 1 // field 2, tag size 2.
 )
 
-func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
+func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
 	if mapv.Len() == 0 {
 		return 0
 	}
 	n := 0
-	mapv.Range(func(key pref.MapKey, value pref.Value) bool {
-		n += tagsize + wire.SizeBytes(
-			mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
-				mapi.valFuncs.size(value, mapValTagSize, opts))
-		return true
-	})
+	iter := mapRange(mapv)
+	for iter.Next() {
+		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
+		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		var valSize int
+		value := mapi.conv.valConv.PBValueOf(iter.Value())
+		if mapi.valMessageInfo == nil {
+			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
+		} else {
+			p := pointerOfValue(iter.Value())
+			valSize += mapValTagSize
+			valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
+		}
+		n += tagsize + wire.SizeBytes(keySize+valSize)
+	}
 	return n
 }
 
-func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
 	if wtyp != wire.BytesType {
 		return 0, errUnknown
 	}
@@ -98,7 +113,7 @@
 	}
 	var (
 		key = mapi.keyZero
-		val = mapv.NewValue()
+		val = mapi.conv.valConv.New()
 	)
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
@@ -133,44 +148,161 @@
 		}
 		b = b[n:]
 	}
-	mapv.Set(key.MapKey(), val)
+	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
 	return n, nil
 }
 
-func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+	if wtyp != wire.BytesType {
+		return 0, errUnknown
+	}
+	b, n := wire.ConsumeBytes(b)
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	var (
+		key = mapi.keyZero
+		val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
+	)
+	for len(b) > 0 {
+		num, wtyp, n := wire.ConsumeTag(b)
+		if n < 0 {
+			return 0, wire.ParseError(n)
+		}
+		b = b[n:]
+		err := errUnknown
+		switch num {
+		case 1:
+			var v pref.Value
+			v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			key = v
+		case 2:
+			if wtyp != wire.BytesType {
+				break
+			}
+			v, n := wire.ConsumeBytes(b)
+			if n < 0 {
+				return 0, wire.ParseError(n)
+			}
+			n, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
+		}
+		if err == errUnknown {
+			n = wire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return 0, wire.ParseError(n)
+			}
+		} else if err != nil {
+			return 0, err
+		}
+		b = b[n:]
+	}
+	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
+	return n, nil
+}
+
+func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+	if mapi.valMessageInfo == nil {
+		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
+		val := mapi.conv.valConv.PBValueOf(valrv)
+		size := 0
+		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		size += mapi.valFuncs.size(val, mapValTagSize, opts)
+		b = wire.AppendVarint(b, uint64(size))
+		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
+		if err != nil {
+			return nil, err
+		}
+		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
+	} else {
+		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
+		val := pointerOfValue(valrv)
+		valSize := mapi.valMessageInfo.sizePointer(val, opts)
+		size := 0
+		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		size += mapValTagSize + wire.SizeBytes(valSize)
+		b = wire.AppendVarint(b, uint64(size))
+		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
+		if err != nil {
+			return nil, err
+		}
+		b = wire.AppendVarint(b, mapi.valWiretag)
+		b = wire.AppendVarint(b, uint64(valSize))
+		return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
+	}
+}
+
+func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
 	if mapv.Len() == 0 {
 		return b, nil
 	}
-	var err error
-	fn := func(key pref.MapKey, value pref.Value) bool {
-		b = wire.AppendVarint(b, wiretag)
-		size := 0
-		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
-		size += mapi.valFuncs.size(value, mapValTagSize, opts)
-		b = wire.AppendVarint(b, uint64(size))
-		b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
-		if err != nil {
-			return false
-		}
-		b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
-		if err != nil {
-			return false
-		}
-		return true
-	}
 	if opts.Deterministic() {
-		mapsort.Range(mapv, mapi.keyKind, fn)
-	} else {
-		mapv.Range(fn)
+		return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
 	}
-	return b, err
+	iter := mapRange(mapv)
+	for iter.Next() {
+		var err error
+		b = wire.AppendVarint(b, wiretag)
+		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
+		if err != nil {
+			return b, err
+		}
+	}
+	return b, nil
 }
 
-func isInitMap(mapv pref.Map, mapi *mapInfo) error {
-	var err error
-	mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
-		err = mapi.valFuncs.isInit(value)
-		return err == nil
+func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+	keys := mapv.MapKeys()
+	sort.Slice(keys, func(i, j int) bool {
+		switch keys[i].Kind() {
+		case reflect.Bool:
+			return !keys[i].Bool() && keys[j].Bool()
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			return keys[i].Int() < keys[j].Int()
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+			return keys[i].Uint() < keys[j].Uint()
+		case reflect.Float32, reflect.Float64:
+			return keys[i].Float() < keys[j].Float()
+		case reflect.String:
+			return keys[i].String() < keys[j].String()
+		default:
+			panic("invalid kind: " + keys[i].Kind().String())
+		}
 	})
-	return err
+	for _, key := range keys {
+		var err error
+		b = wire.AppendVarint(b, wiretag)
+		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
+		if err != nil {
+			return b, err
+		}
+	}
+	return b, nil
+}
+
+func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
+	if mi := mapi.valMessageInfo; mi != nil {
+		mi.init()
+		if !mi.needsInitCheck {
+			return nil
+		}
+		iter := mapRange(mapv)
+		for iter.Next() {
+			val := pointerOfValue(iter.Value())
+			if err := mi.isInitializedPointer(val); err != nil {
+				return err
+			}
+		}
+	} else {
+		iter := mapRange(mapv)
+		for iter.Next() {
+			val := mapi.conv.valConv.PBValueOf(iter.Value())
+			if err := mapi.valFuncs.isInit(val); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
 }
diff --git a/internal/impl/convert_map.go b/internal/impl/convert_map.go
index 34c3fcd..fcb1450 100644
--- a/internal/impl/convert_map.go
+++ b/internal/impl/convert_map.go
@@ -16,7 +16,7 @@
 	keyConv, valConv Converter
 }
 
-func newMapConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
+func newMapConverter(t reflect.Type, fd pref.FieldDescriptor) *mapConverter {
 	if t.Kind() != reflect.Map {
 		panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
 	}