proto: validate UTF-8 in proto3 strings

Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/jsonpb/decode_test.go b/encoding/jsonpb/decode_test.go
index 941e416..cb99bcd 100644
--- a/encoding/jsonpb/decode_test.go
+++ b/encoding/jsonpb/decode_test.go
@@ -5,6 +5,7 @@
 package jsonpb_test
 
 import (
+	"bytes"
 	"math"
 	"testing"
 
@@ -2130,14 +2131,14 @@
   "value": "` + "abc\xff" + `"
 }`,
 		wantMessage: func() proto.Message {
-			m := &knownpb.StringValue{Value: "abc\xff"}
+			m := &knownpb.StringValue{Value: "abcd"}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
 			if err != nil {
 				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
 			}
 			return &knownpb.Any{
 				TypeUrl: "google.protobuf.StringValue",
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		wantErr: true,
@@ -2216,14 +2217,14 @@
   "value": "` + "abc\xff" + `"
 }`,
 		wantMessage: func() proto.Message {
-			m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}}
+			m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
 			if err != nil {
 				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
 			}
 			return &knownpb.Any{
 				TypeUrl: "google.protobuf.Value",
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		wantErr: true,
@@ -2369,7 +2370,7 @@
   }
 }`,
 		wantMessage: func() proto.Message {
-			m1 := &knownpb.StringValue{Value: "abc\xff"}
+			m1 := &knownpb.StringValue{Value: "abcd"}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m1)
 			if err != nil {
 				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
@@ -2385,7 +2386,7 @@
 			}
 			return &knownpb.Any{
 				TypeUrl: "pb2.KnownTypes",
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		wantErr: true,
diff --git a/encoding/jsonpb/encode_test.go b/encoding/jsonpb/encode_test.go
index 1a2858e..4277a3d 100644
--- a/encoding/jsonpb/encode_test.go
+++ b/encoding/jsonpb/encode_test.go
@@ -5,6 +5,7 @@
 package jsonpb_test
 
 import (
+	"bytes"
 	"encoding/hex"
 	"math"
 	"strings"
@@ -1687,14 +1688,14 @@
 			Resolver: preg.NewTypes((&knownpb.StringValue{}).ProtoReflect().Type()),
 		},
 		input: func() proto.Message {
-			m := &knownpb.StringValue{Value: "abc\xff"}
+			m := &knownpb.StringValue{Value: "abcd"}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
 			if err != nil {
 				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
 			}
 			return &knownpb.Any{
 				TypeUrl: "google.protobuf.StringValue",
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		want: `{
@@ -1765,14 +1766,14 @@
 			Resolver: preg.NewTypes((&knownpb.Value{}).ProtoReflect().Type()),
 		},
 		input: func() proto.Message {
-			m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}}
+			m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
 			if err != nil {
 				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
 			}
 			return &knownpb.Any{
 				TypeUrl: "type.googleapis.com/google.protobuf.Value",
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		want: `{
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 3397d66..41cae94 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -5,6 +5,7 @@
 package textpb_test
 
 import (
+	"bytes"
 	"encoding/hex"
 	"math"
 	"strings"
@@ -1248,7 +1249,7 @@
 		},
 		input: func() proto.Message {
 			m := &pb3.Nested{
-				SString: "abc\xff",
+				SString: "abcd",
 			}
 			b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
 			if err != nil {
@@ -1256,7 +1257,7 @@
 			}
 			return &knownpb.Any{
 				TypeUrl: string(m.ProtoReflect().Type().FullName()),
-				Value:   b,
+				Value:   bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
 			}
 		}(),
 		want: `[pb3.Nested]: {
diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go
index 2f73872..c8931d4 100644
--- a/internal/cmd/generate-types/main.go
+++ b/internal/cmd/generate-types/main.go
@@ -312,6 +312,7 @@
 		"fmt",
 		"math",
 		"sync",
+		"unicode/utf8",
 		"",
 		"github.com/golang/protobuf/v2/internal/encoding/wire",
 		"github.com/golang/protobuf/v2/internal/errors",
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index 3609be3..dbbc566 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -157,8 +157,8 @@
 // unmarshalScalar decodes a value of the given kind.
 //
 // Message values are decoded into a []byte which aliases the input data.
-func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) {
-	switch kind {
+func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
+	switch field.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
 		if wtyp != {{.WireType.Expr}} {
@@ -172,6 +172,13 @@
 		if n < 0 {
 			return val, 0, wire.ParseError(n)
 		}
+		{{if (eq .Name "String") -}}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			var nerr errors.NonFatal
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+			return protoreflect.ValueOf(string(v)), n, nerr.E
+		}
+		{{end -}}
 		return protoreflect.ValueOf({{.ToValue}}), n, nil
 	{{- end}}
 	default:
@@ -179,9 +186,9 @@
 	}
 }
 
-func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) {
+func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
 		{{- if .WireType.Packable}}
@@ -212,6 +219,11 @@
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}
+		{{if (eq .Name "String") -}}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
+		{{end -}}
 		{{if or (eq .Name "Message") (eq .Name "Group") -}}
 		m := list.NewMessage()
 		if err := o.unmarshalMessage(v, m); !nerr.Merge(err) {
@@ -240,12 +252,17 @@
 {{- end}}
 }
 
