proto, internal/impl: zero-length proto2 bytes fields should be non-nil
Fix decoding of zero-length bytes fields to produce a non-nil []byte.
Change-Id: Ifb7791a47df81091700f7226523371d1386fb1ad
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188765
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index 085710d..0bbbbe9 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -152,10 +152,26 @@
return b, nil
}
+{{if .ToGoTypeNoZero}}
+// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}.
+// The zero value is not decoded.
+func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != {{.WireType.Expr}} {
+ return 0, errUnknown
+ }
+ v, n := {{template "Consume" .}}
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ *p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
+ return n, nil
+}
+{{end}}
+
var coder{{.Name}}NoZero = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZero,
- unmarshal: consume{{.Name}},
+ unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}},
}
{{if or (eq .Name "Bytes") (eq .Name "String")}}
@@ -174,10 +190,28 @@
return b, nil
}
+{{if .ToGoTypeNoZero}}
+// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
+func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != {{.WireType.Expr}} {
+ return 0, errUnknown
+ }
+ v, n := {{template "Consume" .}}
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
+ return n, nil
+}
+{{end}}
+
var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZeroValidateUTF8,
- unmarshal: consume{{.Name}}ValidateUTF8,
+ unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}ValidateUTF8,
}
{{end}}
@@ -551,6 +585,9 @@
{{end -}}
{{end -}}
+// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
+var emptyBuf [0]byte
+
var wireTypes = map[protoreflect.Kind]wire.Type{
{{range . -}}
protoreflect.{{.Name}}Kind: {{.WireType.Expr}},
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index 023cde8..5cbf4ce 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -85,10 +85,11 @@
FromValue Expr
// Conversions to/from generated structures.
- GoType GoType
- ToGoType Expr
- FromGoType Expr
- NoPointer bool
+ GoType GoType
+ ToGoType Expr
+ ToGoTypeNoZero Expr
+ FromGoType Expr
+ NoPointer bool
}
func (k ProtoKind) Expr() Expr {
@@ -229,14 +230,15 @@
FromGoType: "v",
},
{
- Name: "Bytes",
- WireType: WireBytes,
- ToValue: "append(([]byte)(nil), v...)",
- FromValue: "v.Bytes()",
- GoType: GoBytes,
- ToGoType: "append(([]byte)(nil), v...)",
- FromGoType: "v",
- NoPointer: true,
+ Name: "Bytes",
+ WireType: WireBytes,
+ ToValue: "append(([]byte)(nil), v...)",
+ FromValue: "v.Bytes()",
+ GoType: GoBytes,
+ ToGoType: "append(emptyBuf[:], v...)",
+ ToGoTypeNoZero: "append(([]byte)(nil), v...)",
+ FromGoType: "v",
+ NoPointer: true,
},
{
Name: "Message",
diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go
index 46380f5..f40af3d 100644
--- a/internal/impl/codec_gen.go
+++ b/internal/impl/codec_gen.go
@@ -4395,7 +4395,7 @@
if n < 0 {
return 0, wire.ParseError(n)
}
- *p.Bytes() = append(([]byte)(nil), v...)
+ *p.Bytes() = append(emptyBuf[:], v...)
return n, nil
}
@@ -4428,7 +4428,7 @@
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
- *p.Bytes() = append(([]byte)(nil), v...)
+ *p.Bytes() = append(emptyBuf[:], v...)
return n, nil
}
@@ -4460,10 +4460,24 @@
return b, nil
}
+// consumeBytesNoZero wire decodes a []byte pointer as a Bytes.
+// The zero value is not decoded.
+func consumeBytesNoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ *p.Bytes() = append(([]byte)(nil), v...)
+ return n, nil
+}
+
var coderBytesNoZero = pointerCoderFuncs{
size: sizeBytesNoZero,
marshal: appendBytesNoZero,
- unmarshal: consumeBytes,
+ unmarshal: consumeBytesNoZero,
}
// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
@@ -4481,10 +4495,26 @@
return b, nil
}
+// consumeBytesNoZeroValidateUTF8 wire decodes a []byte pointer as a Bytes.
+func consumeBytesNoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.Bytes() = append(([]byte)(nil), v...)
+ return n, nil
+}
+
var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeBytesNoZero,
marshal: appendBytesNoZeroValidateUTF8,
- unmarshal: consumeBytesValidateUTF8,
+ unmarshal: consumeBytesNoZeroValidateUTF8,
}
// sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
@@ -4516,7 +4546,7 @@
if n < 0 {
return 0, wire.ParseError(n)
}
- *sp = append(*sp, append(([]byte)(nil), v...))
+ *sp = append(*sp, append(emptyBuf[:], v...))
return n, nil
}
@@ -4552,7 +4582,7 @@
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
- *sp = append(*sp, append(([]byte)(nil), v...))
+ *sp = append(*sp, append(emptyBuf[:], v...))
return n, nil
}
@@ -4585,7 +4615,7 @@
if n < 0 {
return nil, 0, wire.ParseError(n)
}
- return append(([]byte)(nil), v...), n, nil
+ return append(emptyBuf[:], v...), n, nil
}
var coderBytesIface = ifaceCoderFuncs{
@@ -4617,7 +4647,7 @@
if !utf8.Valid(v) {
return nil, 0, errInvalidUTF8{}
}
- return append(([]byte)(nil), v...), n, nil
+ return append(emptyBuf[:], v...), n, nil
}
var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
@@ -4655,7 +4685,7 @@
if n < 0 {
return nil, 0, wire.ParseError(n)
}
- *sp = append(*sp, append(([]byte)(nil), v...))
+ *sp = append(*sp, append(emptyBuf[:], v...))
return ival, n, nil
}
@@ -4665,6 +4695,9 @@
unmarshal: consumeBytesSliceIface,
}
+// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
+var emptyBuf [0]byte
+
var wireTypes = map[protoreflect.Kind]wire.Type{
protoreflect.BoolKind: wire.VarintType,
protoreflect.EnumKind: wire.VarintType,
diff --git a/proto/decode_test.go b/proto/decode_test.go
index c32c94c..6088eb5 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -108,6 +108,21 @@
}
}
+func TestDecodeZeroLengthBytes(t *testing.T) {
+ // Verify that proto3 bytes fields don't give the mistaken
+ // impression that they preserve presence.
+ wire := pack.Message{
+ pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
+ }.Marshal()
+ m := &test3pb.TestAllTypes{}
+ if err := proto.Unmarshal(wire, m); err != nil {
+ t.Fatal(err)
+ }
+ if m.OptionalBytes != nil {
+ t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
+ }
+}
+
var testProtos = []testProto{
{
desc: "basic scalar types",
@@ -184,6 +199,60 @@
}.Marshal(),
},
{
+ desc: "zero values",
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
+ OptionalInt32: proto.Int32(0),
+ OptionalInt64: proto.Int64(0),
+ OptionalUint32: proto.Uint32(0),
+ OptionalUint64: proto.Uint64(0),
+ OptionalSint32: proto.Int32(0),
+ OptionalSint64: proto.Int64(0),
+ OptionalFixed32: proto.Uint32(0),
+ OptionalFixed64: proto.Uint64(0),
+ OptionalSfixed32: proto.Int32(0),
+ OptionalSfixed64: proto.Int64(0),
+ OptionalFloat: proto.Float32(0),
+ OptionalDouble: proto.Float64(0),
+ OptionalBool: proto.Bool(false),
+ OptionalString: proto.String(""),
+ OptionalBytes: []byte{},
+ }, &test3pb.TestAllTypes{}, build(
+ &testpb.TestAllExtensions{},
+ extend(testpb.E_OptionalInt32Extension, int32(0)),
+ extend(testpb.E_OptionalInt64Extension, int64(0)),
+ extend(testpb.E_OptionalUint32Extension, uint32(0)),
+ extend(testpb.E_OptionalUint64Extension, uint64(0)),
+ extend(testpb.E_OptionalSint32Extension, int32(0)),
+ extend(testpb.E_OptionalSint64Extension, int64(0)),
+ extend(testpb.E_OptionalFixed32Extension, uint32(0)),
+ extend(testpb.E_OptionalFixed64Extension, uint64(0)),
+ extend(testpb.E_OptionalSfixed32Extension, int32(0)),
+ extend(testpb.E_OptionalSfixed64Extension, int64(0)),
+ extend(testpb.E_OptionalFloatExtension, float32(0)),
+ extend(testpb.E_OptionalDoubleExtension, float64(0)),
+ extend(testpb.E_OptionalBoolExtension, bool(false)),
+ extend(testpb.E_OptionalStringExtension, string("")),
+ extend(testpb.E_OptionalBytesExtension, []byte{}),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(0),
+ pack.Tag{2, pack.VarintType}, pack.Varint(0),
+ pack.Tag{3, pack.VarintType}, pack.Uvarint(0),
+ pack.Tag{4, pack.VarintType}, pack.Uvarint(0),
+ pack.Tag{5, pack.VarintType}, pack.Svarint(0),
+ pack.Tag{6, pack.VarintType}, pack.Svarint(0),
+ pack.Tag{7, pack.Fixed32Type}, pack.Uint32(0),
+ pack.Tag{8, pack.Fixed64Type}, pack.Uint64(0),
+ pack.Tag{9, pack.Fixed32Type}, pack.Int32(0),
+ pack.Tag{10, pack.Fixed64Type}, pack.Int64(0),
+ pack.Tag{11, pack.Fixed32Type}, pack.Float32(0),
+ pack.Tag{12, pack.Fixed64Type}, pack.Float64(0),
+ pack.Tag{13, pack.VarintType}, pack.Bool(false),
+ pack.Tag{14, pack.BytesType}, pack.String(""),
+ pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
+ }.Marshal(),
+ },
+ {
desc: "groups",
decodeTo: []proto.Message{&testpb.TestAllTypes{
Optionalgroup: &testpb.TestAllTypes_OptionalGroup{