encoding/textpb: add string fields UTF-8 validation

Change-Id: I15aec2b90efae9366eb496dc221b9e8cacd9d8e6
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171122
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go
index 59c98b1..218d95f 100644
--- a/encoding/textpb/decode.go
+++ b/encoding/textpb/decode.go
@@ -7,6 +7,7 @@
 import (
 	"fmt"
 	"strings"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/text"
 	"github.com/golang/protobuf/v2/internal/errors"
@@ -293,7 +294,13 @@
 		}
 	case pref.StringKind:
 		if input.Type() == text.String {
-			return pref.ValueOf(string(input.String())), nil
+			s := input.String()
+			if utf8.ValidString(s) {
+				return pref.ValueOf(s), nil
+			}
+			var nerr errors.NonFatal
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
+			return pref.ValueOf(s), nerr.E
 		}
 	case pref.BytesKind:
 		if input.Type() == text.String {
@@ -421,11 +428,12 @@
 		return fd.Default().MapKey(), nil
 	}
 
+	var nerr errors.NonFatal
 	val, err := unmarshalScalar(input, fd)
-	if err != nil {
+	if !nerr.Merge(err) {
 		return pref.MapKey{}, errors.New("%v contains invalid key: %v", fd.FullName(), input)
 	}
-	return val.MapKey(), nil
+	return val.MapKey(), nerr.E
 }
 
 // unmarshalMapMessageValue unmarshals given message-type text.Value into a protoreflect.Map for