-func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) {
+func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
-		{{if (eq .Name "Message") -}}
+		{{- if (eq .Name "String") }}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
+		{{end -}}
+		{{- if (eq .Name "Message") -}}
 		var pos int
 		var err error
 		b, pos = appendSpeculativeLength(b)
@@ -266,7 +283,7 @@
 		{{- end}}
 	{{- end}}
 	default:
-		return b, errors.New("invalid kind %v", kind)
+		return b, errors.New("invalid kind %v", field.Kind())
 	}
 	return b, nerr.E
 }
diff --git a/proto/decode.go b/proto/decode.go
index 3e00074..0b1aa3f 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -86,7 +86,7 @@
 		case fieldType.Cardinality() != protoreflect.Repeated:
 			valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
 		case !fieldType.IsMap():
-			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
+			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType)
 		default:
 			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
 		}
@@ -105,8 +105,9 @@
 }
 
 func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
-	v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
-	if err != nil {
+	var nerr errors.NonFatal
+	v, n, err := o.unmarshalScalar(b, wtyp, num, field)
+	if !nerr.Merge(err) {
 		return 0, err
 	}
 	switch field.Kind() {
@@ -124,12 +125,14 @@
 			knownFields.Set(num, protoreflect.ValueOf(m))
 		}
 		// Pass up errors (fatal and otherwise).
-		err = o.unmarshalMessage(v.Bytes(), m)
+		if err := o.unmarshalMessage(v.Bytes(), m); !nerr.Merge(err) {
+			return n, err
+		}
 	default:
 		// Non-message scalars replace the previous value.
 		knownFields.Set(num, v)
 	}
-	return n, err
+	return n, nerr.E
 }
 
 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -164,17 +167,19 @@
 		err = errUnknown
 		switch num {
 		case 1:
-			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
-			if err != nil {
+			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField)
+			if !nerr.Merge(err) {
 				break
 			}
+			err = nil
 			haveKey = true
 		case 2:
 			var v protoreflect.Value
-			v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
-			if err != nil {
+			v, n, err = o.unmarshalScalar(b, wtyp, num, valField)
+			if !nerr.Merge(err) {
 				break
 			}
+			err = nil
 			switch valField.Kind() {
 			case protoreflect.GroupKind, protoreflect.MessageKind:
 				if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) {
@@ -190,7 +195,7 @@
 			if n < 0 {
 				return 0, wire.ParseError(n)
 			}
-		} else if !nerr.Merge(err) {
+		} else if err != nil {
 			return 0, err
 		}
 		b = b[n:]
diff --git a/proto/decode_gen.go b/proto/decode_gen.go
index 51b85d7..1a3ef15 100644
--- a/proto/decode_gen.go
+++ b/proto/decode_gen.go
@@ -8,6 +8,7 @@
 
 import (
 	"math"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/errors"
@@ -17,8 +18,8 @@
 // unmarshalScalar decodes a value of the given kind.
 //
 // Message values are decoded into a []byte which aliases the input data.
-func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) {
-	switch kind {
+func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp != wire.VarintType {
 			return val, 0, errUnknown
@@ -153,6 +154,11 @@
 		if n < 0 {
 			return val, 0, wire.ParseError(n)
 		}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			var nerr errors.NonFatal
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+			return protoreflect.ValueOf(string(v)), n, nerr.E
+		}
 		return protoreflect.ValueOf(string(v)), n, nil
 	case protoreflect.BytesKind:
 		if wtyp != wire.BytesType {
@@ -186,9 +192,9 @@
 	}
 }
 
-func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) {
+func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp == wire.BytesType {
 			buf, n := wire.ConsumeBytes(b)
@@ -547,6 +553,9 @@
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
 		list.Append(protoreflect.ValueOf(string(v)))
 		return n, nerr.E
 	case protoreflect.BytesKind:
diff --git a/proto/decode_test.go b/proto/decode_test.go
index dda4db1..2c95f6b 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -12,6 +12,7 @@
 	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -80,6 +81,23 @@
 	}
 }
 
+func TestDecodeInvalidUTF8(t *testing.T) {
+	for _, test := range invalidUTF8TestProtos {
+		for _, want := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+				err := proto.Unmarshal(test.wire, got)
+				if !isErrInvalidUTF8(err) {
+					t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+				}
+				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+				}
+			})
+		}
+	}
+}
+
 var testProtos = []testProto{
 	{
 		desc: "basic scalar types",
@@ -1158,6 +1176,69 @@
 	},
 }
 
