proto, internal/impl: zero-length proto2 bytes fields should be non-nil

Fix decoding of zero-length bytes fields to produce a non-nil []byte.

Change-Id: Ifb7791a47df81091700f7226523371d1386fb1ad
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188765
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index 085710d..0bbbbe9 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -152,10 +152,26 @@
 	return b, nil
 }
 
+{{if .ToGoTypeNoZero}}
+// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}.
+// The zero value is not decoded.
+func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+	if wtyp != {{.WireType.Expr}} {
+		return 0, errUnknown
+	}
+	v, n := {{template "Consume" .}}
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
+	return n, nil
+}
+{{end}}
+
 var coder{{.Name}}NoZero = pointerCoderFuncs{
 	size:      size{{.Name}}NoZero,
 	marshal:   append{{.Name}}NoZero,
-	unmarshal: consume{{.Name}},
+	unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}},
 }
 
 {{if or (eq .Name "Bytes") (eq .Name "String")}}
@@ -174,10 +190,28 @@
 	return b, nil
 }
 
+{{if .ToGoTypeNoZero}}
+// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
+func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+	if wtyp != {{.WireType.Expr}} {
+		return 0, errUnknown
+	}
+	v, n := {{template "Consume" .}}
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+		return 0, errInvalidUTF8{}
+	}
+	*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
+	return n, nil
+}
+{{end}}
+
 var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
 	size:      size{{.Name}}NoZero,
 	marshal:   append{{.Name}}NoZeroValidateUTF8,
-	unmarshal: consume{{.Name}}ValidateUTF8,
+	unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}ValidateUTF8,
 }
 {{end}}
 
@@ -551,6 +585,9 @@
 {{end -}}
 {{end -}}
 
+// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
+var emptyBuf [0]byte
+
 var wireTypes = map[protoreflect.Kind]wire.Type{
 {{range . -}}
 	protoreflect.{{.Name}}Kind: {{.WireType.Expr}},
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index 023cde8..5cbf4ce 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -85,10 +85,11 @@
 	FromValue Expr
 
 	// Conversions to/from generated structures.
-	GoType     GoType
-	ToGoType   Expr
-	FromGoType Expr
-	NoPointer  bool
+	GoType         GoType
+	ToGoType       Expr
+	ToGoTypeNoZero Expr
+	FromGoType     Expr
+	NoPointer      bool
 }
 
 func (k ProtoKind) Expr() Expr {
@@ -229,14 +230,15 @@
 		FromGoType: "v",
 	},
 	{
-		Name:       "Bytes",
-		WireType:   WireBytes,
-		ToValue:    "append(([]byte)(nil), v...)",
-		FromValue:  "v.Bytes()",
-		GoType:     GoBytes,
-		ToGoType:   "append(([]byte)(nil), v...)",
-		FromGoType: "v",
-		NoPointer:  true,
+		Name:           "Bytes",
+		WireType:       WireBytes,
+		ToValue:        "append(([]byte)(nil), v...)",
+		FromValue:      "v.Bytes()",
+		GoType:         GoBytes,
+		ToGoType:       "append(emptyBuf[:], v...)",
+		ToGoTypeNoZero: "append(([]byte)(nil), v...)",
+		FromGoType:     "v",
+		NoPointer:      true,
 	},
 	{
 		Name:      "Message",
diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go
index 46380f5..f40af3d 100644
--- a/internal/impl/codec_gen.go
+++ b/internal/impl/codec_gen.go
@@ -4395,7 +4395,7 @@
 	if n < 0 {
 		return 0, wire.ParseError(n)
 	}
-	*p.Bytes() = append(([]byte)(nil), v...)
+	*p.Bytes() = append(emptyBuf[:], v...)
 	return n, nil
 }
 
@@ -4428,7 +4428,7 @@
 	if !utf8.Valid(v) {
 		return 0, errInvalidUTF8{}
 	}
