encoding/textpb: marshal google.protobuf.Any in expanded form
Marshal well-known type Any in expanded form by default, else fallback
to marshaling it as a regular message.
Change-Id: Ic7e9e37b47042a163941f8849dc366ffe48103ca
Reviewed-on: https://go-review.googlesource.com/c/156097
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index c0b04b6..62ebc63 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -14,6 +14,7 @@
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+ "github.com/golang/protobuf/v2/reflect/protoregistry"
)
// Marshal writes the given proto.Message in textproto format using default options.
@@ -28,13 +29,22 @@
// Set Compact to true to have output in a single line with no line breaks.
Compact bool
+
+ // Resolver is the registry used for type lookups when marshaling out
+ // google.protobuf.Any messages in expanded form. If Resolver is not set,
+ // marshaling will default to using protoregistry.GlobalTypes. If a type is
+ // not found, an Any message will be marshaled as a regular message.
+ Resolver *protoregistry.Types
}
// Marshal writes the given proto.Message in textproto format using options in MarshalOptions object.
func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
+ if o.Resolver == nil {
+ o.Resolver = protoregistry.GlobalTypes
+ }
+
var nerr errors.NonFatal
var v text.Value
-
var err error
v, err = o.marshalMessage(m.ProtoReflect())
if !nerr.Merge(err) {
@@ -59,9 +69,22 @@
func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
var nerr errors.NonFatal
var msgFields [][2]text.Value
+ msgType := m.Type()
+
+ // Handle Any expansion.
+ if msgType.FullName() == "google.protobuf.Any" {
+ msg, err := o.marshalAny(m)
+ if err == nil || nerr.Merge(err) {
+ // Return as is for nil or non-fatal error.
+ return msg, nerr.E
+ }
+ if err != protoregistry.NotFound {
+ return text.Value{}, err
+ }
+ // Continue on to marshal Any as a regular message if error is not found.
+ }
// Handle known fields.
- msgType := m.Type()
fieldDescs := msgType.Fields()
knownFields := m.KnownFields()
size := fieldDescs.Len()
@@ -335,3 +358,48 @@
}
return fields
}
+
+// marshalAny converts a google.protobuf.Any protoreflect.Message to a text.Value.
+func (o MarshalOptions) marshalAny(m pref.Message) (text.Value, error) {
+ var nerr errors.NonFatal
+
+ fds := m.Type().Fields()
+ tfd := fds.ByName("type_url")
+ if tfd == nil || tfd.Kind() != pref.StringKind {
+ return text.Value{}, errors.New("invalid google.protobuf.Any message")
+ }
+ vfd := fds.ByName("value")
+ if vfd == nil || vfd.Kind() != pref.BytesKind {
+ return text.Value{}, errors.New("invalid google.protobuf.Any message")
+ }
+
+ knownFields := m.KnownFields()
+ typeURL := knownFields.Get(tfd.Number())
+ value := knownFields.Get(vfd.Number())
+
+ emt, err := o.Resolver.FindMessageByURL(typeURL.String())
+ if !nerr.Merge(err) {
+ return text.Value{}, err
+ }
+ em := emt.New()
+ // TODO: Need to set types registry in binary unmarshaling.
+ err = proto.Unmarshal(value.Bytes(), em)
+ if !nerr.Merge(err) {
+ return text.Value{}, err
+ }
+
+ msg, err := o.marshalMessage(em.ProtoReflect())
+ if !nerr.Merge(err) {
+ return text.Value{}, err
+ }
+ // Expanded Any field value contains only a single field with the embedded
+ // message type as the field name in [] and a text marshaled field value of
+ // the embedded message.
+ msgFields := [][2]text.Value{
+ {
+ text.ValueOf(string(emt.FullName())),
+ msg,
+ },
+ }
+ return text.ValueOf(msgFields), nerr.E
+}
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 8b7974e..a439df0 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -9,14 +9,17 @@
"strings"
"testing"
+ protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoapi"
"github.com/golang/protobuf/v2/encoding/textpb"
"github.com/golang/protobuf/v2/internal/detrand"
"github.com/golang/protobuf/v2/internal/encoding/pack"
"github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/internal/impl"
"github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
+ preg "github.com/golang/protobuf/v2/reflect/protoregistry"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
@@ -24,6 +27,7 @@
// TODO: Remove this when protoV1 registers these hooks for you.
_ "github.com/golang/protobuf/v2/internal/legacy"
+ anypb "github.com/golang/protobuf/ptypes/any"
"github.com/golang/protobuf/v2/encoding/textpb/testprotos/pb2"
"github.com/golang/protobuf/v2/encoding/textpb/testprotos/pb3"
)
@@ -65,6 +69,7 @@
func TestMarshal(t *testing.T) {
tests := []struct {
desc string
+ mo textpb.MarshalOptions
input proto.Message
want string
wantErr bool
@@ -964,18 +969,161 @@
}
`,
*/
+ }, {
+ desc: "google.protobuf.Any message not expanded",
+ mo: textpb.MarshalOptions{Resolver: preg.NewTypes()},
+ input: func() proto.Message {
+ m := &pb2.Nested{
+ OptString: scalar.String("embedded inside Any"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("inception"),
+ },
+ }
+ // TODO: Switch to V2 marshal when ready.
+ b, err := protoV1.Marshal(m)
+ if err != nil {
+ t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+ }
+ return impl.Export{}.MessageOf(&anypb.Any{
+ TypeUrl: string(m.ProtoReflect().Type().FullName()),
+ Value: b,
+ }).Interface()
+ }(),
+ want: `type_url: "pb2.Nested"
+value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
+`,
+ }, {
+ desc: "google.protobuf.Any message expanded",
+ mo: func() textpb.MarshalOptions {
+ m := &pb2.Nested{}
+ resolver := preg.NewTypes(m.ProtoReflect().Type())
+ return textpb.MarshalOptions{Resolver: resolver}
+ }(),
+ input: func() proto.Message {
+ m := &pb2.Nested{
+ OptString: scalar.String("embedded inside Any"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("inception"),
+ },
+ }
+ // TODO: Switch to V2 marshal when ready.
+ b, err := protoV1.Marshal(m)
+ if err != nil {
+ t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+ }
+ return impl.Export{}.MessageOf(&anypb.Any{
+ TypeUrl: string(m.ProtoReflect().Type().FullName()),
+ Value: b,
+ }).Interface()
+ }(),
+ want: `[pb2.Nested]: {
+ opt_string: "embedded inside Any"
+ opt_nested: {
+ opt_string: "inception"
+ }
+}
+`,
+ }, {
+ desc: "google.protobuf.Any message expanded with missing required error",
+ mo: func() textpb.MarshalOptions {
+ m := &pb2.PartialRequired{}
+ resolver := preg.NewTypes(m.ProtoReflect().Type())
+ return textpb.MarshalOptions{Resolver: resolver}
+ }(),
+ input: func() proto.Message {
+ m := &pb2.PartialRequired{
+ OptString: scalar.String("embedded inside Any"),
+ }
+ // TODO: Switch to V2 marshal when ready.
+ b, err := protoV1.Marshal(m)
+ // Ignore required not set error.
+ if _, ok := err.(*protoV1.RequiredNotSetError); !ok {
+ t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+ }
+ return impl.Export{}.MessageOf(&anypb.Any{
+ TypeUrl: string(m.ProtoReflect().Type().FullName()),
+ Value: b,
+ }).Interface()
+ }(),
+ want: `[pb2.PartialRequired]: {
+ opt_string: "embedded inside Any"
+}
+`,
+ wantErr: true,
+ }, {
+ desc: "google.protobuf.Any field",
+ mo: textpb.MarshalOptions{Resolver: preg.NewTypes()},
+ input: func() proto.Message {
+ m := &pb2.Nested{
+ OptString: scalar.String("embedded inside Any"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("inception"),
+ },
+ }
+ // TODO: Switch to V2 marshal when ready.
+ b, err := protoV1.Marshal(m)
+ if err != nil {
+ t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+ }
+ return &pb2.KnownTypes{
+ OptAny: &anypb.Any{
+ TypeUrl: string(m.ProtoReflect().Type().FullName()),
+ Value: b,
+ },
+ }
+ }(),
+ want: `opt_any: {
+ type_url: "pb2.Nested"
+ value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
+}
+`,
+ }, {
+ desc: "google.protobuf.Any field expanded using given types registry",
+ mo: func() textpb.MarshalOptions {
+ m := &pb2.Nested{}
+ resolver := preg.NewTypes(m.ProtoReflect().Type())
+ return textpb.MarshalOptions{Resolver: resolver}
+ }(),
+ input: func() proto.Message {
+ m := &pb2.Nested{
+ OptString: scalar.String("embedded inside Any"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("inception"),
+ },
+ }
+ // TODO: Switch to V2 marshal when ready.
+ b, err := protoV1.Marshal(m)
+ if err != nil {
+ t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+ }
+ return &pb2.KnownTypes{
+ OptAny: &anypb.Any{
+ TypeUrl: string(m.ProtoReflect().Type().FullName()),
+ Value: b,
+ },
+ }
+ }(),
+ want: `opt_any: {
+ [pb2.Nested]: {
+ opt_string: "embedded inside Any"
+ opt_nested: {
+ opt_string: "inception"
+ }
+ }
+}
+`,
}}
for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
t.Parallel()
- b, err := textpb.Marshal(tt.input)
+ b, err := tt.mo.Marshal(tt.input)
if err != nil && !tt.wantErr {
- t.Errorf("Marshal() returned error: %v\n\n", err)
+ t.Errorf("Marshal() returned error: %v\n", err)
}
if err == nil && tt.wantErr {
- t.Error("Marshal() got nil error, want error\n\n")
+ t.Error("Marshal() got nil error, want error\n")
}
got := string(b)
if tt.want != "" && got != tt.want {