internal/impl: improve extension fast path performance

Stash fast-path information for extensions on the ExtensionInfo. In
the usual case where an ExtensionType's underlying implementation is
an *ExtensionInfo, fetching the fast-path information becomes a type
assertion rather than a mutex-guarded map access.

Maintain a global sync.Map for the case where an ExtensionType isn't an
*ExtensionInfo.

Substantially improves performance for fast-path operations on
extensions:

Encode/MessageSet_type_id_before_message_content-12      267ns ± 1%   185ns ± 1%  -30.44%  (p=0.001 n=7+7)
Encode/basic_scalar_types_(*test.TestAllExtensions)-12  1.94µs ± 1%  0.40µs ± 1%  -79.32%  (p=0.000 n=8+7)

Change-Id: If048b521deb3665a090ea3d0a178c61691d4201e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210540
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index f2dca63..22a25dc 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -19,24 +19,36 @@
 	funcs               valueCoderFuncs
 }
 
-func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
-	// As of this time (Go 1.12, linux/amd64), an RWMutex benchmarks as faster
-	// than a sync.Map.
-	mi.extensionFieldInfosMu.RLock()
-	e, ok := mi.extensionFieldInfos[xt]
-	mi.extensionFieldInfosMu.RUnlock()
-	if ok {
-		return e
-	}
+var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo
 
-	xd := xt.TypeDescriptor()
+func getExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+	if xi, ok := xt.(*ExtensionInfo); ok {
+		xi.lazyInit()
+		return xi.info
+	}
+	return legacyLoadExtensionFieldInfo(xt)
+}
+
+// legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt.
+func legacyLoadExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+	if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok {
+		return xi.(*extensionFieldInfo)
+	}
+	e := makeExtensionFieldInfo(xt.TypeDescriptor())
+	if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok {
+		return e.(*extensionFieldInfo)
+	}
+	return e
+}
+
+func makeExtensionFieldInfo(xd pref.ExtensionDescriptor) *extensionFieldInfo {
 	var wiretag uint64
 	if !xd.IsPacked() {
 		wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
 	} else {
 		wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
 	}
-	e = &extensionFieldInfo{
+	e := &extensionFieldInfo{
 		wiretag: wiretag,
 		tagsize: wire.SizeVarint(wiretag),
 		funcs:   encoderFuncsForValue(xd),
@@ -52,12 +64,6 @@
 			e.unmarshalNeedsValue = true
 		}
 	}
-	mi.extensionFieldInfosMu.Lock()
-	if mi.extensionFieldInfos == nil {
-		mi.extensionFieldInfos = make(map[pref.ExtensionType]*extensionFieldInfo)
-	}
-	mi.extensionFieldInfos[xt] = e
-	mi.extensionFieldInfosMu.Unlock()
 	return e
 }
 
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index b4d632e..d7584d4 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -8,7 +8,6 @@
 	"fmt"
 	"reflect"
 	"sort"
-	"sync"
 
 	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
@@ -31,9 +30,6 @@
 	extensionOffset    offset
 	needsInitCheck     bool
 	isMessageSet       bool
-
-	extensionFieldInfosMu sync.RWMutex
-	extensionFieldInfos   map[pref.ExtensionType]*extensionFieldInfo
 }
 
 type coderFieldInfo struct {
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
index d78afeb..e917c7c 100644
--- a/internal/impl/codec_messageset.go
+++ b/internal/impl/codec_messageset.go
@@ -20,7 +20,7 @@
 
 	ext := *p.Apply(mi.extensionOffset).Extensions()
 	for _, x := range ext {
-		xi := mi.extensionFieldInfo(x.Type())
+		xi := getExtensionFieldInfo(x.Type())
 		if xi.funcs.size == nil {
 			continue
 		}
@@ -79,7 +79,7 @@
 }
 
 func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
-	xi := mi.extensionFieldInfo(x.Type())
+	xi := getExtensionFieldInfo(x.Type())
 	num, _ := wire.DecodeTag(xi.wiretag)
 	b = messageset.AppendFieldStart(b, num)
 	b, err := xi.funcs.marshal(b, x.Value(), wire.EncodeTag(messageset.FieldMessage, wire.BytesType), opts)
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 9ccfa7f..e48081e 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -154,7 +154,7 @@
 			return 0, err
 		}
 	}
-	xi := mi.extensionFieldInfo(xt)
+	xi := getExtensionFieldInfo(xt)
 	if xi.funcs.unmarshal == nil {
 		return 0, errUnknown
 	}
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index cd57998..4ce3b1d 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -155,7 +155,7 @@
 		return 0
 	}
 	for _, x := range *ext {
-		xi := mi.extensionFieldInfo(x.Type())
+		xi := getExtensionFieldInfo(x.Type())
 		if xi.funcs.size == nil {
 			continue
 		}
@@ -176,7 +176,7 @@
 		// Fast-path for one extension: Don't bother sorting the keys.
 		var err error
 		for _, x := range *ext {
-			xi := mi.extensionFieldInfo(x.Type())
+			xi := getExtensionFieldInfo(x.Type())
 			b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
 		}
 		return b, err
@@ -191,7 +191,7 @@
 		var err error
 		for _, k := range keys {
 			x := (*ext)[int32(k)]
-			xi := mi.extensionFieldInfo(x.Type())
+			xi := getExtensionFieldInfo(x.Type())
 			b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
 			if err != nil {
 				return b, err
diff --git a/internal/impl/extension.go b/internal/impl/extension.go
index 541f4a7..0619d5d 100644
--- a/internal/impl/extension.go
+++ b/internal/impl/extension.go
@@ -38,6 +38,7 @@
 	goType reflect.Type
 	desc   extensionTypeDescriptor
 	conv   Converter
+	info   *extensionFieldInfo // for fast-path method implementations
 
 	// ExtendedType is a typed nil-pointer to the parent message type that
 	// is being extended. It is possible for this to be unpopulated in v2
@@ -136,7 +137,8 @@
 		if xi.ExtensionType == nil {
 			xi.initToLegacy()
 		}
-		xi.conv = NewConverter(xi.goType, xi.desc)
+		xi.conv = NewConverter(xi.goType, xi.desc.ExtensionDescriptor)
+		xi.info = makeExtensionFieldInfo(xi.desc.ExtensionDescriptor)
 	}
 }
 
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index 3fd3674..d3a01bc 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -66,7 +66,7 @@
 		return nil
 	}
 	for _, x := range *ext {
-		ei := mi.extensionFieldInfo(x.Type())
+		ei := getExtensionFieldInfo(x.Type())
 		if ei.funcs.isInit == nil {
 			continue
 		}