-	*p.Bytes() = append(([]byte)(nil), v...)
+	*p.Bytes() = append(emptyBuf[:], v...)
 	return n, nil
 }
 
@@ -4460,10 +4460,24 @@
 	return b, nil
 }
 
+// consumeBytesNoZero wire decodes a []byte pointer as a Bytes.
+// The zero value is not decoded.
+func consumeBytesNoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+	if wtyp != wire.BytesType {
+		return 0, errUnknown
+	}
+	v, n := wire.ConsumeBytes(b)
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	*p.Bytes() = append(([]byte)(nil), v...)
+	return n, nil
+}
+
 var coderBytesNoZero = pointerCoderFuncs{
 	size:      sizeBytesNoZero,
 	marshal:   appendBytesNoZero,
-	unmarshal: consumeBytes,
+	unmarshal: consumeBytesNoZero,
 }
 
 // appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
@@ -4481,10 +4495,26 @@
 	return b, nil
 }
 
+// consumeBytesNoZeroValidateUTF8 wire decodes a []byte pointer as a Bytes.
+func consumeBytesNoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+	if wtyp != wire.BytesType {
+		return 0, errUnknown
+	}
+	v, n := wire.ConsumeBytes(b)
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	if !utf8.Valid(v) {
+		return 0, errInvalidUTF8{}
+	}
+	*p.Bytes() = append(([]byte)(nil), v...)
+	return n, nil
+}
+
 var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
 	size:      sizeBytesNoZero,
 	marshal:   appendBytesNoZeroValidateUTF8,
-	unmarshal: consumeBytesValidateUTF8,
+	unmarshal: consumeBytesNoZeroValidateUTF8,
 }
 
 // sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
@@ -4516,7 +4546,7 @@
 	if n < 0 {
 		return 0, wire.ParseError(n)
 	}
-	*sp = append(*sp, append(([]byte)(nil), v...))
+	*sp = append(*sp, append(emptyBuf[:], v...))
 	return n, nil
 }
 
@@ -4552,7 +4582,7 @@
 	if !utf8.Valid(v) {
 		return 0, errInvalidUTF8{}
 	}
-	*sp = append(*sp, append(([]byte)(nil), v...))
+	*sp = append(*sp, append(emptyBuf[:], v...))
 	return n, nil
 }
 
@@ -4585,7 +4615,7 @@
 	if n < 0 {
 		return nil, 0, wire.ParseError(n)
 	}
-	return append(([]byte)(nil), v...), n, nil
+	return append(emptyBuf[:], v...), n, nil
 }
 
 var coderBytesIface = ifaceCoderFuncs{
@@ -4617,7 +4647,7 @@
 	if !utf8.Valid(v) {
 		return nil, 0, errInvalidUTF8{}
 	}
-	return append(([]byte)(nil), v...), n, nil
+	return append(emptyBuf[:], v...), n, nil
 }
 
 var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
@@ -4655,7 +4685,7 @@
 	if n < 0 {
 		return nil, 0, wire.ParseError(n)
 	}
-	*sp = append(*sp, append(([]byte)(nil), v...))
+	*sp = append(*sp, append(emptyBuf[:], v...))
 	return ival, n, nil
 }
 
@@ -4665,6 +4695,9 @@
 	unmarshal: consumeBytesSliceIface,
 }
 
+// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
+var emptyBuf [0]byte
+
 var wireTypes = map[protoreflect.Kind]wire.Type{
 	protoreflect.BoolKind:     wire.VarintType,
 	protoreflect.EnumKind:     wire.VarintType,
diff --git a/proto/decode_test.go b/proto/decode_test.go
index c32c94c..6088eb5 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -108,6 +108,21 @@
 	}
 }
 
