internal/impl: add runtime support for *[]byte unknown representation
This CL adds runtime support for unknown fields to be represented
as *[]byte in addition to the current representation as []byte.
This CL does not change generated code to use *[]byte.
Comparison between using *[]byte and []byte:
• Every message supports unknown fields, so use of []byte
expands a message size by 24B (for 64-bit systems).
In contrast, *[]byte only expands a message by 8B.
This has significant memory implications for small messages.
• If unknown fields are encountered, *[]byte has extra overhead
allocating the 24B slice header. However, it is assumed
that messages rarely see any unknown fields at runtime,
or generally do so for a temporary period of time.
Change-Id: I81935e4ea7394166e61ff4579f76f59fa792dfc9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/244937
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 733ad9f..e108642 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -27,6 +27,7 @@
coderFields map[protowire.Number]*coderFieldInfo
sizecacheOffset offset
unknownOffset offset
+ unknownPtrKind bool
extensionOffset offset
needsInitCheck bool
isMessageSet bool
@@ -47,9 +48,20 @@
}
func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
- mi.sizecacheOffset = si.sizecacheOffset
- mi.unknownOffset = si.unknownOffset
- mi.extensionOffset = si.extensionOffset
+ mi.sizecacheOffset = invalidOffset
+ mi.unknownOffset = invalidOffset
+ mi.extensionOffset = invalidOffset
+
+ if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
+ mi.sizecacheOffset = si.sizecacheOffset
+ }
+ if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
+ mi.unknownOffset = si.unknownOffset
+ mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
+ }
+ if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
+ mi.extensionOffset = si.extensionOffset
+ }
mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
fields := mi.Desc.Fields()
@@ -157,3 +169,28 @@
mi.methods.Merge = mi.merge
}
}
+
+// getUnknownBytes returns a *[]byte for the unknown fields.
+// It is the caller's responsibility to check whether the pointer is nil.
+// This function is specially designed to be inlineable.
+func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
+ if mi.unknownPtrKind {
+ return *p.Apply(mi.unknownOffset).BytesPtr()
+ } else {
+ return p.Apply(mi.unknownOffset).Bytes()
+ }
+}
+
+// mutableUnknownBytes returns a *[]byte for the unknown fields.
+// The returned pointer is guaranteed to not be nil.
+func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
+ if mi.unknownPtrKind {
+ bp := p.Apply(mi.unknownOffset).BytesPtr()
+ if *bp == nil {
+ *bp = new([]byte)
+ }
+ return *bp
+ } else {
+ return p.Apply(mi.unknownOffset).Bytes()
+ }
+}
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
index cfb68e1..b7a23fa 100644
--- a/internal/impl/codec_messageset.go
+++ b/internal/impl/codec_messageset.go
@@ -29,8 +29,9 @@
size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts)
}
- unknown := *p.Apply(mi.unknownOffset).Bytes()
- size += messageset.SizeUnknown(unknown)
+ if u := mi.getUnknownBytes(p); u != nil {
+ size += messageset.SizeUnknown(*u)
+ }
return size
}
@@ -69,10 +70,12 @@
}
}
- unknown := *p.Apply(mi.unknownOffset).Bytes()
- b, err := messageset.AppendUnknown(b, unknown)
- if err != nil {
- return b, err
+ if u := mi.getUnknownBytes(p); u != nil {
+ var err error
+ b, err = messageset.AppendUnknown(b, *u)
+ if err != nil {
+ return b, err
+ }
}
return b, nil
@@ -100,13 +103,13 @@
*ep = make(map[int32]ExtensionField)
}
ext := *ep
- unknown := p.Apply(mi.unknownOffset).Bytes()
initialized := true
err = messageset.Unmarshal(b, true, func(num protowire.Number, v []byte) error {
o, err := mi.unmarshalExtension(v, num, protowire.BytesType, ext, opts)
if err == errUnknown {
- *unknown = protowire.AppendTag(*unknown, num, protowire.BytesType)
- *unknown = append(*unknown, v...)
+ u := mi.mutableUnknownBytes(p)
+ *u = protowire.AppendTag(*u, num, protowire.BytesType)
+ *u = append(*u, v...)
return nil
}
if !o.initialized {
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 2dcbdec..949dc49 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -175,7 +175,7 @@
return out, errDecode
}
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
- u := p.Apply(mi.unknownOffset).Bytes()
+ u := mi.mutableUnknownBytes(p)
*u = protowire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...)
}
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 8c8a794..845c67d 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -79,8 +79,9 @@
size += f.funcs.size(fptr, f, opts)
}
if mi.unknownOffset.IsValid() {
- u := *p.Apply(mi.unknownOffset).Bytes()
- size += len(u)
+ if u := mi.getUnknownBytes(p); u != nil {
+ size += len(*u)
+ }
}
if mi.sizecacheOffset.IsValid() {
if size > math.MaxInt32 {
@@ -141,8 +142,9 @@
}
}
if mi.unknownOffset.IsValid() && !mi.isMessageSet {
- u := *p.Apply(mi.unknownOffset).Bytes()
- b = append(b, u...)
+ if u := mi.getUnknownBytes(p); u != nil {
+ b = append(b, (*u)...)
+ }
}
return b, nil
}
diff --git a/internal/impl/merge.go b/internal/impl/merge.go
index cdc4267..c65bbc0 100644
--- a/internal/impl/merge.go
+++ b/internal/impl/merge.go
@@ -77,9 +77,9 @@
}
}
if mi.unknownOffset.IsValid() {
- du := dst.Apply(mi.unknownOffset).Bytes()
- su := src.Apply(mi.unknownOffset).Bytes()
- if len(*su) > 0 {
+ su := mi.getUnknownBytes(src)
+ if su != nil && len(*su) > 0 {
+ du := mi.mutableUnknownBytes(dst)
*du = append(*du, *su...)
}
}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c8c3859..a104e28 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -110,22 +110,29 @@
type (
SizeCache = int32
WeakFields = map[int32]protoreflect.ProtoMessage
- UnknownFields = []byte
+ UnknownFields = unknownFieldsA // TODO: switch to unknownFieldsB
+ unknownFieldsA = []byte
+ unknownFieldsB = *[]byte
ExtensionFields = map[int32]ExtensionField
)
var (
sizecacheType = reflect.TypeOf(SizeCache(0))
weakFieldsType = reflect.TypeOf(WeakFields(nil))
- unknownFieldsType = reflect.TypeOf(UnknownFields(nil))
+ unknownFieldsAType = reflect.TypeOf(unknownFieldsA(nil))
+ unknownFieldsBType = reflect.TypeOf(unknownFieldsB(nil))
extensionFieldsType = reflect.TypeOf(ExtensionFields(nil))
)
type structInfo struct {
sizecacheOffset offset
+ sizecacheType reflect.Type
weakOffset offset
+ weakType reflect.Type
unknownOffset offset
+ unknownType reflect.Type
extensionOffset offset
+ extensionType reflect.Type
fieldsByNumber map[pref.FieldNumber]reflect.StructField
oneofsByName map[pref.Name]reflect.StructField
@@ -152,18 +159,22 @@
case genid.SizeCache_goname, genid.SizeCacheA_goname:
if f.Type == sizecacheType {
si.sizecacheOffset = offsetOf(f, mi.Exporter)
+ si.sizecacheType = f.Type
}
case genid.WeakFields_goname, genid.WeakFieldsA_goname:
if f.Type == weakFieldsType {
si.weakOffset = offsetOf(f, mi.Exporter)
+ si.weakType = f.Type
}
case genid.UnknownFields_goname, genid.UnknownFieldsA_goname:
- if f.Type == unknownFieldsType {
+ if f.Type == unknownFieldsAType || f.Type == unknownFieldsBType {
si.unknownOffset = offsetOf(f, mi.Exporter)
+ si.unknownType = f.Type
}
case genid.ExtensionFields_goname, genid.ExtensionFieldsA_goname, genid.ExtensionFieldsB_goname:
if f.Type == extensionFieldsType {
si.extensionOffset = offsetOf(f, mi.Exporter)
+ si.extensionType = f.Type
}
default:
for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 41cd090..0c6f106 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -108,24 +108,44 @@
}
func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
- mi.getUnknown = func(pointer) pref.RawFields { return nil }
- mi.setUnknown = func(pointer, pref.RawFields) { return }
- if si.unknownOffset.IsValid() {
+ switch {
+ case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
+ // Handle as []byte.
mi.getUnknown = func(p pointer) pref.RawFields {
if p.IsNil() {
return nil
}
- rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
- return pref.RawFields(*rv.Interface().(*[]byte))
+ return *p.Apply(mi.unknownOffset).Bytes()
}
mi.setUnknown = func(p pointer, b pref.RawFields) {
if p.IsNil() {
panic("invalid SetUnknown on nil Message")
}
- rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
- *rv.Interface().(*[]byte) = []byte(b)
+ *p.Apply(mi.unknownOffset).Bytes() = b
}
- } else {
+ case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
+ // Handle as *[]byte.
+ mi.getUnknown = func(p pointer) pref.RawFields {
+ if p.IsNil() {
+ return nil
+ }
+ bp := p.Apply(mi.unknownOffset).BytesPtr()
+ if *bp == nil {
+ return nil
+ }
+ return **bp
+ }
+ mi.setUnknown = func(p pointer, b pref.RawFields) {
+ if p.IsNil() {
+ panic("invalid SetUnknown on nil Message")
+ }
+ bp := p.Apply(mi.unknownOffset).BytesPtr()
+ if *bp == nil {
+ *bp = new([]byte)
+ }
+ **bp = b
+ }
+ default:
mi.getUnknown = func(pointer) pref.RawFields {
return nil
}
diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go
index 66c2536..96c7e1b 100644
--- a/internal/impl/message_reflect_test.go
+++ b/internal/impl/message_reflect_test.go
@@ -22,6 +22,7 @@
pdesc "google.golang.org/protobuf/reflect/protodesc"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
+ "google.golang.org/protobuf/testing/protopack"
proto2_20180125 "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20180125_92554152"
testpb "google.golang.org/protobuf/internal/testprotos/test"
@@ -1425,6 +1426,43 @@
return strings.Join(ss, ".")
}
+type UnknownFieldsA struct {
+ XXX_unrecognized []byte
+}
+
+var unknownFieldsAType = pimpl.MessageInfo{
+ GoReflectType: reflect.TypeOf(new(UnknownFieldsA)),
+ Desc: mustMakeMessageDesc("unknown.proto", pref.Proto2, "", `name: "UnknownFieldsA"`, nil),
+}
+
+func (m *UnknownFieldsA) ProtoReflect() pref.Message { return unknownFieldsAType.MessageOf(m) }
+
+type UnknownFieldsB struct {
+ XXX_unrecognized *[]byte
+}
+
+var unknownFieldsBType = pimpl.MessageInfo{
+ GoReflectType: reflect.TypeOf(new(UnknownFieldsB)),
+ Desc: mustMakeMessageDesc("unknown.proto", pref.Proto2, "", `name: "UnknownFieldsB"`, nil),
+}
+
+func (m *UnknownFieldsB) ProtoReflect() pref.Message { return unknownFieldsBType.MessageOf(m) }
+
+func TestUnknownFields(t *testing.T) {
+ for _, m := range []proto.Message{new(UnknownFieldsA), new(UnknownFieldsB)} {
+ t.Run(reflect.TypeOf(m).Elem().Name(), func(t *testing.T) {
+ want := protopack.Message{
+ protopack.Tag{1, protopack.BytesType}, protopack.String("Hello, world!"),
+ }.Marshal()
+ m.ProtoReflect().SetUnknown(want)
+ got := []byte(m.ProtoReflect().GetUnknown())
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("UnknownFields mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
func TestReset(t *testing.T) {
mi := new(testpb.TestAllTypes)
diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go
index 67b4ede..9e3ed82 100644
--- a/internal/impl/pointer_reflect.go
+++ b/internal/impl/pointer_reflect.go
@@ -121,6 +121,7 @@
func (p pointer) StringPtr() **string { return p.v.Interface().(**string) }
func (p pointer) StringSlice() *[]string { return p.v.Interface().(*[]string) }
func (p pointer) Bytes() *[]byte { return p.v.Interface().(*[]byte) }
+func (p pointer) BytesPtr() **[]byte { return p.v.Interface().(**[]byte) }
func (p pointer) BytesSlice() *[][]byte { return p.v.Interface().(*[][]byte) }
func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.v.Interface().(*WeakFields)) }
func (p pointer) Extensions() *map[int32]ExtensionField {
diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go
index 088aa85..9ecf23a 100644
--- a/internal/impl/pointer_unsafe.go
+++ b/internal/impl/pointer_unsafe.go
@@ -109,6 +109,7 @@
func (p pointer) StringPtr() **string { return (**string)(p.p) }
func (p pointer) StringSlice() *[]string { return (*[]string)(p.p) }
func (p pointer) Bytes() *[]byte { return (*[]byte)(p.p) }
+func (p pointer) BytesPtr() **[]byte { return (**[]byte)(p.p) }
func (p pointer) BytesSlice() *[][]byte { return (*[][]byte)(p.p) }
func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.p) }
func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) }