testing/protopack: add Message.UnmarshalAbductive

The protobuf wire format is insufficiently self-decribing such that
it is impossible to know whether for sure whether an unknown bytes value
is a sub-message or not. However, protopack is primarily used for debugging
where a best-effort guess is still very useful.

The Message.UnmarshalAbductive unmarshals an unknown bytes value as a message
if it is syntactically well-formed. Otherwise, it is left as is.

Change-Id: I5e2b4b995e2b5eb60942a242558bf4cea1da9891
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/309669
Trust: Joe Tsai <joetsai@digital-static.net>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/pbdump/pbdump.go b/internal/cmd/pbdump/pbdump.go
index 87d6030..befb831 100644
--- a/internal/cmd/pbdump/pbdump.go
+++ b/internal/cmd/pbdump/pbdump.go
@@ -128,7 +128,7 @@
 	// Parse and print message structure.
 	defer log.Printf("fatal input: %q", buf) // debug printout if panic occurs
 	var m protopack.Message
-	m.UnmarshalDescriptor(buf, desc)
+	m.UnmarshalAbductive(buf, desc)
 	if *printSource {
 		fmt.Printf("%#v\n", m)
 	} else {
diff --git a/testing/protopack/pack.go b/testing/protopack/pack.go
index d39593a..683ce0b 100644
--- a/testing/protopack/pack.go
+++ b/testing/protopack/pack.go
@@ -270,7 +270,7 @@
 //
 // Unmarshal is useful for debugging the protobuf wire format.
 func (m *Message) Unmarshal(in []byte) {
-	m.UnmarshalDescriptor(in, nil)
+	m.unmarshal(in, nil, false)
 }
 
 // UnmarshalDescriptor parses the input protobuf wire data as a syntax tree
@@ -289,22 +289,40 @@
 // Known sub-messages are parsed as a Message and packed repeated fields are
 // parsed as a LengthPrefix.
 func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) {
+	m.unmarshal(in, desc, false)
+}
+
+// UnmarshalAbductive is like UnmarshalDescriptor, but infers abductively
+// whether any unknown bytes values is a message based on whether it is
+// a syntactically well-formed message.
+//
+// Note that the protobuf wire format is not fully self-describing,
+// so abductive inference may attempt to expand a bytes value as a message
+// that is not actually a message. It is a best-effort guess.
+func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) {
+	m.unmarshal(in, desc, true)
+}
+
+func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) {
 	p := parser{in: in, out: *m}
-	p.parseMessage(desc, false)
+	p.parseMessage(desc, false, inferMessage)
 	*m = p.out
 }
 
 type parser struct {
 	in  []byte
 	out []Token
+
+	invalid bool
 }
 
-func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool) {
+func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) {
 	for len(p.in) > 0 {
 		v, n := protowire.ConsumeVarint(p.in)
 		num, typ := protowire.DecodeTag(v)
-		if n < 0 || num < 0 || v > math.MaxUint32 {
+		if n < 0 || num <= 0 || v > math.MaxUint32 {
 			p.out, p.in = append(p.out, Raw(p.in)), nil
+			p.invalid = true
 			return
 		}
 		if typ == EndGroupType && group {
@@ -341,13 +359,14 @@
 		case Fixed64Type:
 			p.parseFixed64(kind)
 		case BytesType:
-			p.parseBytes(isPacked, kind, subDesc)
+			p.parseBytes(isPacked, kind, subDesc, inferMessage)
 		case StartGroupType:
-			p.parseGroup(subDesc)
+			p.parseGroup(num, subDesc, inferMessage)
 		case EndGroupType:
-			// Handled above.
+			// Handled by p.parseGroup.
 		default:
 			p.out, p.in = append(p.out, Raw(p.in)), nil
+			p.invalid = true
 		}
 	}
 }
@@ -356,6 +375,7 @@
 	v, n := protowire.ConsumeVarint(p.in)
 	if n < 0 {
 		p.out, p.in = append(p.out, Raw(p.in)), nil
+		p.invalid = true
 		return
 	}
 	switch kind {
@@ -384,6 +404,7 @@
 	v, n := protowire.ConsumeFixed32(p.in)
 	if n < 0 {
 		p.out, p.in = append(p.out, Raw(p.in)), nil
+		p.invalid = true
 		return
 	}
 	switch kind {
@@ -400,6 +421,7 @@
 	v, n := protowire.ConsumeFixed64(p.in)
 	if n < 0 {
 		p.out, p.in = append(p.out, Raw(p.in)), nil
+		p.invalid = true
 		return
 	}
 	switch kind {
@@ -412,10 +434,11 @@
 	}
 }
 
-func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor) {
+func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) {
 	v, n := protowire.ConsumeVarint(p.in)
 	if n < 0 {
 		p.out, p.in = append(p.out, Raw(p.in)), nil
+		p.invalid = true
 		return
 	}
 	p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
@@ -424,6 +447,7 @@
 	}
 	if v > uint64(len(p.in)) {
 		p.out, p.in = append(p.out, Raw(p.in)), nil
+		p.invalid = true
 		return
 	}
 	p.out = p.out[:len(p.out)-1] // subsequent tokens contain prefix-length
@@ -434,11 +458,22 @@
 		switch kind {
 		case protoreflect.MessageKind:
 			p2 := parser{in: p.in[:v]}
-			p2.parseMessage(desc, false)
+			p2.parseMessage(desc, false, inferMessage)
 			p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
 		case protoreflect.StringKind:
 			p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:]
+		case protoreflect.BytesKind:
+			p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
 		default:
+			if inferMessage {
+				// Check whether this is a syntactically valid message.
+				p2 := parser{in: p.in[:v]}
+				p2.parseMessage(nil, false, inferMessage)
+				if !p2.invalid {
+					p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
+					break
+				}
+			}
 			p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
 		}
 	}
