testing/protopack: add Message.UnmarshalAbductive
The protobuf wire format is insufficiently self-decribing such that
it is impossible to know whether for sure whether an unknown bytes value
is a sub-message or not. However, protopack is primarily used for debugging
where a best-effort guess is still very useful.
The Message.UnmarshalAbductive unmarshals an unknown bytes value as a message
if it is syntactically well-formed. Otherwise, it is left as is.
Change-Id: I5e2b4b995e2b5eb60942a242558bf4cea1da9891
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/309669
Trust: Joe Tsai <joetsai@digital-static.net>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/pbdump/pbdump.go b/internal/cmd/pbdump/pbdump.go
index 87d6030..befb831 100644
--- a/internal/cmd/pbdump/pbdump.go
+++ b/internal/cmd/pbdump/pbdump.go
@@ -128,7 +128,7 @@
// Parse and print message structure.
defer log.Printf("fatal input: %q", buf) // debug printout if panic occurs
var m protopack.Message
- m.UnmarshalDescriptor(buf, desc)
+ m.UnmarshalAbductive(buf, desc)
if *printSource {
fmt.Printf("%#v\n", m)
} else {
diff --git a/testing/protopack/pack.go b/testing/protopack/pack.go
index d39593a..683ce0b 100644
--- a/testing/protopack/pack.go
+++ b/testing/protopack/pack.go
@@ -270,7 +270,7 @@
//
// Unmarshal is useful for debugging the protobuf wire format.
func (m *Message) Unmarshal(in []byte) {
- m.UnmarshalDescriptor(in, nil)
+ m.unmarshal(in, nil, false)
}
// UnmarshalDescriptor parses the input protobuf wire data as a syntax tree
@@ -289,22 +289,40 @@
// Known sub-messages are parsed as a Message and packed repeated fields are
// parsed as a LengthPrefix.
func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) {
+ m.unmarshal(in, desc, false)
+}
+
+// UnmarshalAbductive is like UnmarshalDescriptor, but infers abductively
+// whether any unknown bytes values is a message based on whether it is
+// a syntactically well-formed message.
+//
+// Note that the protobuf wire format is not fully self-describing,
+// so abductive inference may attempt to expand a bytes value as a message
+// that is not actually a message. It is a best-effort guess.
+func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) {
+ m.unmarshal(in, desc, true)
+}
+
+func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) {
p := parser{in: in, out: *m}
- p.parseMessage(desc, false)
+ p.parseMessage(desc, false, inferMessage)
*m = p.out
}
type parser struct {
in []byte
out []Token
+
+ invalid bool
}
-func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool) {
+func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) {
for len(p.in) > 0 {
v, n := protowire.ConsumeVarint(p.in)
num, typ := protowire.DecodeTag(v)
- if n < 0 || num < 0 || v > math.MaxUint32 {
+ if n < 0 || num <= 0 || v > math.MaxUint32 {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
if typ == EndGroupType && group {
@@ -341,13 +359,14 @@
case Fixed64Type:
p.parseFixed64(kind)
case BytesType:
- p.parseBytes(isPacked, kind, subDesc)
+ p.parseBytes(isPacked, kind, subDesc, inferMessage)
case StartGroupType:
- p.parseGroup(subDesc)
+ p.parseGroup(num, subDesc, inferMessage)
case EndGroupType:
- // Handled above.
+ // Handled by p.parseGroup.
default:
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
}
}
}
@@ -356,6 +375,7 @@
v, n := protowire.ConsumeVarint(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
switch kind {
@@ -384,6 +404,7 @@
v, n := protowire.ConsumeFixed32(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
switch kind {
@@ -400,6 +421,7 @@
v, n := protowire.ConsumeFixed64(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
switch kind {
@@ -412,10 +434,11 @@
}
}
-func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor) {
+func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) {
v, n := protowire.ConsumeVarint(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
@@ -424,6 +447,7 @@
}
if v > uint64(len(p.in)) {
p.out, p.in = append(p.out, Raw(p.in)), nil
+ p.invalid = true
return
}
p.out = p.out[:len(p.out)-1] // subsequent tokens contain prefix-length
@@ -434,11 +458,22 @@
switch kind {
case protoreflect.MessageKind:
p2 := parser{in: p.in[:v]}
- p2.parseMessage(desc, false)
+ p2.parseMessage(desc, false, inferMessage)
p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
case protoreflect.StringKind:
p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:]
+ case protoreflect.BytesKind:
+ p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
default:
+ if inferMessage {
+ // Check whether this is a syntactically valid message.
+ p2 := parser{in: p.in[:v]}
+ p2.parseMessage(nil, false, inferMessage)
+ if !p2.invalid {
+ p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
+ break
+ }
+ }
p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
}
}
@@ -466,9 +501,9 @@
p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:]
}
-func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) {
+func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) {
p2 := parser{in: p.in}
- p2.parseMessage(desc, true)
+ p2.parseMessage(desc, true, inferMessage)
if len(p2.out) > 0 {
p.out = append(p.out, Message(p2.out))
}
@@ -476,8 +511,11 @@
// Append the trailing end group.
v, n := protowire.ConsumeVarint(p.in)
- if num, typ := protowire.DecodeTag(v); typ == EndGroupType {
- p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:]
+ if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType {
+ if startNum != endNum {
+ p.invalid = true
+ }
+ p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:]
if m := n - protowire.SizeVarint(v); m > 0 {
p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
}
diff --git a/testing/protopack/pack_test.go b/testing/protopack/pack_test.go
index 9752549..61ea336 100644
--- a/testing/protopack/pack_test.go
+++ b/testing/protopack/pack_test.go
@@ -15,6 +15,7 @@
"google.golang.org/protobuf/encoding/prototext"
pdesc "google.golang.org/protobuf/reflect/protodesc"
+ "google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
@@ -67,8 +68,10 @@
func TestPack(t *testing.T) {
tests := []struct {
- raw []byte
- msg Message
+ raw []byte
+ msg Message
+ msgDesc protoreflect.MessageDescriptor
+ inferMsg bool
wantOutCompact string
wantOutMulti string
@@ -81,6 +84,7 @@
Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
Tag{1, BytesType}, LengthPrefix{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}},
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{1, protopack.VarintType}, protopack.Bool(false),
protopack.Denormalized{+5, protopack.Tag{1, protopack.VarintType}}, protopack.Uvarint(2),
@@ -88,12 +92,22 @@
protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix{protopack.Bool(true), protopack.Bool(false), protopack.Uvarint(2), protopack.Denormalized{+5, protopack.Uvarint(2)}},
}`,
}, {
+ raw: dhex("080088808080800002088280808080000a09010002828080808000"),
+ msg: Message{
+ Tag{1, VarintType}, Uvarint(0),
+ Denormalized{5, Tag{1, VarintType}}, Uvarint(2),
+ Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
+ Tag{1, BytesType}, Bytes(Message{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}}.Marshal()),
+ },
+ inferMsg: true,
+ }, {
raw: dhex("100010828080808000121980808080808080808001ffffffffffffffff7f828080808000"),
msg: Message{
Tag{2, VarintType}, Varint(0),
Tag{2, VarintType}, Denormalized{5, Varint(2)},
Tag{2, BytesType}, LengthPrefix{Varint(math.MinInt64), Varint(math.MaxInt64), Denormalized{5, Varint(2)}},
},
+ msgDesc: msgDesc,
wantOutCompact: `Message{Tag{2, Varint}, Varint(0), Tag{2, Varint}, Denormalized{+5, Varint(2)}, Tag{2, Bytes}, LengthPrefix{Varint(-9223372036854775808), Varint(9223372036854775807), Denormalized{+5, Varint(2)}}}`,
}, {
raw: dhex("1801188180808080001a1affffffffffffffffff01feffffffffffffffff01818080808000"),
@@ -102,6 +116,7 @@
Tag{3, VarintType}, Denormalized{5, Svarint(-1)},
Tag{3, BytesType}, LengthPrefix{Svarint(math.MinInt64), Svarint(math.MaxInt64), Denormalized{5, Svarint(-1)}},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{3, Varint}, Svarint(-1),
Tag{3, Varint}, Denormalized{+5, Svarint(-1)},
@@ -114,6 +129,7 @@
Tag{4, VarintType}, Denormalized{5, Uvarint(+1)},
Tag{4, BytesType}, LengthPrefix{Uvarint(0), Uvarint(math.MaxUint64), Denormalized{5, Uvarint(+1)}},
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{4, protopack.VarintType}, protopack.Uvarint(1),
protopack.Tag{4, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(1)},
@@ -125,6 +141,7 @@
Tag{5, Fixed32Type}, Uint32(+1),
Tag{5, BytesType}, LengthPrefix{Uint32(0), Uint32(math.MaxUint32)},
},
+ msgDesc: msgDesc,
wantOutCompact: `Message{Tag{5, Fixed32}, Uint32(1), Tag{5, Bytes}, LengthPrefix{Uint32(0), Uint32(4294967295)}}`,
}, {
raw: dhex("35ffffffff320800000080ffffff7f"),
@@ -132,6 +149,7 @@
Tag{6, Fixed32Type}, Int32(-1),
Tag{6, BytesType}, LengthPrefix{Int32(math.MinInt32), Int32(math.MaxInt32)},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{6, Fixed32}, Int32(-1),
Tag{6, Bytes}, LengthPrefix{Int32(-2147483648), Int32(2147483647)},
@@ -142,6 +160,7 @@
Tag{7, Fixed32Type}, Float32(math.Pi),
Tag{7, BytesType}, LengthPrefix{Float32(math.SmallestNonzeroFloat32), Float32(math.MaxFloat32), Float32(math.Inf(+1)), Float32(math.Inf(-1))},
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(3.1415927),
protopack.Tag{7, protopack.BytesType}, protopack.LengthPrefix{protopack.Float32(1e-45), protopack.Float32(3.4028235e+38), protopack.Float32(math.Inf(+1)), protopack.Float32(math.Inf(-1))},
@@ -152,6 +171,7 @@
Tag{8, Fixed64Type}, Uint64(+1),
Tag{8, BytesType}, LengthPrefix{Uint64(0), Uint64(math.MaxUint64)},
},
+ msgDesc: msgDesc,
wantOutCompact: `Message{Tag{8, Fixed64}, Uint64(1), Tag{8, Bytes}, LengthPrefix{Uint64(0), Uint64(18446744073709551615)}}`,
}, {
raw: dhex("49ffffffffffffffff4a100000000000000080ffffffffffffff7f"),
@@ -159,6 +179,7 @@
Tag{9, Fixed64Type}, Int64(-1),
Tag{9, BytesType}, LengthPrefix{Int64(math.MinInt64), Int64(math.MaxInt64)},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{9, Fixed64}, Int64(-1),
Tag{9, Bytes}, LengthPrefix{Int64(-9223372036854775808), Int64(9223372036854775807)},
@@ -169,6 +190,7 @@
Tag{10, Fixed64Type}, Float64(math.Pi),
Tag{10, BytesType}, LengthPrefix{Float64(math.SmallestNonzeroFloat64), Float64(math.MaxFloat64), Float64(math.Inf(+1)), Float64(math.Inf(-1))},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{10, Fixed64}, Float64(3.141592653589793),
Tag{10, Bytes}, LengthPrefix{Float64(5e-324), Float64(1.7976931348623157e+308), Float64(+Inf), Float64(-Inf)},
@@ -179,6 +201,7 @@
Tag{11, BytesType}, String("string"),
Tag{11, BytesType}, Denormalized{+5, String("string")},
},
+ msgDesc: msgDesc,
wantOutCompact: `Message{Tag{11, Bytes}, String("string"), Tag{11, Bytes}, Denormalized{+5, String("string")}}`,
}, {
raw: dhex("62056279746573628580808080006279746573"),
@@ -186,6 +209,7 @@
Tag{12, BytesType}, Bytes("bytes"),
Tag{12, BytesType}, Denormalized{+5, Bytes("bytes")},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{12, Bytes}, Bytes("bytes"),
Tag{12, Bytes}, Denormalized{+5, Bytes("bytes")},
@@ -201,6 +225,7 @@
Tag{100, StartGroupType}, Tag{100, EndGroupType},
}),
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{13, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
protopack.Tag{100, protopack.VarintType}, protopack.Uvarint(18446744073709551615),
@@ -212,6 +237,30 @@
}),
}`,
}, {
+ raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
+ msg: Message{
+ Tag{13, BytesType}, LengthPrefix(Message{
+ Tag{100, VarintType}, Uvarint(math.MaxUint64),
+ Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
+ Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
+ Tag{100, BytesType}, Bytes("bytes"),
+ Tag{100, StartGroupType}, Tag{100, EndGroupType},
+ }),
+ },
+ inferMsg: true,
+ }, {
+ raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306ac06"),
+ msg: Message{
+ Tag{13, BytesType}, Bytes(Message{
+ Tag{100, VarintType}, Uvarint(math.MaxUint64),
+ Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
+ Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
+ Tag{100, BytesType}, Bytes("bytes"),
+ Tag{100, StartGroupType}, Tag{101, EndGroupType},
+ }.Marshal()),
+ },
+ inferMsg: true,
+ }, {
raw: dhex("6aa88080808000a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
msg: Message{
Tag{13, BytesType}, Denormalized{5, LengthPrefix(Message{
@@ -222,6 +271,7 @@
Tag{100, StartGroupType}, Tag{100, EndGroupType},
})},
},
+ msgDesc: msgDesc,
wantOutCompact: `Message{Tag{13, Bytes}, Denormalized{+5, LengthPrefix(Message{Tag{100, Varint}, Uvarint(18446744073709551615), Tag{100, Fixed32}, Uint32(4294967295), Tag{100, Fixed64}, Uint64(18446744073709551615), Tag{100, Bytes}, Bytes("bytes"), Tag{100, StartGroup}, Tag{100, EndGroup}})}}`,
}, {
raw: dhex("73a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a40674"),
@@ -235,6 +285,7 @@
},
Tag{14, EndGroupType},
},
+ msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{14, StartGroup},
Message{
@@ -261,6 +312,7 @@
Tag{1706, Type(7)},
Raw("\x1an\x98\x11\xc8Z*\xb3"),
},
+ msgDesc: msgDesc,
}, {
raw: dhex("3d08d0e57f"),
msg: Message{
@@ -269,6 +321,7 @@
func() uint32 { return 0x7fe5d008 }(),
)),
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(math.Float32frombits(0x7fe5d008)),
}`,
@@ -277,6 +330,7 @@
msg: Message{
Tag{10, Fixed64Type}, Float64(math.Float64frombits(0x7ff91b771051d6a8)),
},
+ msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{10, protopack.Fixed64Type}, protopack.Float64(math.Float64frombits(0x7ff91b771051d6a8)),
}`,
@@ -302,6 +356,7 @@
Tag{28856, BytesType},
Raw("\xbb"),
},
+ msgDesc: msgDesc,
}, {
raw: dhex("29baa4ac1c1e0a20183393bac434b8d3559337ec940050038770eaa9937f98e4"),
msg: Message{
@@ -318,6 +373,7 @@
Raw("꩓\u007f\x98\xe4"),
},
},
+ msgDesc: msgDesc,
}}
equateFloatBits := cmp.Options{
@@ -332,13 +388,13 @@
t.Run("", func(t *testing.T) {
var msg Message
raw := tt.msg.Marshal()
- msg.UnmarshalDescriptor(tt.raw, msgDesc)
+ msg.unmarshal(tt.raw, tt.msgDesc, tt.inferMsg)
if !bytes.Equal(raw, tt.raw) {
t.Errorf("Marshal() mismatch:\ngot %x\nwant %x", raw, tt.raw)
}
- if !cmp.Equal(msg, tt.msg, equateFloatBits) {
- t.Errorf("Unmarshal() mismatch:\ngot %+v\nwant %+v", msg, tt.msg)
+ if diff := cmp.Diff(tt.msg, msg, equateFloatBits); diff != "" {
+ t.Errorf("Unmarshal() mismatch (-want +got):\n%s", diff)
}
if got, want := tt.msg.Size(), len(tt.raw); got != want {
t.Errorf("Size() = %v, want %v", got, want)