internal/impl: validate UTF-8 for proto3 optional strings

Change-Id: I090e7c5adac47818831c63d3d999cb7fea5ac696
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/231357
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index 932c8b0..4b362c6 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -299,6 +299,48 @@
 }
 {{end}}
 
+{{if (eq .Name "String")}}
+// append{{.Name}}PtrValidateUTF8 wire encodes a *{{.GoType}} pointer as a {{.Name}}.
+// It panics if the pointer is nil.
+func append{{.Name}}PtrValidateUTF8(b []byte, p pointer, f *coderFieldInfo, _ marshalOptions) ([]byte, error) {
+	v := **p.{{.GoType.PointerMethod}}Ptr()
+	b = protowire.AppendVarint(b, f.wiretag)
+	{{template "Append" .}}
+	if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+		return b, errInvalidUTF8{}
+	}
+	return b, nil
+}
+
+// consume{{.Name}}PtrValidateUTF8 wire decodes a *{{.GoType}} pointer as a {{.Name}}.
+func consume{{.Name}}PtrValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, _ unmarshalOptions) (out unmarshalOutput, err error) {
+	if wtyp != {{.WireType.Expr}} {
+		return out, errUnknown
+	}
+	{{template "Consume" .}}
+	if n < 0 {
+		return out, protowire.ParseError(n)
+	}
+	if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+		return out, errInvalidUTF8{}
+	}
+	vp := p.{{.GoType.PointerMethod}}Ptr()
+	if *vp == nil {
+		*vp = new({{.GoType}})
+	}
+	**vp = {{.ToGoType}}
+	out.n = n
+	return out, nil
+}
+
+var coder{{.Name}}PtrValidateUTF8 = pointerCoderFuncs{
+	size:      size{{.Name}}Ptr,
+	marshal:   append{{.Name}}PtrValidateUTF8,
+	unmarshal: consume{{.Name}}PtrValidateUTF8,
+	merge:     merge{{.GoType.PointerMethod}}Ptr,
+}
+{{end}}
+
 // size{{.Name}}Slice returns the size of wire encoding a []{{.GoType}} pointer as a repeated {{.Name}}.
 func size{{.Name}}Slice(p pointer, f *coderFieldInfo, _ marshalOptions) (size int) {
 	s := *p.{{.GoType.PointerMethod}}Slice()
diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go
index 2c43b11..ff198d0 100644
--- a/internal/impl/codec_gen.go
+++ b/internal/impl/codec_gen.go
@@ -5078,6 +5078,46 @@
 	merge:     mergeStringPtr,
 }
 
+// appendStringPtrValidateUTF8 wire encodes a *string pointer as a String.
+// It panics if the pointer is nil.
+func appendStringPtrValidateUTF8(b []byte, p pointer, f *coderFieldInfo, _ marshalOptions) ([]byte, error) {
+	v := **p.StringPtr()
+	b = protowire.AppendVarint(b, f.wiretag)
+	b = protowire.AppendString(b, v)
+	if !utf8.ValidString(v) {
+		return b, errInvalidUTF8{}
+	}
+	return b, nil
+}
+
+// consumeStringPtrValidateUTF8 wire decodes a *string pointer as a String.
+func consumeStringPtrValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, _ unmarshalOptions) (out unmarshalOutput, err error) {
+	if wtyp != protowire.BytesType {
+		return out, errUnknown
+	}
+	v, n := protowire.ConsumeString(b)
+	if n < 0 {
+		return out, protowire.ParseError(n)
+	}
+	if !utf8.ValidString(v) {
+		return out, errInvalidUTF8{}
+	}
+	vp := p.StringPtr()
+	if *vp == nil {
+		*vp = new(string)
+	}
+	**vp = v
+	out.n = n
+	return out, nil
+}
+
+var coderStringPtrValidateUTF8 = pointerCoderFuncs{
+	size:      sizeStringPtr,
+	marshal:   appendStringPtrValidateUTF8,
+	unmarshal: consumeStringPtrValidateUTF8,
+	merge:     mergeStringPtr,
+}
+
 // sizeStringSlice returns the size of wire encoding a []string pointer as a repeated String.
 func sizeStringSlice(p pointer, f *coderFieldInfo, _ marshalOptions) (size int) {
 	s := *p.StringSlice()
diff --git a/internal/impl/codec_tables.go b/internal/impl/codec_tables.go
index c934c8d..e899712 100644
--- a/internal/impl/codec_tables.go
+++ b/internal/impl/codec_tables.go
@@ -338,6 +338,9 @@
 				return nil, coderDoublePtr
 			}
 		case pref.StringKind:
+			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
+				return nil, coderStringPtrValidateUTF8
+			}
 			if ft.Kind() == reflect.String {
 				return nil, coderStringPtr
 			}
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 441f646..9979d77 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -1559,6 +1559,15 @@
 	{
 		desc: "invalid UTF-8 in optional string field",
 		decodeTo: makeMessages(protobuild.Message{
+			"optional_string": "abc\xff",
+		}, &test3pb.TestAllTypes{}),
+		wire: protopack.Message{
+			protopack.Tag{14, protopack.BytesType}, protopack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in singular string field",
+		decodeTo: makeMessages(protobuild.Message{
 			"singular_string": "abc\xff",
 		}, &test3pb.TestAllTypes{}),
 		wire: protopack.Message{