internal/impl: support legacy unknown extension fields
The unknown fields in legacy messages is split across the XXX_unrecognized
field and also the XXX_InternalExtensions field. Implement support for
wrapping both fields and presenting it as if it were a unified set of
unknown fields.
Change-Id: If274fae2b48962520edd8a640080b6eced747684
Reviewed-on: https://go-review.googlesource.com/c/146517
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
new file mode 100644
index 0000000..b229581
--- /dev/null
+++ b/internal/impl/legacy_extension.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "reflect"
+ "sync"
+ "unsafe"
+
+ protoV1 "github.com/golang/protobuf/proto"
+ pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// TODO: The logic below this is a hack since v1 currently exposes no
+// exported functionality for interacting with these data structures.
+// Eventually make changes to v1 such that v2 can access the necessary
+// fields without relying on unsafe.
+
+var (
+ extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
+ extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
+)
+
+type legacyExtensionIface interface {
+ Len() int
+ Get(pref.FieldNumber) legacyExtensionEntry
+ Set(pref.FieldNumber, legacyExtensionEntry)
+ Range(f func(pref.FieldNumber, legacyExtensionEntry) bool)
+}
+
+func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) legacyExtensionIface {
+ fx1, _ := t.FieldByName("XXX_extensions")
+ fx2, _ := t.FieldByName("XXX_InternalExtensions")
+ switch {
+ case fx1.Type == extTypeA:
+ return func(p *messageDataType) legacyExtensionIface {
+ rv := p.p.asType(t).Elem()
+ return (*legacyExtensionMap)(unsafe.Pointer(rv.UnsafeAddr() + fx1.Offset))
+ }
+ case fx2.Type == extTypeB:
+ return func(p *messageDataType) legacyExtensionIface {
+ rv := p.p.asType(t).Elem()
+ return (*legacyExtensionSyncMap)(unsafe.Pointer(rv.UnsafeAddr() + fx2.Offset))
+ }
+ default:
+ return nil
+ }
+}
+
+// legacyExtensionSyncMap is identical to protoV1.XXX_InternalExtensions.
+// It implements legacyExtensionIface.
+type legacyExtensionSyncMap struct {
+ p *struct {
+ mu sync.Mutex
+ m legacyExtensionMap
+ }
+}
+
+func (m legacyExtensionSyncMap) Len() int {
+ if m.p == nil {
+ return 0
+ }
+ m.p.mu.Lock()
+ defer m.p.mu.Unlock()
+ return m.p.m.Len()
+}
+func (m legacyExtensionSyncMap) Get(n pref.FieldNumber) legacyExtensionEntry {
+ if m.p == nil {
+ return legacyExtensionEntry{}
+ }
+ m.p.mu.Lock()
+ defer m.p.mu.Unlock()
+ return m.p.m.Get(n)
+}
+func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
+ if m.p == nil {
+ m.p = new(struct {
+ mu sync.Mutex
+ m legacyExtensionMap
+ })
+ }
+ m.p.mu.Lock()
+ defer m.p.mu.Unlock()
+ m.p.m.Set(n, x)
+}
+func (m legacyExtensionSyncMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
+ if m.p == nil {
+ return
+ }
+ m.p.mu.Lock()
+ defer m.p.mu.Unlock()
+ m.p.m.Range(f)
+}
+
+// legacyExtensionMap is identical to map[int32]protoV1.Extension.
+// It implements legacyExtensionIface.
+type legacyExtensionMap map[pref.FieldNumber]legacyExtensionEntry
+
+func (m legacyExtensionMap) Len() int {
+ return len(m)
+}
+func (m legacyExtensionMap) Get(n pref.FieldNumber) legacyExtensionEntry {
+ return m[n]
+}
+func (m *legacyExtensionMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
+ if *m == nil {
+ *m = make(map[pref.FieldNumber]legacyExtensionEntry)
+ }
+ (*m)[n] = x
+}
+func (m legacyExtensionMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
+ for n, x := range m {
+ if !f(n, x) {
+ return
+ }
+ }
+}
+
+// legacyExtensionEntry is identical to protoV1.Extension.
+type legacyExtensionEntry struct {
+ desc *protoV1.ExtensionDesc
+ val interface{}
+ raw []byte
+}
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index d2e71d4..caa0043 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -10,6 +10,7 @@
"reflect"
"testing"
+ protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/internal/encoding/pack"
"github.com/golang/protobuf/v2/internal/pragma"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -137,6 +138,15 @@
}
}
+type legacyUnknownMessage struct {
+ XXX_unrecognized []byte
+ protoV1.XXX_InternalExtensions
+}
+
+func (*legacyUnknownMessage) ExtensionRangeArray() []protoV1.ExtensionRange {
+ return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}}
+}
+
func TestLegacyUnknown(t *testing.T) {
rawOf := func(toks ...pack.Token) pref.RawFields {
return pref.RawFields(pack.Message(toks).Marshal())
@@ -149,6 +159,17 @@
raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c
raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef")) // 1a04deadbeef
+ raw1 := rawOf(pack.Tag{1, pack.BytesType}, pack.Bytes("1")) // 0a0131
+ raw3 := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("3")) // 1a0133
+ raw10 := rawOf(pack.Tag{10, pack.BytesType}, pack.Bytes("10")) // 52023130 - extension
+ raw15 := rawOf(pack.Tag{15, pack.BytesType}, pack.Bytes("15")) // 7a023135 - extension
+ raw26 := rawOf(pack.Tag{26, pack.BytesType}, pack.Bytes("26")) // d201023236
+ raw32 := rawOf(pack.Tag{32, pack.BytesType}, pack.Bytes("32")) // 8202023332
+ raw45 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("45")) // ea02023435 - extension
+ raw46 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("46")) // ea02023436 - extension
+ raw47 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("47")) // ea02023437 - extension
+ raw99 := rawOf(pack.Tag{99, pack.BytesType}, pack.Bytes("99")) // 9a06023939
+
joinRaw := func(bs ...pref.RawFields) (out []byte) {
for _, b := range bs {
out = append(out, b...)
@@ -156,11 +177,13 @@
return out
}
- var fs legacyUnknownBytes
+ m := new(legacyUnknownMessage)
+ fs := new(MessageType).MessageOf(m).UnknownFields()
+
if got, want := fs.Len(), 0; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
@@ -170,7 +193,7 @@
if got, want := fs.Len(), 1; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
@@ -178,7 +201,7 @@
if got, want := fs.Len(), 2; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
@@ -196,12 +219,12 @@
if got, want := fs.Len(), 1; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(raw2a); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(raw2a); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
// Simulate manual appending of raw field data.
- fs = append(fs, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...)
+ m.XXX_unrecognized = append(m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...)
if got, want := fs.Len(), 3; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
@@ -232,14 +255,14 @@
if got, want := fs.Len(), 3; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
fs.Set(1, nil) // remove field 1
if got, want := fs.Len(), 2; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
@@ -251,7 +274,102 @@
if got, want := fs.Len(), 0; got != want {
t.Errorf("Len() = %d, want %d", got, want)
}
- if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+ if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
}
+
+ fs.Set(1, raw1)
+ if got, want := fs.Len(), 1; got != want {
+ t.Errorf("Len() = %d, want %d", got, want)
+ }
+ if got, want := m.XXX_unrecognized, joinRaw(raw1); !bytes.Equal(got, want) {
+ t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
+ }
+
+ fs.Set(45, raw45)
+ fs.Set(10, raw10) // extension
+ fs.Set(32, raw32)
+ fs.Set(1, nil) // deletion
+ fs.Set(26, raw26)
+ fs.Set(47, raw47) // extension
+ fs.Set(46, raw46) // extension
+ if got, want := fs.Len(), 6; got != want {
+ t.Errorf("Len() = %d, want %d", got, want)
+ }
+ if got, want := m.XXX_unrecognized, joinRaw(raw32, raw26); !bytes.Equal(got, want) {
+ t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
+ }
+
+ // Verify iteration order.
+ i = 0
+ want = []struct {
+ num pref.FieldNumber
+ raw pref.RawFields
+ }{
+ {32, raw32},
+ {26, raw26},
+ {10, raw10}, // extension
+ {45, raw45}, // extension
+ {46, raw46}, // extension
+ {47, raw47}, // extension
+ }
+ fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+ if i < len(want) {
+ if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+ t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+ }
+ } else {
+ t.Errorf("unexpected Range iteration: %d", i)
+ }
+ i++
+ return true
+ })
+
+ // Perform partial deletion while iterating.
+ i = 0
+ fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+ if i%2 == 0 {
+ fs.Set(num, nil)
+ }
+ i++
+ return true
+ })
+
+ if got, want := fs.Len(), 3; got != want {
+ t.Errorf("Len() = %d, want %d", got, want)
+ }
+ if got, want := m.XXX_unrecognized, joinRaw(raw26); !bytes.Equal(got, want) {
+ t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
+ }
+
+ fs.Set(15, raw15) // extension
+ fs.Set(3, raw3)
+ fs.Set(99, raw99)
+ if got, want := fs.Len(), 6; got != want {
+ t.Errorf("Len() = %d, want %d", got, want)
+ }
+ if got, want := m.XXX_unrecognized, joinRaw(raw26, raw3, raw99); !bytes.Equal(got, want) {
+ t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want)
+ }
+
+ // Perform partial iteration.
+ i = 0
+ want = []struct {
+ num pref.FieldNumber
+ raw pref.RawFields
+ }{
+ {26, raw26},
+ {3, raw3},
+ }
+ fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+ if i < len(want) {
+ if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+ t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+ }
+ } else {
+ t.Errorf("unexpected Range iteration: %d", i)
+ }
+ i++
+ return i < 2
+ })
}
diff --git a/internal/impl/legacy_unknown.go b/internal/impl/legacy_unknown.go
index a319f05..9ab617b 100644
--- a/internal/impl/legacy_unknown.go
+++ b/internal/impl/legacy_unknown.go
@@ -7,36 +7,104 @@
import (
"container/list"
"reflect"
+ "sort"
- protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/internal/encoding/wire"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
)
-var (
- extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
- extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
-)
-
-func generateLegacyUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) func(p *messageDataType) pref.UnknownFields {
+func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.UnknownFields {
fu, ok := t.FieldByName("XXX_unrecognized")
if !ok || fu.Type != bytesType {
return nil
}
- fx1, _ := t.FieldByName("XXX_extensions")
- fx2, _ := t.FieldByName("XXX_InternalExtensions")
- if fx1.Type == extTypeA || fx2.Type == extTypeB {
- // TODO: In proto v1, the unknown fields are split between both
- // XXX_unrecognized and XXX_InternalExtensions. If the message supports
- // extensions, then we will need to create a wrapper data structure
- // that presents unknown fields in both lists as a single ordered list.
- panic("not implemented")
- }
fieldOffset := offsetOf(fu)
- return func(p *messageDataType) pref.UnknownFields {
+ unkFunc := func(p *messageDataType) pref.UnknownFields {
rv := p.p.apply(fieldOffset).asType(bytesType)
return (*legacyUnknownBytes)(rv.Interface().(*[]byte))
}
+ extFunc := makeLegacyExtensionMapFunc(t)
+ if extFunc != nil {
+ return func(p *messageDataType) pref.UnknownFields {
+ return &legacyUnknownBytesAndExtensionMap{
+ unkFunc(p), extFunc(p), p.mi.Desc.ExtensionRanges(),
+ }
+ }
+ }
+ return unkFunc
+}
+
+// legacyUnknownBytesAndExtensionMap is a wrapper around both XXX_unrecognized
+// and also the extension field map.
+type legacyUnknownBytesAndExtensionMap struct {
+ u pref.UnknownFields
+ x legacyExtensionIface
+ r pref.FieldRanges
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Len() int {
+ n := fs.u.Len()
+ fs.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
+ if len(x.raw) > 0 {
+ n++
+ }
+ return true
+ })
+ return n
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Get(num pref.FieldNumber) (raw pref.RawFields) {
+ if fs.r.Has(num) {
+ return fs.x.Get(num).raw
+ }
+ return fs.u.Get(num)
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Set(num pref.FieldNumber, raw pref.RawFields) {
+ if fs.r.Has(num) {
+ x := fs.x.Get(num)
+ x.raw = raw
+ fs.x.Set(num, x)
+ return
+ }
+ fs.u.Set(num, raw)
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Range(f func(pref.FieldNumber, pref.RawFields) bool) {
+ // Range over unknown fields not in the extension range.
+ // Create a closure around f to capture whether iteration terminated early.
+ var stop bool
+ fs.u.Range(func(n pref.FieldNumber, b pref.RawFields) bool {
+ stop = stop || !f(n, b)
+ return !stop
+ })
+ if stop {
+ return
+ }
+
+ // Range over unknown fields in the extension range in ascending order
+ // to ensure protoreflect.UnknownFields.Range remains deterministic.
+ type entry struct {
+ num pref.FieldNumber
+ raw pref.RawFields
+ }
+ var xs []entry
+ fs.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool {
+ if len(x.raw) > 0 {
+ xs = append(xs, entry{n, x.raw})
+ }
+ return true
+ })
+ sort.Slice(xs, func(i, j int) bool { return xs[i].num < xs[j].num })
+ for _, x := range xs {
+ if !f(x.num, x.raw) {
+ return
+ }
+ }
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) IsSupported() bool {
+ return true
}
// legacyUnknownBytes is a wrapper around XXX_unrecognized that implements
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 2b0a9ca..a3e4b9d 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -53,9 +53,8 @@
mi.goType = t
// Derive the message descriptor if unspecified.
- md := mi.Desc
- if md == nil {
- // TODO: derive the message type from the Go struct type
+ if mi.Desc == nil {
+ mi.Desc = loadMessageDesc(t)
}
// Initialize the Go message type wrapper if the Go type does not
@@ -68,7 +67,7 @@
// Generated code ensures that this property holds.
if _, ok := p.(pref.ProtoMessage); !ok {
mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{
- MessageDescriptor: md,
+ MessageDescriptor: mi.Desc,
New: func(pref.MessageType) pref.ProtoMessage {
p := reflect.New(t.Elem()).Interface()
return (*message)(mi.dataTypeOf(p))
@@ -76,9 +75,9 @@
})
}
- mi.generateKnownFieldFuncs(t.Elem(), md)
- mi.generateUnknownFieldFuncs(t.Elem(), md)
- mi.generateExtensionFieldFuncs(t.Elem(), md)
+ mi.makeKnownFieldsFunc(t.Elem())
+ mi.makeUnknownFieldsFunc(t.Elem())
+ mi.makeExtensionFieldsFunc(t.Elem())
})
// TODO: Remove this check? This API is primarily used by generated code,
@@ -90,14 +89,14 @@
}
}
-// generateKnownFieldFuncs generates per-field functions for all operations
+// makeKnownFieldsFunc generates per-field functions for all operations
// to be performed on each field. It takes in a reflect.Type representing the
// Go struct, and a protoreflect.MessageDescriptor to match with the fields
// in the struct.
//
// This code assumes that the struct is well-formed and panics if there are
// any discrepancies.
-func (mi *MessageType) generateKnownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
+func (mi *MessageType) makeKnownFieldsFunc(t reflect.Type) {
// Generate a mapping of field numbers and names to Go struct field or type.
fields := map[pref.FieldNumber]reflect.StructField{}
oneofs := map[pref.Name]reflect.StructField{}
@@ -140,8 +139,8 @@
}
mi.fields = map[pref.FieldNumber]*fieldInfo{}
- for i := 0; i < md.Fields().Len(); i++ {
- fd := md.Fields().Get(i)
+ for i := 0; i < mi.Desc.Fields().Len(); i++ {
+ fd := mi.Desc.Fields().Get(i)
fs := fields[fd.Number()]
var fi fieldInfo
switch {
@@ -162,8 +161,8 @@
}
}
-func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
- if f := generateLegacyUnknownFieldFuncs(t, md); f != nil {
+func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) {
+ if f := makeLegacyUnknownFieldsFunc(t); f != nil {
mi.unknownFields = f
return
}
@@ -172,7 +171,7 @@
}
}
-func (mi *MessageType) generateExtensionFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
+func (mi *MessageType) makeExtensionFieldsFunc(t reflect.Type) {
// TODO
mi.extensionFields = func(*messageDataType) pref.KnownFields {
return emptyExtensionFields{}