+var invalidUTF8TestProtos = []testProto{
+	{
+		desc: "invalid UTF-8 in optional string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalString: "abc\xff",
+		}},
+		wire: pack.Message{
+			pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in repeated string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			RepeatedString: []string{"foo", "abc\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{44, pack.BytesType}, pack.String("foo"),
+			pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in nested message",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
+				Corecursive: &test3pb.TestAllTypes{
+					OptionalString: "abc\xff",
+				},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+				}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map key",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key\xff": "val"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
+				pack.Tag{2, pack.BytesType}, pack.String("val"),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map value",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key": "val\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key"),
+				pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
+			}),
+		}.Marshal(),
+	},
+}
+
 func build(m proto.Message, opts ...buildOpt) proto.Message {
 	for _, opt := range opts {
 		opt(m)
@@ -1185,3 +1266,17 @@
 	b, _ := textpb.Marshal(m)
 	return string(b)
 }
+
+func isErrInvalidUTF8(err error) bool {
+	nerr, ok := err.(errors.NonFatalErrors)
+	if !ok || len(nerr) == 0 {
+		return false
+	}
+	for _, err := range nerr {
+		if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
+			continue
+		}
+		return false
+	}
+	return true
+}
diff --git a/proto/encode.go b/proto/encode.go
index b294392..8635790 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -182,13 +182,13 @@
 	switch {
 	case field.Cardinality() != protoreflect.Repeated:
 		b = wire.AppendTag(b, num, wireTypes[kind])
-		return o.marshalSingular(b, num, kind, value)
+		return o.marshalSingular(b, num, field, value)
 	case field.IsMap():
 		return o.marshalMap(b, num, kind, field.MessageType(), value.Map())
 	case field.IsPacked():
-		return o.marshalPacked(b, num, kind, value.List())
+		return o.marshalPacked(b, num, field, value.List())
 	default:
-		return o.marshalList(b, num, kind, value.List())
+		return o.marshalList(b, num, field, value.List())
 	}
 }
 
@@ -229,13 +229,13 @@
 	mapsort.Range(mapv, kind, f)
 }
 
-func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
 	b = wire.AppendTag(b, num, wire.BytesType)
 	b, pos := appendSpeculativeLength(b)
 	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
-		b, err = o.marshalSingular(b, num, kind, list.Get(i))
+		b, err = o.marshalSingular(b, num, field, list.Get(i))
 		if !nerr.Merge(err) {
 			return b, err
 		}
@@ -244,12 +244,13 @@
 	return b, nerr.E
 }
 
-func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+func (o MarshalOptions) marshalList(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
+	kind := field.Kind()
 	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
 		b = wire.AppendTag(b, num, wireTypes[kind])
-		b, err = o.marshalSingular(b, num, kind, list.Get(i))
+		b, err = o.marshalSingular(b, num, field, list.Get(i))
 		if !nerr.Merge(err) {
 			return b, err
 		}
diff --git a/proto/encode_gen.go b/proto/encode_gen.go
index 46621c8..4919b96 100644
--- a/proto/encode_gen.go
+++ b/proto/encode_gen.go
@@ -8,6 +8,7 @@
 
 import (
 	"math"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/errors"
@@ -35,9 +36,9 @@
 	protoreflect.GroupKind:    wire.StartGroupType,
 }
 
-func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) {
+func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		b = wire.AppendVarint(b, wire.EncodeBool(v.Bool()))
 	case protoreflect.EnumKind:
@@ -67,6 +68,9 @@
 	case protoreflect.DoubleKind:
 		b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
 	case protoreflect.StringKind:
+		if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
 		b = wire.AppendBytes(b, []byte(v.String()))
 	case protoreflect.BytesKind:
 		b = wire.AppendBytes(b, v.Bytes())
@@ -87,7 +91,7 @@
 		}
 		b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType))
 	default:
-		return b, errors.New("invalid kind %v", kind)
+		return b, errors.New("invalid kind %v", field.Kind())
 	}
 	return b, nerr.E
 }
diff --git a/proto/encode_test.go b/proto/encode_test.go
index 30722e0..d670edf 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -92,6 +92,27 @@
 	}
 }
 
+func TestEncodeInvalidUTF8(t *testing.T) {
+	for _, test := range invalidUTF8TestProtos {
+		for _, want := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				wire, err := proto.Marshal(want)
+				if !isErrInvalidUTF8(err) {
+					t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+				}
+				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+				if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
+					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
+					return
+				}
+				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+				}
+			})
+		}
+	}
+}
+
 func TestEncodeRequiredFieldChecks(t *testing.T) {
 	for _, test := range testProtos {
 		if !test.partial {