@@ -447,18 +455,19 @@
 // unmarshalMapScalarValue unmarshals given scalar-type text.Value into a protoreflect.Map
 // for the given MapKey.
 func unmarshalMapScalarValue(input text.Value, pkey pref.MapKey, fd pref.FieldDescriptor, mmap pref.Map) error {
+	var nerr errors.NonFatal
 	var val pref.Value
 	if input.Type() == 0 {
 		val = fd.Default()
 	} else {
 		var err error
 		val, err = unmarshalScalar(input, fd)
-		if err != nil {
+		if !nerr.Merge(err) {
 			return err
 		}
 	}
 	mmap.Set(pkey, val)
-	return nil
+	return nerr.E
 }
 
 // isExpandedAny returns true if given [][2]text.Value may be an expanded Any that contains only one
diff --git a/encoding/textpb/decode_test.go b/encoding/textpb/decode_test.go
index 7c45641..e98b0b3 100644
--- a/encoding/textpb/decode_test.go
+++ b/encoding/textpb/decode_test.go
@@ -10,6 +10,7 @@
 
 	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/encoding/textpb"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/legacy"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
@@ -183,6 +184,14 @@
 			SString:   "谷歌",
 		},
 	}, {
+		desc:         "string with invalid UTF-8",
+		inputMessage: &pb3.Scalars{},
+		inputText:    `s_string: "abc\xff"`,
+		wantMessage: &pb3.Scalars{
+			SString: "abc\xff",
+		},
+		wantErr: true,
+	}, {
 		desc:         "proto2 message contains unknown field",
 		inputMessage: &pb2.Scalars{},
 		inputText:    "unknown_field: 123",
@@ -474,6 +483,19 @@
 			},
 		},
 	}, {
+		desc:         "proto3 nested message contains invalid UTF-8",
+		inputMessage: &pb3.Nests{},
+		inputText: `s_nested: {
+  s_string: "abc\xff"
+}
+`,
+		wantMessage: &pb3.Nests{
+			SNested: &pb3.Nested{
+				SString: "abc\xff",
+			},
+		},
+		wantErr: true,
+	}, {
 		desc:         "oneof set to empty string",
 		inputMessage: &pb3.Oneofs{},
 		inputText:    "oneof_string: ''",
@@ -561,6 +583,14 @@
 			RptBool:   []bool{true, false, true},
 		},
 	}, {
+		desc:         "repeated contains invalid UTF-8",
+		inputMessage: &pb2.Repeats{},
+		inputText:    `rpt_string: "abc\xff"`,
+		wantMessage: &pb2.Repeats{
+			RptString: []string{"abc\xff"},
+		},
+		wantErr: true,
+	}, {
 		desc:         "repeated enums",
 		inputMessage: &pb2.Enums{},
 		inputText: `
@@ -871,6 +901,34 @@
 			},
 		},
 	}, {
+		desc:         "map field value contains invalid UTF-8",
+		inputMessage: &pb3.Maps{},
+		inputText: `int32_to_str: {
+  key: 101
+  value: "abc\xff"
+}
+`,
+		wantMessage: &pb3.Maps{
+			Int32ToStr: map[int32]string{
+				101: "abc\xff",
+			},
+		},
+		wantErr: true,
+	}, {
+		desc:         "map field key contains invalid UTF-8",
+		inputMessage: &pb3.Maps{},
+		inputText: `str_to_nested: {
+  key: "abc\xff"
+  value: {}
+}
+`,
+		wantMessage: &pb3.Maps{
+			StrToNested: map[string]*pb3.Nested{
+				"abc\xff": {},
+			},
+		},
+		wantErr: true,
+	}, {
 		desc:         "map contains unknown field",
 		inputMessage: &pb3.Maps{},
 		inputText: `
@@ -1165,6 +1223,16 @@
 			return m
 		}(),
 	}, {
+		desc:         "extension field contains invalid UTF-8",
+		inputMessage: &pb2.Extensions{},
+		inputText:    `[pb2.opt_ext_string]: "abc\xff"`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_OptExtString, "abc\xff")
+			return m
+		}(),
+		wantErr: true,
+	}, {
 		desc:         "extensions of repeated fields",
 		inputMessage: &pb2.Extensions{},
 		inputText: `[pb2.rpt_ext_enum]: TEN
@@ -1419,6 +1487,32 @@
 		}(),
 		wantErr: true,
 	}, {
+		desc: "Any with invalid UTF-8",
+		umo: textpb.UnmarshalOptions{
+			Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()),
+		},
+		inputMessage: &knownpb.Any{},
+		inputText: `
+[pb3.Nested]: {
+  s_string: "abc\xff"
+}
+`,
+		wantMessage: func() proto.Message {
+			m := &pb3.Nested{
+				SString: "abc\xff",
+			}
+			var nerr errors.NonFatal
+			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
+			if !nerr.Merge(err) {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return &knownpb.Any{
+				TypeUrl: string(m.ProtoReflect().Type().FullName()),
+				Value:   b,
+			}
+		}(),
+		wantErr: true,
+	}, {
 		desc:         "Any expanded with unregistered type",
 		umo:          textpb.UnmarshalOptions{Resolver: preg.NewTypes()},
 		inputMessage: &knownpb.Any{},
@@ -1459,7 +1553,6 @@
 	for _, tt := range tests {
 		tt := tt
 		t.Run(tt.desc, func(t *testing.T) {
-			t.Parallel()
 			err := tt.umo.Unmarshal(tt.inputMessage, []byte(tt.inputText))
 			if err != nil && !tt.wantErr {
 				t.Errorf("Unmarshal() returned error: %v\n\n", err)
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index e293143..c706898 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -7,6 +7,7 @@
 import (
 	"fmt"
 	"sort"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/text"
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
@@ -174,9 +175,18 @@
 		pref.Sfixed32Kind, pref.Fixed32Kind,
 		pref.Sfixed64Kind, pref.Fixed64Kind,
 		pref.FloatKind, pref.DoubleKind,
-		pref.StringKind, pref.BytesKind:
+		pref.BytesKind:
 		return text.ValueOf(val.Interface()), nil
 
+	case pref.StringKind:
+		s := val.String()
+		if utf8.ValidString(s) {
+			return text.ValueOf(s), nil
+		}
+		var nerr errors.NonFatal
+		nerr.AppendInvalidUTF8(string(fd.FullName()))
+		return text.ValueOf(s), nerr.E
+
 	case pref.EnumKind:
 		num := val.Enum()
 		if desc := fd.EnumType().Values().ByNumber(num); desc != nil {
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 5b9ee38..3397d66 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -170,6 +170,14 @@
 opt_string: "谷歌"
 `,
 	}, {
+		desc: "string with invalid UTF-8",
+		input: &pb3.Scalars{
+			SString: "abc\xff",
+		},
+		want: `s_string: "abc\xff"
+`,
+		wantErr: true,
+	}, {
 		desc: "float nan",
 		input: &pb3.Scalars{
 			SFloat: float32(math.NaN()),
@@ -364,6 +372,18 @@
 }
 `,
 	}, {
+		desc: "proto3 nested message contains invalid UTF-8",
+		input: &pb3.Nests{
+			SNested: &pb3.Nested{
+				SString: "abc\xff",
+			},
+		},
+		want: `s_nested: {
+  s_string: "abc\xff"
+}
+`,
+		wantErr: true,
+	}, {
 		desc:  "oneof not set",
 		input: &pb3.Oneofs{},
 		want:  "\n",
@@ -473,6 +493,14 @@
 rpt_bytes: "世界"
 `,
 	}, {
+		desc: "repeated contains invalid UTF-8",
+		input: &pb2.Repeats{
+			RptString: []string{"abc\xff"},
+		},
+		want: `rpt_string: "abc\xff"
+`,
+		wantErr: true,
+	}, {
 		desc: "repeated enums",
 		input: &pb2.Enums{
 			RptEnum:       []pb2.Enum{pb2.Enum_ONE, 2, pb2.Enum_TEN, 42},
@@ -671,6 +699,32 @@
 }
 `,
 	}, {
+		desc: "map field value contains invalid UTF-8",
+		input: &pb3.Maps{
+			Int32ToStr: map[int32]string{
+				101: "abc\xff",
+			},
+		},
+		want: `int32_to_str: {
+  key: 101
+  value: "abc\xff"
+}
+`,
+		wantErr: true,
+	}, {
+		desc: "map field key contains invalid UTF-8",
+		input: &pb3.Maps{
+			StrToNested: map[string]*pb3.Nested{
+				"abc\xff": {},
+			},
+		},
+		want: `str_to_nested: {
+  key: "abc\xff"
+  value: {}
+}
+`,
+		wantErr: true,
+	}, {
 		desc: "map field contains nil value",
 		input: &pb3.Maps{
 			StrToNested: map[string]*pb3.Nested{
@@ -919,6 +973,16 @@
 [pb2.opt_ext_string]: "extension field"
 `,
 	}, {
+		desc: "extension field contains invalid UTF-8",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_OptExtString, "abc\xff")
+			return m
+		}(),
+		want: `[pb2.opt_ext_string]: "abc\xff"
+`,
+		wantErr: true,
+	}, {
 		desc: "extension partial returns error",
 		input: func() proto.Message {
 			m := &pb2.Extensions{}
@@ -1178,6 +1242,29 @@
 `,
 		wantErr: true,
 	}, {
+		desc: "Any with invalid UTF-8",
+		mo: textpb.MarshalOptions{
+			Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()),
+		},
+		input: func() proto.Message {
+			m := &pb3.Nested{
+				SString: "abc\xff",
+			}
+			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
+			if err != nil {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return &knownpb.Any{
+				TypeUrl: string(m.ProtoReflect().Type().FullName()),
+				Value:   b,
+			}
+		}(),
+		want: `[pb3.Nested]: {
+  s_string: "abc\xff"
+}
+`,
+		wantErr: true,
+	}, {
 		desc: "Any with invalid value",
 		mo: textpb.MarshalOptions{
 			Resolver: preg.NewTypes((&pb2.Nested{}).ProtoReflect().Type()),