all: support enforce_utf8 override
In 2014, when proto3 was being developed, there were a number of early
adopters of the new syntax. Before the finalization of proto3 when
it was released in open-source in July 2016, a decision was made to
strictly validate strings in proto3. However, some of the early adopters
were already using invalid UTF-8 with string fields.
The google.protobuf.FieldOptions.enforce_utf8 option only exists to support
those grandfathered users where they can opt-out of the validation logic.
Practical use of that option in open source is impossible even if a user
specifies the proto1_legacy build tag since it requires a hacked
variant of descriptor.proto that is not externally available.
This CL supports enforce_utf8 by modifiyng internal/filedesc to
expose the flag if it detects it in the raw descriptor.
We add an strs.EnforceUTF8 function as a centralized place to determine
whether to perform validation. Validation opt-out is supported
only in builds with legacy support.
We implement support for validating UTF-8 in all proto3 string fields,
even if they are backed by a Go []byte.
Change-Id: I9c0628b84909bc7181125f09db730c80d490e485
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186002
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index a92c7ea..9919243 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -95,7 +95,42 @@
unmarshal: consume{{.Name}},
}
-// size{{.Name}} returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
+{{if or (eq .Name "Bytes") (eq .Name "String")}}
+// append{{.Name}}ValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}.
+func append{{.Name}}ValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.{{.GoType.PointerMethod}}()
+ b = wire.AppendVarint(b, wiretag)
+ {{template "Append" .}}
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consume{{.Name}}ValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
+func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != {{.WireType.Expr}} {
+ return 0, errUnknown
+ }
+ v, n := {{template "Consume" .}}
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.{{.GoType.PointerMethod}}() = {{.ToGoType}}
+ return n, nil
+}
+
+var coder{{.Name}}ValidateUTF8 = pointerCoderFuncs{
+ size: size{{.Name}},
+ marshal: append{{.Name}}ValidateUTF8,
+ unmarshal: consume{{.Name}}ValidateUTF8,
+}
+{{end}}
+
+// size{{.Name}}NoZero returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not encoded.
func size{{.Name}}NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.{{.GoType.PointerMethod}}()
@@ -105,7 +140,7 @@
return tagsize + {{template "Size" .}}
}
-// append{{.Name}} wire encodes a {{.GoType}} pointer as a {{.Name}}.
+// append{{.Name}}NoZero wire encodes a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not encoded.
func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.{{.GoType.PointerMethod}}()
@@ -123,6 +158,29 @@
unmarshal: consume{{.Name}},
}
+{{if or (eq .Name "Bytes") (eq .Name "String")}}
+// append{{.Name}}NoZeroValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}.
+// The zero value is not encoded.
+func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.{{.GoType.PointerMethod}}()
+ if {{template "IsZero" .}} {
+ return b, nil
+ }
+ b = wire.AppendVarint(b, wiretag)
+ {{template "Append" .}}
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
+ size: size{{.Name}}NoZero,
+ marshal: append{{.Name}}NoZeroValidateUTF8,
+ unmarshal: consume{{.Name}}ValidateUTF8,
+}
+{{end}}
+
{{- if not .NoPointer}}
// size{{.Name}}Ptr returns the size of wire encoding a *{{.GoType}} pointer as a {{.Name}}.
// It panics if the pointer is nil.
@@ -228,6 +286,44 @@
unmarshal: consume{{.Name}}Slice,
}
+{{if or (eq .Name "Bytes") (eq .Name "String")}}
+// append{{.Name}}SliceValidateUTF8 encodes a []{{.GoType}} pointer as a repeated {{.Name}}.
+func append{{.Name}}SliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ s := *p.{{.GoType.PointerMethod}}Slice()
+ for _, v := range s {
+ b = wire.AppendVarint(b, wiretag)
+ {{template "Append" .}}
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return b, errInvalidUTF8{}
+ }
+ }
+ return b, nil
+}
+
+// consume{{.Name}}SliceValidateUTF8 wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}.
+func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ sp := p.{{.GoType.PointerMethod}}Slice()
+ if wtyp != {{.WireType.Expr}} {
+ return 0, errUnknown
+ }
+ v, n := {{template "Consume" .}}
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *sp = append(*sp, {{.ToGoType}})
+ return n, nil
+}
+
+var coder{{.Name}}SliceValidateUTF8 = pointerCoderFuncs{
+ size: size{{.Name}}Slice,
+ marshal: append{{.Name}}SliceValidateUTF8,
+ unmarshal: consume{{.Name}}SliceValidateUTF8,
+}
+{{end}}
+
{{if or (eq .WireType "Varint") (eq .WireType "Fixed32") (eq .WireType "Fixed64")}}
// size{{.Name}}PackedSlice returns the size of wire encoding a []{{.GoType}} pointer as a packed repeated {{.Name}}.
func size{{.Name}}PackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
@@ -309,6 +405,40 @@
unmarshal: consume{{.Name}}Iface,
}
+{{if or (eq .Name "Bytes") (eq .Name "String")}}
+// append{{.Name}}IfaceValidateUTF8 encodes a {{.GoType}} value as a {{.Name}}.
+func append{{.Name}}IfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := ival.({{.GoType}})
+ b = wire.AppendVarint(b, wiretag)
+ {{template "Append" .}}
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consume{{.Name}}IfaceValidateUTF8 decodes a {{.GoType}} value as a {{.Name}}.
+func consume{{.Name}}IfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
+ if wtyp != {{.WireType.Expr}} {
+ return nil, 0, errUnknown
+ }
+ v, n := {{template "Consume" .}}
+ if n < 0 {
+ return nil, 0, wire.ParseError(n)
+ }
+ if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
+ return nil, 0, errInvalidUTF8{}
+ }
+ return {{.ToGoType}}, n, nil
+}
+
+var coder{{.Name}}IfaceValidateUTF8 = ifaceCoderFuncs{
+ size: size{{.Name}}Iface,
+ marshal: append{{.Name}}IfaceValidateUTF8,
+ unmarshal: consume{{.Name}}IfaceValidateUTF8,
+}
+{{end}}
+
// size{{.Name}}SliceIface returns the size of wire encoding a []{{.GoType}} value as a repeated {{.Name}}.
func size{{.Name}}SliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[]{{.GoType}})
diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go
index 6490b5c..6d19508 100644
--- a/internal/cmd/generate-types/main.go
+++ b/internal/cmd/generate-types/main.go
@@ -191,6 +191,7 @@
"google.golang.org/protobuf/internal/descfmt",
"google.golang.org/protobuf/internal/encoding/wire",
"google.golang.org/protobuf/internal/errors",
+ "google.golang.org/protobuf/internal/strs",
"google.golang.org/protobuf/internal/pragma",
"google.golang.org/protobuf/reflect/protoreflect",
"google.golang.org/protobuf/runtime/protoiface",
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index e507b03..023cde8 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -276,7 +276,7 @@
return val, 0, wire.ParseError(n)
}
{{if (eq .Name "String") -}}
- if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+ if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
@@ -320,7 +320,7 @@
return 0, wire.ParseError(n)
}
{{if (eq .Name "String") -}}
- if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+ if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
@@ -357,7 +357,7 @@
{{- range .}}
case {{.Expr}}:
{{- if (eq .Name "String") }}
- if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+ if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
return b, errors.InvalidUTF8(string(fd.FullName()))
}
b = wire.AppendString(b, {{.FromValue}})
diff --git a/internal/filedesc/desc.go b/internal/filedesc/desc.go
index 59984e8..d42bcd7 100644
--- a/internal/filedesc/desc.go
+++ b/internal/filedesc/desc.go
@@ -200,6 +200,8 @@
IsWeak bool // promoted from google.protobuf.FieldOptions
HasPacked bool // promoted from google.protobuf.FieldOptions
IsPacked bool // promoted from google.protobuf.FieldOptions
+ HasEnforceUTF8 bool // promoted from google.protobuf.FieldOptions
+ EnforceUTF8 bool // promoted from google.protobuf.FieldOptions
Default defaultValue
ContainingOneof pref.OneofDescriptor // must be consistent with Message.Oneofs.Fields
Enum pref.EnumDescriptor
@@ -303,6 +305,20 @@
func (fd *Field) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, fd) }
func (fd *Field) ProtoType(pref.FieldDescriptor) {}
+// EnforceUTF8 is a pseudo-internal API to determine whether to enforce UTF-8
+// validation for the string field. This exists for Google-internal use only
+// since proto3 did not enforce UTF-8 validity prior to the open-source release.
+// If this method does not exist, the default is to enforce valid UTF-8.
+//
+// WARNING: This method is exempt from the compatibility promise and may be
+// removed in the future without warning.
+func (fd *Field) EnforceUTF8() bool {
+ if fd.L1.HasEnforceUTF8 {
+ return fd.L1.EnforceUTF8
+ }
+ return fd.L0.ParentFile.L1.Syntax == pref.Proto3
+}
+
func (od *Oneof) Options() pref.ProtoMessage {
if f := od.L1.Options; f != nil {
return f()
diff --git a/internal/filedesc/desc_lazy.go b/internal/filedesc/desc_lazy.go
index 55104ad..9b54e6d 100644
--- a/internal/filedesc/desc_lazy.go
+++ b/internal/filedesc/desc_lazy.go
@@ -480,6 +480,8 @@
}
func (fd *Field) unmarshalOptions(b []byte) {
+ const FieldOptions_EnforceUTF8 = 13
+
for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b)
b = b[n:]
@@ -493,6 +495,9 @@
fd.L1.IsPacked = wire.DecodeBool(v)
case fieldnum.FieldOptions_Weak:
fd.L1.IsWeak = wire.DecodeBool(v)
+ case FieldOptions_EnforceUTF8:
+ fd.L1.HasEnforceUTF8 = true
+ fd.L1.EnforceUTF8 = !wire.DecodeBool(v)
}
default:
m := wire.ConsumeFieldValue(num, typ, b)
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 8d0e339..94b7d6a 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -6,7 +6,6 @@
import (
"reflect"
- "unicode/utf8"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
@@ -747,136 +746,6 @@
unmarshal: consumeEnumSliceIface,
}
-// Strings with UTF8 validation.
-
-func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
- v := *p.String()
- b = wire.AppendVarint(b, wiretag)
- b = wire.AppendString(b, v)
- if !utf8.ValidString(v) {
- return b, errInvalidUTF8{}
- }
- return b, nil
-}
-
-func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
- if wtyp != wire.BytesType {
- return 0, errUnknown
- }
- v, n := wire.ConsumeString(b)
- if n < 0 {
- return 0, wire.ParseError(n)
- }
- if !utf8.ValidString(v) {
- return 0, errInvalidUTF8{}
- }
- *p.String() = v
- return n, nil
-}
-
-var coderStringValidateUTF8 = pointerCoderFuncs{
- size: sizeString,
- marshal: appendStringValidateUTF8,
- unmarshal: consumeStringValidateUTF8,
-}
-
-func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
- v := *p.String()
- if len(v) == 0 {
- return b, nil
- }
- b = wire.AppendVarint(b, wiretag)
- b = wire.AppendString(b, v)
- if !utf8.ValidString(v) {
- return b, errInvalidUTF8{}
- }
- return b, nil
-}
-
-var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
- size: sizeStringNoZero,
- marshal: appendStringNoZeroValidateUTF8,
- unmarshal: consumeStringValidateUTF8,
-}
-
-func sizeStringSliceValidateUTF8(p pointer, tagsize int, _ marshalOptions) (size int) {
- s := *p.StringSlice()
- for _, v := range s {
- size += tagsize + wire.SizeBytes(len(v))
- }
- return size
-}
-
-func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
- s := *p.StringSlice()
- var err error
- for _, v := range s {
- b = wire.AppendVarint(b, wiretag)
- b = wire.AppendString(b, v)
- if !utf8.ValidString(v) {
- return b, errInvalidUTF8{}
- }
- }
- return b, err
-}
-
-func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
- if wtyp != wire.BytesType {
- return 0, errUnknown
- }
- sp := p.StringSlice()
- v, n := wire.ConsumeString(b)
- if n < 0 {
- return 0, wire.ParseError(n)
- }
- if !utf8.ValidString(v) {
- return 0, errInvalidUTF8{}
- }
- *sp = append(*sp, v)
- return n, nil
-}
-
-var coderStringSliceValidateUTF8 = pointerCoderFuncs{
- size: sizeStringSliceValidateUTF8,
- marshal: appendStringSliceValidateUTF8,
- unmarshal: consumeStringSliceValidateUTF8,
-}
-
-func sizeStringIfaceValidateUTF8(ival interface{}, tagsize int, _ marshalOptions) int {
- v := ival.(string)
- return tagsize + wire.SizeBytes(len(v))
-}
-
-func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
- v := ival.(string)
- b = wire.AppendVarint(b, wiretag)
- b = wire.AppendString(b, v)
- if !utf8.ValidString(v) {
- return b, errInvalidUTF8{}
- }
- return b, nil
-}
-
-func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
- if wtyp != wire.BytesType {
- return nil, 0, errUnknown
- }
- v, n := wire.ConsumeString(b)
- if n < 0 {
- return nil, 0, wire.ParseError(n)
- }
- if !utf8.ValidString(v) {
- return nil, 0, errInvalidUTF8{}
- }
- return v, n, nil
-}
-
-var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
- size: sizeStringIfaceValidateUTF8,
- marshal: appendStringIfaceValidateUTF8,
- unmarshal: consumeStringIfaceValidateUTF8,
-}
-
func asMessage(v reflect.Value) pref.ProtoMessage {
if m, ok := v.Interface().(pref.ProtoMessage); ok {
return m
diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go
index 41bd099..46380f5 100644
--- a/internal/impl/codec_gen.go
+++ b/internal/impl/codec_gen.go
@@ -8,6 +8,7 @@
import (
"math"
+ "unicode/utf8"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/reflect/protoreflect"
@@ -46,7 +47,7 @@
unmarshal: consumeBool,
}
-// sizeBool returns the size of wire encoding a bool pointer as a Bool.
+// sizeBoolNoZero returns the size of wire encoding a bool pointer as a Bool.
// The zero value is not encoded.
func sizeBoolNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Bool()
@@ -56,7 +57,7 @@
return tagsize + wire.SizeVarint(wire.EncodeBool(v))
}
-// appendBool wire encodes a bool pointer as a Bool.
+// appendBoolNoZero wire encodes a bool pointer as a Bool.
// The zero value is not encoded.
func appendBoolNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bool()
@@ -364,7 +365,7 @@
unmarshal: consumeInt32,
}
-// sizeInt32 returns the size of wire encoding a int32 pointer as a Int32.
+// sizeInt32NoZero returns the size of wire encoding a int32 pointer as a Int32.
// The zero value is not encoded.
func sizeInt32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@@ -374,7 +375,7 @@
return tagsize + wire.SizeVarint(uint64(v))
}
-// appendInt32 wire encodes a int32 pointer as a Int32.
+// appendInt32NoZero wire encodes a int32 pointer as a Int32.
// The zero value is not encoded.
func appendInt32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@@ -682,7 +683,7 @@
unmarshal: consumeSint32,
}
-// sizeSint32 returns the size of wire encoding a int32 pointer as a Sint32.
+// sizeSint32NoZero returns the size of wire encoding a int32 pointer as a Sint32.
// The zero value is not encoded.
func sizeSint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@@ -692,7 +693,7 @@
return tagsize + wire.SizeVarint(wire.EncodeZigZag(int64(v)))
}
-// appendSint32 wire encodes a int32 pointer as a Sint32.
+// appendSint32NoZero wire encodes a int32 pointer as a Sint32.
// The zero value is not encoded.
func appendSint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@@ -1000,7 +1001,7 @@
unmarshal: consumeUint32,
}
-// sizeUint32 returns the size of wire encoding a uint32 pointer as a Uint32.
+// sizeUint32NoZero returns the size of wire encoding a uint32 pointer as a Uint32.
// The zero value is not encoded.
func sizeUint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint32()
@@ -1010,7 +1011,7 @@
return tagsize + wire.SizeVarint(uint64(v))
}
-// appendUint32 wire encodes a uint32 pointer as a Uint32.
+// appendUint32NoZero wire encodes a uint32 pointer as a Uint32.
// The zero value is not encoded.
func appendUint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint32()
@@ -1318,7 +1319,7 @@
unmarshal: consumeInt64,
}
-// sizeInt64 returns the size of wire encoding a int64 pointer as a Int64.
+// sizeInt64NoZero returns the size of wire encoding a int64 pointer as a Int64.
// The zero value is not encoded.
func sizeInt64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@@ -1328,7 +1329,7 @@
return tagsize + wire.SizeVarint(uint64(v))
}
-// appendInt64 wire encodes a int64 pointer as a Int64.
+// appendInt64NoZero wire encodes a int64 pointer as a Int64.
// The zero value is not encoded.
func appendInt64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@@ -1636,7 +1637,7 @@
unmarshal: consumeSint64,
}
-// sizeSint64 returns the size of wire encoding a int64 pointer as a Sint64.
+// sizeSint64NoZero returns the size of wire encoding a int64 pointer as a Sint64.
// The zero value is not encoded.
func sizeSint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@@ -1646,7 +1647,7 @@
return tagsize + wire.SizeVarint(wire.EncodeZigZag(v))
}
-// appendSint64 wire encodes a int64 pointer as a Sint64.
+// appendSint64NoZero wire encodes a int64 pointer as a Sint64.
// The zero value is not encoded.
func appendSint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@@ -1954,7 +1955,7 @@
unmarshal: consumeUint64,
}
-// sizeUint64 returns the size of wire encoding a uint64 pointer as a Uint64.
+// sizeUint64NoZero returns the size of wire encoding a uint64 pointer as a Uint64.
// The zero value is not encoded.
func sizeUint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint64()
@@ -1964,7 +1965,7 @@
return tagsize + wire.SizeVarint(v)
}
-// appendUint64 wire encodes a uint64 pointer as a Uint64.
+// appendUint64NoZero wire encodes a uint64 pointer as a Uint64.
// The zero value is not encoded.
func appendUint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint64()
@@ -2272,7 +2273,7 @@
unmarshal: consumeSfixed32,
}
-// sizeSfixed32 returns the size of wire encoding a int32 pointer as a Sfixed32.
+// sizeSfixed32NoZero returns the size of wire encoding a int32 pointer as a Sfixed32.
// The zero value is not encoded.
func sizeSfixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@@ -2282,7 +2283,7 @@
return tagsize + wire.SizeFixed32()
}
-// appendSfixed32 wire encodes a int32 pointer as a Sfixed32.
+// appendSfixed32NoZero wire encodes a int32 pointer as a Sfixed32.
// The zero value is not encoded.
func appendSfixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@@ -2572,7 +2573,7 @@
unmarshal: consumeFixed32,
}
-// sizeFixed32 returns the size of wire encoding a uint32 pointer as a Fixed32.
+// sizeFixed32NoZero returns the size of wire encoding a uint32 pointer as a Fixed32.
// The zero value is not encoded.
func sizeFixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint32()
@@ -2582,7 +2583,7 @@
return tagsize + wire.SizeFixed32()
}
-// appendFixed32 wire encodes a uint32 pointer as a Fixed32.
+// appendFixed32NoZero wire encodes a uint32 pointer as a Fixed32.
// The zero value is not encoded.
func appendFixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint32()
@@ -2872,7 +2873,7 @@
unmarshal: consumeFloat,
}
-// sizeFloat returns the size of wire encoding a float32 pointer as a Float.
+// sizeFloatNoZero returns the size of wire encoding a float32 pointer as a Float.
// The zero value is not encoded.
func sizeFloatNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Float32()
@@ -2882,7 +2883,7 @@
return tagsize + wire.SizeFixed32()
}
-// appendFloat wire encodes a float32 pointer as a Float.
+// appendFloatNoZero wire encodes a float32 pointer as a Float.
// The zero value is not encoded.
func appendFloatNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Float32()
@@ -3172,7 +3173,7 @@
unmarshal: consumeSfixed64,
}
-// sizeSfixed64 returns the size of wire encoding a int64 pointer as a Sfixed64.
+// sizeSfixed64NoZero returns the size of wire encoding a int64 pointer as a Sfixed64.
// The zero value is not encoded.
func sizeSfixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@@ -3182,7 +3183,7 @@
return tagsize + wire.SizeFixed64()
}
-// appendSfixed64 wire encodes a int64 pointer as a Sfixed64.
+// appendSfixed64NoZero wire encodes a int64 pointer as a Sfixed64.
// The zero value is not encoded.
func appendSfixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@@ -3472,7 +3473,7 @@
unmarshal: consumeFixed64,
}
-// sizeFixed64 returns the size of wire encoding a uint64 pointer as a Fixed64.
+// sizeFixed64NoZero returns the size of wire encoding a uint64 pointer as a Fixed64.
// The zero value is not encoded.
func sizeFixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint64()
@@ -3482,7 +3483,7 @@
return tagsize + wire.SizeFixed64()
}
-// appendFixed64 wire encodes a uint64 pointer as a Fixed64.
+// appendFixed64NoZero wire encodes a uint64 pointer as a Fixed64.
// The zero value is not encoded.
func appendFixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint64()
@@ -3772,7 +3773,7 @@
unmarshal: consumeDouble,
}
-// sizeDouble returns the size of wire encoding a float64 pointer as a Double.
+// sizeDoubleNoZero returns the size of wire encoding a float64 pointer as a Double.
// The zero value is not encoded.
func sizeDoubleNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Float64()
@@ -3782,7 +3783,7 @@
return tagsize + wire.SizeFixed64()
}
-// appendDouble wire encodes a float64 pointer as a Double.
+// appendDoubleNoZero wire encodes a float64 pointer as a Double.
// The zero value is not encoded.
func appendDoubleNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Float64()
@@ -4072,7 +4073,40 @@
unmarshal: consumeString,
}
-// sizeString returns the size of wire encoding a string pointer as a String.
+// appendStringValidateUTF8 wire encodes a string pointer as a String.
+func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.String()
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendString(b, v)
+ if !utf8.ValidString(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consumeStringValidateUTF8 wire decodes a string pointer as a String.
+func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.String() = v
+ return n, nil
+}
+
+var coderStringValidateUTF8 = pointerCoderFuncs{
+ size: sizeString,
+ marshal: appendStringValidateUTF8,
+ unmarshal: consumeStringValidateUTF8,
+}
+
+// sizeStringNoZero returns the size of wire encoding a string pointer as a String.
// The zero value is not encoded.
func sizeStringNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.String()
@@ -4082,7 +4116,7 @@
return tagsize + wire.SizeBytes(len(v))
}
-// appendString wire encodes a string pointer as a String.
+// appendStringNoZero wire encodes a string pointer as a String.
// The zero value is not encoded.
func appendStringNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
@@ -4100,6 +4134,27 @@
unmarshal: consumeString,
}
+// appendStringNoZeroValidateUTF8 wire encodes a string pointer as a String.
+// The zero value is not encoded.
+func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.String()
+ if len(v) == 0 {
+ return b, nil
+ }
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendString(b, v)
+ if !utf8.ValidString(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
+ size: sizeStringNoZero,
+ marshal: appendStringNoZeroValidateUTF8,
+ unmarshal: consumeStringValidateUTF8,
+}
+
// sizeStringPtr returns the size of wire encoding a *string pointer as a String.
// It panics if the pointer is nil.
func sizeStringPtr(p pointer, tagsize int, _ marshalOptions) (size int) {
@@ -4178,6 +4233,42 @@
unmarshal: consumeStringSlice,
}
+// appendStringSliceValidateUTF8 encodes a []string pointer as a repeated String.
+func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ s := *p.StringSlice()
+ for _, v := range s {
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendString(b, v)
+ if !utf8.ValidString(v) {
+ return b, errInvalidUTF8{}
+ }
+ }
+ return b, nil
+}
+
+// consumeStringSliceValidateUTF8 wire decodes a []string pointer as a repeated String.
+func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ sp := p.StringSlice()
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *sp = append(*sp, v)
+ return n, nil
+}
+
+var coderStringSliceValidateUTF8 = pointerCoderFuncs{
+ size: sizeStringSlice,
+ marshal: appendStringSliceValidateUTF8,
+ unmarshal: consumeStringSliceValidateUTF8,
+}
+
// sizeStringIface returns the size of wire encoding a string value as a String.
func sizeStringIface(ival interface{}, tagsize int, _ marshalOptions) int {
v := ival.(string)
@@ -4210,6 +4301,38 @@
unmarshal: consumeStringIface,
}
+// appendStringIfaceValidateUTF8 encodes a string value as a String.
+func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := ival.(string)
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendString(b, v)
+ if !utf8.ValidString(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consumeStringIfaceValidateUTF8 decodes a string value as a String.
+func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
+ if wtyp != wire.BytesType {
+ return nil, 0, errUnknown
+ }
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return nil, 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return nil, 0, errInvalidUTF8{}
+ }
+ return v, n, nil
+}
+
+var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
+ size: sizeStringIface,
+ marshal: appendStringIfaceValidateUTF8,
+ unmarshal: consumeStringIfaceValidateUTF8,
+}
+
// sizeStringSliceIface returns the size of wire encoding a []string value as a repeated String.
func sizeStringSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[]string)
@@ -4282,7 +4405,40 @@
unmarshal: consumeBytes,
}
-// sizeBytes returns the size of wire encoding a []byte pointer as a Bytes.
+// appendBytesValidateUTF8 wire encodes a []byte pointer as a Bytes.
+func appendBytesValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.Bytes()
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendBytes(b, v)
+ if !utf8.Valid(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consumeBytesValidateUTF8 wire decodes a []byte pointer as a Bytes.
+func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.Bytes() = append(([]byte)(nil), v...)
+ return n, nil
+}
+
+var coderBytesValidateUTF8 = pointerCoderFuncs{
+ size: sizeBytes,
+ marshal: appendBytesValidateUTF8,
+ unmarshal: consumeBytesValidateUTF8,
+}
+
+// sizeBytesNoZero returns the size of wire encoding a []byte pointer as a Bytes.
// The zero value is not encoded.
func sizeBytesNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Bytes()
@@ -4292,7 +4448,7 @@
return tagsize + wire.SizeBytes(len(v))
}
-// appendBytes wire encodes a []byte pointer as a Bytes.
+// appendBytesNoZero wire encodes a []byte pointer as a Bytes.
// The zero value is not encoded.
func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bytes()
@@ -4310,6 +4466,27 @@
unmarshal: consumeBytes,
}
+// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
+// The zero value is not encoded.
+func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := *p.Bytes()
+ if len(v) == 0 {
+ return b, nil
+ }
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendBytes(b, v)
+ if !utf8.Valid(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
+ size: sizeBytesNoZero,
+ marshal: appendBytesNoZeroValidateUTF8,
+ unmarshal: consumeBytesValidateUTF8,
+}
+
// sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
func sizeBytesSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
s := *p.BytesSlice()
@@ -4349,6 +4526,42 @@
unmarshal: consumeBytesSlice,
}
+// appendBytesSliceValidateUTF8 encodes a [][]byte pointer as a repeated Bytes.
+func appendBytesSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ s := *p.BytesSlice()
+ for _, v := range s {
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendBytes(b, v)
+ if !utf8.Valid(v) {
+ return b, errInvalidUTF8{}
+ }
+ }
+ return b, nil
+}
+
+// consumeBytesSliceValidateUTF8 wire decodes a [][]byte pointer as a repeated Bytes.
+func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ sp := p.BytesSlice()
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.Valid(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *sp = append(*sp, append(([]byte)(nil), v...))
+ return n, nil
+}
+
+var coderBytesSliceValidateUTF8 = pointerCoderFuncs{
+ size: sizeBytesSlice,
+ marshal: appendBytesSliceValidateUTF8,
+ unmarshal: consumeBytesSliceValidateUTF8,
+}
+
// sizeBytesIface returns the size of wire encoding a []byte value as a Bytes.
func sizeBytesIface(ival interface{}, tagsize int, _ marshalOptions) int {
v := ival.([]byte)
@@ -4381,6 +4594,38 @@
unmarshal: consumeBytesIface,
}
+// appendBytesIfaceValidateUTF8 encodes a []byte value as a Bytes.
+func appendBytesIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
+ v := ival.([]byte)
+ b = wire.AppendVarint(b, wiretag)
+ b = wire.AppendBytes(b, v)
+ if !utf8.Valid(v) {
+ return b, errInvalidUTF8{}
+ }
+ return b, nil
+}
+
+// consumeBytesIfaceValidateUTF8 decodes a []byte value as a Bytes.
+func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
+ if wtyp != wire.BytesType {
+ return nil, 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return nil, 0, wire.ParseError(n)
+ }
+ if !utf8.Valid(v) {
+ return nil, 0, errInvalidUTF8{}
+ }
+ return append(([]byte)(nil), v...), n, nil
+}
+
+var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
+ size: sizeBytesIface,
+ marshal: appendBytesIfaceValidateUTF8,
+ unmarshal: consumeBytesIfaceValidateUTF8,
+}
+
// sizeBytesSliceIface returns the size of wire encoding a [][]byte value as a repeated Bytes.
func sizeBytesSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[][]byte)
diff --git a/internal/impl/codec_tables.go b/internal/impl/codec_tables.go
index 564187e..3ff4260 100644
--- a/internal/impl/codec_tables.go
+++ b/internal/impl/codec_tables.go
@@ -9,6 +9,7 @@
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/strs"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
@@ -98,12 +99,15 @@
return coderDoubleSlice
}
case pref.StringKind:
- if ft.Kind() == reflect.String && fd.Syntax() == pref.Proto3 {
+ if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringSliceValidateUTF8
}
if ft.Kind() == reflect.String {
return coderStringSlice
}
+ if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
+ return coderBytesSliceValidateUTF8
+ }
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesSlice
}
@@ -251,9 +255,15 @@
return coderDoubleNoZero
}
case pref.StringKind:
- if ft.Kind() == reflect.String {
+ if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringNoZeroValidateUTF8
}
+ if ft.Kind() == reflect.String {
+ return coderStringNoZero
+ }
+ if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
+ return coderBytesNoZeroValidateUTF8
+ }
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesNoZero
}
@@ -392,12 +402,15 @@
return coderDouble
}
case pref.StringKind:
- if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String {
+ if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringValidateUTF8
}
if ft.Kind() == reflect.String {
return coderString
}
+ if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
+ return coderBytesValidateUTF8
+ }
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytes
}
@@ -620,12 +633,15 @@
return coderDoubleIface
}
case pref.StringKind:
- if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String {
+ if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringIfaceValidateUTF8
}
if ft.Kind() == reflect.String {
return coderStringIface
}
+ if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
+ return coderBytesIfaceValidateUTF8
+ }
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesIface
}
diff --git a/internal/strs/strings.go b/internal/strs/strings.go
index 295bd29..af5f197 100644
--- a/internal/strs/strings.go
+++ b/internal/strs/strings.go
@@ -8,8 +8,21 @@
import (
"strings"
"unicode"
+
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/reflect/protoreflect"
)
+// EnforceUTF8 reports whether to enforce strict UTF-8 validation.
+func EnforceUTF8(fd protoreflect.FieldDescriptor) bool {
+ if flags.Proto1Legacy {
+ if fd, ok := fd.(interface{ EnforceUTF8() bool }); ok {
+ return fd.EnforceUTF8()
+ }
+ }
+ return fd.Syntax() == protoreflect.Proto3
+}
+
// JSONCamelCase converts a snake_case identifier to a camelCase identifier,
// according to the protobuf JSON specification.
func JSONCamelCase(s string) string {
diff --git a/proto/decode_gen.go b/proto/decode_gen.go
index a272242..dbb4c87 100644
--- a/proto/decode_gen.go
+++ b/proto/decode_gen.go
@@ -12,6 +12,7 @@
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
@@ -154,7 +155,7 @@
if n < 0 {
return val, 0, wire.ParseError(n)
}
- if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+ if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
}
return protoreflect.ValueOf(string(v)), n, nil
@@ -550,7 +551,7 @@
if n < 0 {
return 0, wire.ParseError(n)
}
- if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+ if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName()))
}
list.Append(protoreflect.ValueOf(string(v)))
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 5fa3a0f..ce2e1af 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -12,13 +12,20 @@
protoV1 "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/encoding/pack"
+ "google.golang.org/protobuf/internal/filedesc"
+ "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protodesc"
+ "google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/reflect/prototype"
+ "google.golang.org/protobuf/runtime/protoimpl"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
+ "google.golang.org/protobuf/types/descriptorpb"
)
type testProto struct {
@@ -85,6 +92,23 @@
}
}
+func TestDecodeNoEnforceUTF8(t *testing.T) {
+ for _, test := range noEnforceUTF8TestProtos {
+ for _, want := range test.decodeTo {
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+ got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ err := proto.Unmarshal(test.wire, got)
+ switch {
+ case flags.Proto1Legacy && err != nil:
+ t.Errorf("Unmarshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want))
+ case !flags.Proto1Legacy && err == nil:
+ t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+ }
+ })
+ }
+ }
+}
+
var testProtos = []testProto{
{
desc: "basic scalar types",
@@ -1442,6 +1466,129 @@
},
}
+var noEnforceUTF8TestProtos = []testProto{
+ {
+ desc: "invalid UTF-8 in optional string field",
+ decodeTo: []proto.Message{&TestNoEnforceUTF8{
+ OptionalString: string("abc\xff"),
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in optional string field of Go bytes",
+ decodeTo: []proto.Message{&TestNoEnforceUTF8{
+ OptionalBytes: []byte("abc\xff"),
+ }},
+ wire: pack.Message{
+ pack.Tag{2, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in repeated string field",
+ decodeTo: []proto.Message{&TestNoEnforceUTF8{
+ RepeatedString: []string{string("foo"), string("abc\xff")},
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.BytesType}, pack.String("foo"),
+ pack.Tag{3, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in repeated string field of Go bytes",
+ decodeTo: []proto.Message{&TestNoEnforceUTF8{
+ RepeatedBytes: [][]byte{[]byte("foo"), []byte("abc\xff")},
+ }},
+ wire: pack.Message{
+ pack.Tag{4, pack.BytesType}, pack.String("foo"),
+ pack.Tag{4, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in oneof string field",
+ decodeTo: []proto.Message{
+ &TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofString{string("abc\xff")}},
+ },
+ wire: pack.Message{pack.Tag{5, pack.BytesType}, pack.String("abc\xff")}.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in oneof string field of Go bytes",
+ decodeTo: []proto.Message{
+ &TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofBytes{[]byte("abc\xff")}},
+ },
+ wire: pack.Message{pack.Tag{6, pack.BytesType}, pack.String("abc\xff")}.Marshal(),
+ },
+}
+
+type TestNoEnforceUTF8 struct {
+ OptionalString string `protobuf:"bytes,1,opt,name=optional_string"`
+ OptionalBytes []byte `protobuf:"bytes,2,opt,name=optional_bytes"`
+ RepeatedString []string `protobuf:"bytes,3,rep,name=repeated_string"`
+ RepeatedBytes [][]byte `protobuf:"bytes,4,rep,name=repeated_bytes"`
+ OneofField isOneofField `protobuf_oneof:"oneof_field"`
+}
+
+type isOneofField interface{ isOneofField() }
+
+type TestNoEnforceUTF8_OneofString struct {
+ OneofString string `protobuf:"bytes,5,opt,name=oneof_string,oneof"`
+}
+type TestNoEnforceUTF8_OneofBytes struct {
+ OneofBytes []byte `protobuf:"bytes,6,opt,name=oneof_bytes,oneof"`
+}
+
+func (*TestNoEnforceUTF8_OneofString) isOneofField() {}
+func (*TestNoEnforceUTF8_OneofBytes) isOneofField() {}
+
+func (m *TestNoEnforceUTF8) ProtoReflect() pref.Message {
+ return messageInfo_TestNoEnforceUTF8.MessageOf(m)
+}
+
+var messageInfo_TestNoEnforceUTF8 = protoimpl.MessageInfo{
+ GoType: reflect.TypeOf((*TestNoEnforceUTF8)(nil)),
+ PBType: &prototype.Message{
+ MessageDescriptor: func() protoreflect.MessageDescriptor {
+ pb := new(descriptorpb.FileDescriptorProto)
+ if err := prototext.Unmarshal([]byte(`
+ syntax: "proto3"
+ name: "test.proto"
+ message_type: [{
+ name: "TestNoEnforceUTF8"
+ field: [
+ {name:"optional_string" number:1 label:LABEL_OPTIONAL type:TYPE_STRING},
+ {name:"optional_bytes" number:2 label:LABEL_OPTIONAL type:TYPE_STRING},
+ {name:"repeated_string" number:3 label:LABEL_REPEATED type:TYPE_STRING},
+ {name:"repeated_bytes" number:4 label:LABEL_REPEATED type:TYPE_STRING},
+ {name:"oneof_string" number:5 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0},
+ {name:"oneof_bytes" number:6 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0}
+ ]
+ oneof_decl: [{name:"oneof_field"}]
+ }]
+ `), pb); err != nil {
+ panic(err)
+ }
+ fd, err := protodesc.NewFile(pb, nil)
+ if err != nil {
+ panic(err)
+ }
+ md := fd.Messages().Get(0)
+ for i := 0; i < md.Fields().Len(); i++ {
+ md.Fields().Get(i).(*filedesc.Field).L1.HasEnforceUTF8 = true
+ md.Fields().Get(i).(*filedesc.Field).L1.EnforceUTF8 = false
+ }
+ return md
+ }(),
+ NewMessage: func() pref.Message {
+ return pref.ProtoMessage(new(TestNoEnforceUTF8)).ProtoReflect()
+ },
+ },
+ OneofWrappers: []interface{}{
+ (*TestNoEnforceUTF8_OneofString)(nil),
+ (*TestNoEnforceUTF8_OneofBytes)(nil),
+ },
+}
+
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)
diff --git a/proto/encode_gen.go b/proto/encode_gen.go
index fe977e3..77b6511 100644
--- a/proto/encode_gen.go
+++ b/proto/encode_gen.go
@@ -12,6 +12,7 @@
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
@@ -67,7 +68,7 @@
case protoreflect.DoubleKind:
b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
case protoreflect.StringKind:
- if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+ if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
return b, errors.InvalidUTF8(string(fd.FullName()))
}
b = wire.AppendString(b, v.String())
diff --git a/proto/encode_test.go b/proto/encode_test.go
index f90020a..573a197 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -10,6 +10,7 @@
"testing"
"github.com/google/go-cmp/cmp"
+ "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
@@ -97,6 +98,22 @@
}
}
+func TestEncodeNoEnforceUTF8(t *testing.T) {
+ for _, test := range noEnforceUTF8TestProtos {
+ for _, want := range test.decodeTo {
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+ _, err := proto.Marshal(want)
+ switch {
+ case flags.Proto1Legacy && err != nil:
+ t.Errorf("Marshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want))
+ case !flags.Proto1Legacy && err == nil:
+ t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+ }
+ })
+ }
+ }
+}
+
func TestEncodeRequiredFieldChecks(t *testing.T) {
for _, test := range testProtos {
if !test.partial {