@@ -466,9 +501,9 @@
 	p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:]
 }
 
-func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) {
+func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) {
 	p2 := parser{in: p.in}
-	p2.parseMessage(desc, true)
+	p2.parseMessage(desc, true, inferMessage)
 	if len(p2.out) > 0 {
 		p.out = append(p.out, Message(p2.out))
 	}
@@ -476,8 +511,11 @@
 
 	// Append the trailing end group.
 	v, n := protowire.ConsumeVarint(p.in)
-	if num, typ := protowire.DecodeTag(v); typ == EndGroupType {
-		p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:]
+	if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType {
+		if startNum != endNum {
+			p.invalid = true
+		}
+		p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:]
 		if m := n - protowire.SizeVarint(v); m > 0 {
 			p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
 		}
diff --git a/testing/protopack/pack_test.go b/testing/protopack/pack_test.go
index 9752549..61ea336 100644
--- a/testing/protopack/pack_test.go
+++ b/testing/protopack/pack_test.go
@@ -15,6 +15,7 @@
 
 	"google.golang.org/protobuf/encoding/prototext"
 	pdesc "google.golang.org/protobuf/reflect/protodesc"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 
 	"google.golang.org/protobuf/types/descriptorpb"
@@ -67,8 +68,10 @@
 
 func TestPack(t *testing.T) {
 	tests := []struct {
-		raw []byte
-		msg Message
+		raw      []byte
+		msg      Message
+		msgDesc  protoreflect.MessageDescriptor
+		inferMsg bool
 
 		wantOutCompact string
 		wantOutMulti   string
@@ -81,6 +84,7 @@
 			Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
 			Tag{1, BytesType}, LengthPrefix{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}},
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{1, protopack.VarintType}, protopack.Bool(false),
 	protopack.Denormalized{+5, protopack.Tag{1, protopack.VarintType}}, protopack.Uvarint(2),
@@ -88,12 +92,22 @@
 	protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix{protopack.Bool(true), protopack.Bool(false), protopack.Uvarint(2), protopack.Denormalized{+5, protopack.Uvarint(2)}},
 }`,
 	}, {
+		raw: dhex("080088808080800002088280808080000a09010002828080808000"),
+		msg: Message{
+			Tag{1, VarintType}, Uvarint(0),
+			Denormalized{5, Tag{1, VarintType}}, Uvarint(2),
+			Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
+			Tag{1, BytesType}, Bytes(Message{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}}.Marshal()),
+		},
+		inferMsg: true,
+	}, {
 		raw: dhex("100010828080808000121980808080808080808001ffffffffffffffff7f828080808000"),
 		msg: Message{
 			Tag{2, VarintType}, Varint(0),
 			Tag{2, VarintType}, Denormalized{5, Varint(2)},
 			Tag{2, BytesType}, LengthPrefix{Varint(math.MinInt64), Varint(math.MaxInt64), Denormalized{5, Varint(2)}},
 		},
+		msgDesc:        msgDesc,
 		wantOutCompact: `Message{Tag{2, Varint}, Varint(0), Tag{2, Varint}, Denormalized{+5, Varint(2)}, Tag{2, Bytes}, LengthPrefix{Varint(-9223372036854775808), Varint(9223372036854775807), Denormalized{+5, Varint(2)}}}`,
 	}, {
 		raw: dhex("1801188180808080001a1affffffffffffffffff01feffffffffffffffff01818080808000"),
@@ -102,6 +116,7 @@
 			Tag{3, VarintType}, Denormalized{5, Svarint(-1)},
 			Tag{3, BytesType}, LengthPrefix{Svarint(math.MinInt64), Svarint(math.MaxInt64), Denormalized{5, Svarint(-1)}},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{3, Varint}, Svarint(-1),
 	Tag{3, Varint}, Denormalized{+5, Svarint(-1)},
@@ -114,6 +129,7 @@
 			Tag{4, VarintType}, Denormalized{5, Uvarint(+1)},
 			Tag{4, BytesType}, LengthPrefix{Uvarint(0), Uvarint(math.MaxUint64), Denormalized{5, Uvarint(+1)}},
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{4, protopack.VarintType}, protopack.Uvarint(1),
 	protopack.Tag{4, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(1)},
@@ -125,6 +141,7 @@
 			Tag{5, Fixed32Type}, Uint32(+1),
 			Tag{5, BytesType}, LengthPrefix{Uint32(0), Uint32(math.MaxUint32)},
 		},
+		msgDesc:        msgDesc,
 		wantOutCompact: `Message{Tag{5, Fixed32}, Uint32(1), Tag{5, Bytes}, LengthPrefix{Uint32(0), Uint32(4294967295)}}`,
 	}, {
 		raw: dhex("35ffffffff320800000080ffffff7f"),
@@ -132,6 +149,7 @@
 			Tag{6, Fixed32Type}, Int32(-1),
 			Tag{6, BytesType}, LengthPrefix{Int32(math.MinInt32), Int32(math.MaxInt32)},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{6, Fixed32}, Int32(-1),
 	Tag{6, Bytes}, LengthPrefix{Int32(-2147483648), Int32(2147483647)},
@@ -142,6 +160,7 @@
 			Tag{7, Fixed32Type}, Float32(math.Pi),
 			Tag{7, BytesType}, LengthPrefix{Float32(math.SmallestNonzeroFloat32), Float32(math.MaxFloat32), Float32(math.Inf(+1)), Float32(math.Inf(-1))},
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(3.1415927),
 	protopack.Tag{7, protopack.BytesType}, protopack.LengthPrefix{protopack.Float32(1e-45), protopack.Float32(3.4028235e+38), protopack.Float32(math.Inf(+1)), protopack.Float32(math.Inf(-1))},
@@ -152,6 +171,7 @@
 			Tag{8, Fixed64Type}, Uint64(+1),
 			Tag{8, BytesType}, LengthPrefix{Uint64(0), Uint64(math.MaxUint64)},
 		},
+		msgDesc:        msgDesc,
 		wantOutCompact: `Message{Tag{8, Fixed64}, Uint64(1), Tag{8, Bytes}, LengthPrefix{Uint64(0), Uint64(18446744073709551615)}}`,
 	}, {
 		raw: dhex("49ffffffffffffffff4a100000000000000080ffffffffffffff7f"),
@@ -159,6 +179,7 @@
 			Tag{9, Fixed64Type}, Int64(-1),
 			Tag{9, BytesType}, LengthPrefix{Int64(math.MinInt64), Int64(math.MaxInt64)},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{9, Fixed64}, Int64(-1),
 	Tag{9, Bytes}, LengthPrefix{Int64(-9223372036854775808), Int64(9223372036854775807)},
@@ -169,6 +190,7 @@
 			Tag{10, Fixed64Type}, Float64(math.Pi),
 			Tag{10, BytesType}, LengthPrefix{Float64(math.SmallestNonzeroFloat64), Float64(math.MaxFloat64), Float64(math.Inf(+1)), Float64(math.Inf(-1))},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{10, Fixed64}, Float64(3.141592653589793),
 	Tag{10, Bytes}, LengthPrefix{Float64(5e-324), Float64(1.7976931348623157e+308), Float64(+Inf), Float64(-Inf)},
@@ -179,6 +201,7 @@
 			Tag{11, BytesType}, String("string"),
 			Tag{11, BytesType}, Denormalized{+5, String("string")},
 		},
+		msgDesc:        msgDesc,
 		wantOutCompact: `Message{Tag{11, Bytes}, String("string"), Tag{11, Bytes}, Denormalized{+5, String("string")}}`,
 	}, {
 		raw: dhex("62056279746573628580808080006279746573"),
@@ -186,6 +209,7 @@
 			Tag{12, BytesType}, Bytes("bytes"),
 			Tag{12, BytesType}, Denormalized{+5, Bytes("bytes")},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{12, Bytes}, Bytes("bytes"),
 	Tag{12, Bytes}, Denormalized{+5, Bytes("bytes")},
@@ -201,6 +225,7 @@
 				Tag{100, StartGroupType}, Tag{100, EndGroupType},
 			}),
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{13, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
 		protopack.Tag{100, protopack.VarintType}, protopack.Uvarint(18446744073709551615),
@@ -212,6 +237,30 @@
 	}),
 }`,
 	}, {
+		raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
+		msg: Message{
+			Tag{13, BytesType}, LengthPrefix(Message{
+				Tag{100, VarintType}, Uvarint(math.MaxUint64),
+				Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
+				Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
+				Tag{100, BytesType}, Bytes("bytes"),
+				Tag{100, StartGroupType}, Tag{100, EndGroupType},
+			}),
+		},
+		inferMsg: true,
+	}, {
+		raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306ac06"),
+		msg: Message{
+			Tag{13, BytesType}, Bytes(Message{
+				Tag{100, VarintType}, Uvarint(math.MaxUint64),
+				Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
+				Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
+				Tag{100, BytesType}, Bytes("bytes"),
+				Tag{100, StartGroupType}, Tag{101, EndGroupType},
+			}.Marshal()),
+		},
+		inferMsg: true,
+	}, {
 		raw: dhex("6aa88080808000a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
 		msg: Message{
 			Tag{13, BytesType}, Denormalized{5, LengthPrefix(Message{
@@ -222,6 +271,7 @@
 				Tag{100, StartGroupType}, Tag{100, EndGroupType},
 			})},
 		},
+		msgDesc:        msgDesc,
 		wantOutCompact: `Message{Tag{13, Bytes}, Denormalized{+5, LengthPrefix(Message{Tag{100, Varint}, Uvarint(18446744073709551615), Tag{100, Fixed32}, Uint32(4294967295), Tag{100, Fixed64}, Uint64(18446744073709551615), Tag{100, Bytes}, Bytes("bytes"), Tag{100, StartGroup}, Tag{100, EndGroup}})}}`,
 	}, {
 		raw: dhex("73a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a40674"),
@@ -235,6 +285,7 @@
 			},
 			Tag{14, EndGroupType},
 		},
