internal/impl: improve extension fast path performance
Stash fast-path information for extensions on the ExtensionInfo. In
the usual case where an ExtensionType's underlying implementation is
an *ExtensionInfo, fetching the fast-path information becomes a type
assertion rather than a mutex-guarded map access.
Maintain a global sync.Map for the case where an ExtensionType isn't an
*ExtensionInfo.
Substantially improves performance for fast-path operations on
extensions:
Encode/MessageSet_type_id_before_message_content-12 267ns ± 1% 185ns ± 1% -30.44% (p=0.001 n=7+7)
Encode/basic_scalar_types_(*test.TestAllExtensions)-12 1.94µs ± 1% 0.40µs ± 1% -79.32% (p=0.000 n=8+7)
Change-Id: If048b521deb3665a090ea3d0a178c61691d4201e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210540
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index f2dca63..22a25dc 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -19,24 +19,36 @@
funcs valueCoderFuncs
}
-func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
- // As of this time (Go 1.12, linux/amd64), an RWMutex benchmarks as faster
- // than a sync.Map.
- mi.extensionFieldInfosMu.RLock()
- e, ok := mi.extensionFieldInfos[xt]
- mi.extensionFieldInfosMu.RUnlock()
- if ok {
- return e
- }
+var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo
- xd := xt.TypeDescriptor()
+func getExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+ if xi, ok := xt.(*ExtensionInfo); ok {
+ xi.lazyInit()
+ return xi.info
+ }
+ return legacyLoadExtensionFieldInfo(xt)
+}
+
+// legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt.
+func legacyLoadExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+ if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok {
+ return xi.(*extensionFieldInfo)
+ }
+ e := makeExtensionFieldInfo(xt.TypeDescriptor())
+ if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok {
+ return e.(*extensionFieldInfo)
+ }
+ return e
+}
+
+func makeExtensionFieldInfo(xd pref.ExtensionDescriptor) *extensionFieldInfo {
var wiretag uint64
if !xd.IsPacked() {
wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
} else {
wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
}
- e = &extensionFieldInfo{
+ e := &extensionFieldInfo{
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
funcs: encoderFuncsForValue(xd),
@@ -52,12 +64,6 @@
e.unmarshalNeedsValue = true
}
}
- mi.extensionFieldInfosMu.Lock()
- if mi.extensionFieldInfos == nil {
- mi.extensionFieldInfos = make(map[pref.ExtensionType]*extensionFieldInfo)
- }
- mi.extensionFieldInfos[xt] = e
- mi.extensionFieldInfosMu.Unlock()
return e
}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index b4d632e..d7584d4 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -8,7 +8,6 @@
"fmt"
"reflect"
"sort"
- "sync"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
@@ -31,9 +30,6 @@
extensionOffset offset
needsInitCheck bool
isMessageSet bool
-
- extensionFieldInfosMu sync.RWMutex
- extensionFieldInfos map[pref.ExtensionType]*extensionFieldInfo
}
type coderFieldInfo struct {
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
index d78afeb..e917c7c 100644
--- a/internal/impl/codec_messageset.go
+++ b/internal/impl/codec_messageset.go
@@ -20,7 +20,7 @@
ext := *p.Apply(mi.extensionOffset).Extensions()
for _, x := range ext {
- xi := mi.extensionFieldInfo(x.Type())
+ xi := getExtensionFieldInfo(x.Type())
if xi.funcs.size == nil {
continue
}
@@ -79,7 +79,7 @@
}
func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
- xi := mi.extensionFieldInfo(x.Type())
+ xi := getExtensionFieldInfo(x.Type())
num, _ := wire.DecodeTag(xi.wiretag)
b = messageset.AppendFieldStart(b, num)
b, err := xi.funcs.marshal(b, x.Value(), wire.EncodeTag(messageset.FieldMessage, wire.BytesType), opts)
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 9ccfa7f..e48081e 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -154,7 +154,7 @@
return 0, err
}
}
- xi := mi.extensionFieldInfo(xt)
+ xi := getExtensionFieldInfo(xt)
if xi.funcs.unmarshal == nil {
return 0, errUnknown
}
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index cd57998..4ce3b1d 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -155,7 +155,7 @@
return 0
}
for _, x := range *ext {
- xi := mi.extensionFieldInfo(x.Type())
+ xi := getExtensionFieldInfo(x.Type())
if xi.funcs.size == nil {
continue
}
@@ -176,7 +176,7 @@
// Fast-path for one extension: Don't bother sorting the keys.
var err error
for _, x := range *ext {
- xi := mi.extensionFieldInfo(x.Type())
+ xi := getExtensionFieldInfo(x.Type())
b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
}
return b, err
@@ -191,7 +191,7 @@
var err error
for _, k := range keys {
x := (*ext)[int32(k)]
- xi := mi.extensionFieldInfo(x.Type())
+ xi := getExtensionFieldInfo(x.Type())
b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
if err != nil {
return b, err
diff --git a/internal/impl/extension.go b/internal/impl/extension.go
index 541f4a7..0619d5d 100644
--- a/internal/impl/extension.go
+++ b/internal/impl/extension.go
@@ -38,6 +38,7 @@
goType reflect.Type
desc extensionTypeDescriptor
conv Converter
+ info *extensionFieldInfo // for fast-path method implementations
// ExtendedType is a typed nil-pointer to the parent message type that
// is being extended. It is possible for this to be unpopulated in v2
@@ -136,7 +137,8 @@
if xi.ExtensionType == nil {
xi.initToLegacy()
}
- xi.conv = NewConverter(xi.goType, xi.desc)
+ xi.conv = NewConverter(xi.goType, xi.desc.ExtensionDescriptor)
+ xi.info = makeExtensionFieldInfo(xi.desc.ExtensionDescriptor)
}
}
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index 3fd3674..d3a01bc 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -66,7 +66,7 @@
return nil
}
for _, x := range *ext {
- ei := mi.extensionFieldInfo(x.Type())
+ ei := getExtensionFieldInfo(x.Type())
if ei.funcs.isInit == nil {
continue
}