+func TestDecodeZeroLengthBytes(t *testing.T) {
+	// Verify that proto3 bytes fields don't give the mistaken
+	// impression that they preserve presence.
+	wire := pack.Message{
+		pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
+	}.Marshal()
+	m := &test3pb.TestAllTypes{}
+	if err := proto.Unmarshal(wire, m); err != nil {
+		t.Fatal(err)
+	}
+	if m.OptionalBytes != nil {
+		t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
+	}
+}
+
 var testProtos = []testProto{
 	{
 		desc: "basic scalar types",
@@ -184,6 +199,60 @@
 		}.Marshal(),
 	},
 	{
+		desc: "zero values",
+		decodeTo: []proto.Message{&testpb.TestAllTypes{
+			OptionalInt32:    proto.Int32(0),
+			OptionalInt64:    proto.Int64(0),
+			OptionalUint32:   proto.Uint32(0),
+			OptionalUint64:   proto.Uint64(0),
+			OptionalSint32:   proto.Int32(0),
+			OptionalSint64:   proto.Int64(0),
+			OptionalFixed32:  proto.Uint32(0),
+			OptionalFixed64:  proto.Uint64(0),
+			OptionalSfixed32: proto.Int32(0),
+			OptionalSfixed64: proto.Int64(0),
+			OptionalFloat:    proto.Float32(0),
+			OptionalDouble:   proto.Float64(0),
+			OptionalBool:     proto.Bool(false),
+			OptionalString:   proto.String(""),
+			OptionalBytes:    []byte{},
+		}, &test3pb.TestAllTypes{}, build(
+			&testpb.TestAllExtensions{},
+			extend(testpb.E_OptionalInt32Extension, int32(0)),
+			extend(testpb.E_OptionalInt64Extension, int64(0)),
+			extend(testpb.E_OptionalUint32Extension, uint32(0)),
+			extend(testpb.E_OptionalUint64Extension, uint64(0)),
+			extend(testpb.E_OptionalSint32Extension, int32(0)),
+			extend(testpb.E_OptionalSint64Extension, int64(0)),
+			extend(testpb.E_OptionalFixed32Extension, uint32(0)),
+			extend(testpb.E_OptionalFixed64Extension, uint64(0)),
+			extend(testpb.E_OptionalSfixed32Extension, int32(0)),
+			extend(testpb.E_OptionalSfixed64Extension, int64(0)),
+			extend(testpb.E_OptionalFloatExtension, float32(0)),
+			extend(testpb.E_OptionalDoubleExtension, float64(0)),
+			extend(testpb.E_OptionalBoolExtension, bool(false)),
+			extend(testpb.E_OptionalStringExtension, string("")),
+			extend(testpb.E_OptionalBytesExtension, []byte{}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(0),
+			pack.Tag{2, pack.VarintType}, pack.Varint(0),
+			pack.Tag{3, pack.VarintType}, pack.Uvarint(0),
+			pack.Tag{4, pack.VarintType}, pack.Uvarint(0),
+			pack.Tag{5, pack.VarintType}, pack.Svarint(0),
+			pack.Tag{6, pack.VarintType}, pack.Svarint(0),
+			pack.Tag{7, pack.Fixed32Type}, pack.Uint32(0),
+			pack.Tag{8, pack.Fixed64Type}, pack.Uint64(0),
+			pack.Tag{9, pack.Fixed32Type}, pack.Int32(0),
+			pack.Tag{10, pack.Fixed64Type}, pack.Int64(0),
+			pack.Tag{11, pack.Fixed32Type}, pack.Float32(0),
+			pack.Tag{12, pack.Fixed64Type}, pack.Float64(0),
+			pack.Tag{13, pack.VarintType}, pack.Bool(false),
+			pack.Tag{14, pack.BytesType}, pack.String(""),
+			pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
+		}.Marshal(),
+	},
+	{
 		desc: "groups",
 		decodeTo: []proto.Message{&testpb.TestAllTypes{
 			Optionalgroup: &testpb.TestAllTypes_OptionalGroup{