+		msgDesc: msgDesc,
 		wantOutMulti: `Message{
 	Tag{14, StartGroup},
 	Message{
@@ -261,6 +312,7 @@
 			Tag{1706, Type(7)},
 			Raw("\x1an\x98\x11\xc8Z*\xb3"),
 		},
+		msgDesc: msgDesc,
 	}, {
 		raw: dhex("3d08d0e57f"),
 		msg: Message{
@@ -269,6 +321,7 @@
 				func() uint32 { return 0x7fe5d008 }(),
 			)),
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(math.Float32frombits(0x7fe5d008)),
 }`,
@@ -277,6 +330,7 @@
 		msg: Message{
 			Tag{10, Fixed64Type}, Float64(math.Float64frombits(0x7ff91b771051d6a8)),
 		},
+		msgDesc: msgDesc,
 		wantOutSource: `protopack.Message{
 	protopack.Tag{10, protopack.Fixed64Type}, protopack.Float64(math.Float64frombits(0x7ff91b771051d6a8)),
 }`,
@@ -302,6 +356,7 @@
 			Tag{28856, BytesType},
 			Raw("\xbb"),
 		},
+		msgDesc: msgDesc,
 	}, {
 		raw: dhex("29baa4ac1c1e0a20183393bac434b8d3559337ec940050038770eaa9937f98e4"),
 		msg: Message{
@@ -318,6 +373,7 @@
 				Raw("꩓\u007f\x98\xe4"),
 			},
 		},
+		msgDesc: msgDesc,
 	}}
 
 	equateFloatBits := cmp.Options{
@@ -332,13 +388,13 @@
 		t.Run("", func(t *testing.T) {
 			var msg Message
 			raw := tt.msg.Marshal()
-			msg.UnmarshalDescriptor(tt.raw, msgDesc)
+			msg.unmarshal(tt.raw, tt.msgDesc, tt.inferMsg)
 
 			if !bytes.Equal(raw, tt.raw) {
 				t.Errorf("Marshal() mismatch:\ngot  %x\nwant %x", raw, tt.raw)
 			}
-			if !cmp.Equal(msg, tt.msg, equateFloatBits) {
-				t.Errorf("Unmarshal() mismatch:\ngot  %+v\nwant %+v", msg, tt.msg)
+			if diff := cmp.Diff(tt.msg, msg, equateFloatBits); diff != "" {
+				t.Errorf("Unmarshal() mismatch (-want +got):\n%s", diff)
 			}
 			if got, want := tt.msg.Size(), len(tt.raw); got != want {
 				t.Errorf("Size() = %v, want %v", got, want)