encoding/textpb: marshal google.protobuf.Any in expanded form

Marshal well-known type Any in expanded form by default, else fallback
to marshaling it as a regular message.

Change-Id: Ic7e9e37b47042a163941f8849dc366ffe48103ca
Reviewed-on: https://go-review.googlesource.com/c/156097
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index c0b04b6..62ebc63 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -14,6 +14,7 @@
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Marshal writes the given proto.Message in textproto format using default options.
@@ -28,13 +29,22 @@
 
 	// Set Compact to true to have output in a single line with no line breaks.
 	Compact bool
+
+	// Resolver is the registry used for type lookups when marshaling out
+	// google.protobuf.Any messages in expanded form. If Resolver is not set,
+	// marshaling will default to using protoregistry.GlobalTypes.  If a type is
+	// not found, an Any message will be marshaled as a regular message.
+	Resolver *protoregistry.Types
 }
 
 // Marshal writes the given proto.Message in textproto format using options in MarshalOptions object.
 func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
+
 	var nerr errors.NonFatal
 	var v text.Value
-
 	var err error
 	v, err = o.marshalMessage(m.ProtoReflect())
 	if !nerr.Merge(err) {
@@ -59,9 +69,22 @@
 func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
 	var nerr errors.NonFatal
 	var msgFields [][2]text.Value
+	msgType := m.Type()
+
+	// Handle Any expansion.
+	if msgType.FullName() == "google.protobuf.Any" {
+		msg, err := o.marshalAny(m)
+		if err == nil || nerr.Merge(err) {
+			// Return as is for nil or non-fatal error.
+			return msg, nerr.E
+		}
+		if err != protoregistry.NotFound {
+			return text.Value{}, err
+		}
+		// Continue on to marshal Any as a regular message if error is not found.
+	}
 
 	// Handle known fields.
-	msgType := m.Type()
 	fieldDescs := msgType.Fields()
 	knownFields := m.KnownFields()
 	size := fieldDescs.Len()
@@ -335,3 +358,48 @@
 	}
 	return fields
 }
+
+// marshalAny converts a google.protobuf.Any protoreflect.Message to a text.Value.
+func (o MarshalOptions) marshalAny(m pref.Message) (text.Value, error) {
+	var nerr errors.NonFatal
+
+	fds := m.Type().Fields()
+	tfd := fds.ByName("type_url")
+	if tfd == nil || tfd.Kind() != pref.StringKind {
+		return text.Value{}, errors.New("invalid google.protobuf.Any message")
+	}
+	vfd := fds.ByName("value")
+	if vfd == nil || vfd.Kind() != pref.BytesKind {
+		return text.Value{}, errors.New("invalid google.protobuf.Any message")
+	}
+
+	knownFields := m.KnownFields()
+	typeURL := knownFields.Get(tfd.Number())
+	value := knownFields.Get(vfd.Number())
+
+	emt, err := o.Resolver.FindMessageByURL(typeURL.String())
+	if !nerr.Merge(err) {
+		return text.Value{}, err
+	}
+	em := emt.New()
+	// TODO: Need to set types registry in binary unmarshaling.
+	err = proto.Unmarshal(value.Bytes(), em)
+	if !nerr.Merge(err) {
+		return text.Value{}, err
+	}
+
+	msg, err := o.marshalMessage(em.ProtoReflect())
+	if !nerr.Merge(err) {
+		return text.Value{}, err
+	}
+	// Expanded Any field value contains only a single field with the embedded
+	// message type as the field name in [] and a text marshaled field value of
+	// the embedded message.
+	msgFields := [][2]text.Value{
+		{
+			text.ValueOf(string(emt.FullName())),
+			msg,
+		},
+	}
+	return text.ValueOf(msgFields), nerr.E
+}
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 8b7974e..a439df0 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -9,14 +9,17 @@
 	"strings"
 	"testing"
 
+	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/detrand"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/internal/impl"
 	"github.com/golang/protobuf/v2/internal/legacy"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
+	preg "github.com/golang/protobuf/v2/reflect/protoregistry"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 
@@ -24,6 +27,7 @@
 	// TODO: Remove this when protoV1 registers these hooks for you.
 	_ "github.com/golang/protobuf/v2/internal/legacy"
 
+	anypb "github.com/golang/protobuf/ptypes/any"
 	"github.com/golang/protobuf/v2/encoding/textpb/testprotos/pb2"
 	"github.com/golang/protobuf/v2/encoding/textpb/testprotos/pb3"
 )
@@ -65,6 +69,7 @@
 func TestMarshal(t *testing.T) {
 	tests := []struct {
 		desc    string
+		mo      textpb.MarshalOptions
 		input   proto.Message
 		want    string
 		wantErr bool
@@ -964,18 +969,161 @@
 		   }
 		   `,
 		*/
+	}, {
+		desc: "google.protobuf.Any message not expanded",
+		mo:   textpb.MarshalOptions{Resolver: preg.NewTypes()},
+		input: func() proto.Message {
+			m := &pb2.Nested{
+				OptString: scalar.String("embedded inside Any"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("inception"),
+				},
+			}
+			// TODO: Switch to V2 marshal when ready.
+			b, err := protoV1.Marshal(m)
+			if err != nil {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return impl.Export{}.MessageOf(&anypb.Any{
+				TypeUrl: string(m.ProtoReflect().Type().FullName()),
+				Value:   b,
+			}).Interface()
+		}(),
+		want: `type_url: "pb2.Nested"
+value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
+`,
+	}, {
+		desc: "google.protobuf.Any message expanded",
+		mo: func() textpb.MarshalOptions {
+			m := &pb2.Nested{}
+			resolver := preg.NewTypes(m.ProtoReflect().Type())
+			return textpb.MarshalOptions{Resolver: resolver}
+		}(),
+		input: func() proto.Message {
+			m := &pb2.Nested{
+				OptString: scalar.String("embedded inside Any"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("inception"),
+				},
+			}
+			// TODO: Switch to V2 marshal when ready.
+			b, err := protoV1.Marshal(m)
+			if err != nil {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return impl.Export{}.MessageOf(&anypb.Any{
+				TypeUrl: string(m.ProtoReflect().Type().FullName()),
+				Value:   b,
+			}).Interface()
+		}(),
+		want: `[pb2.Nested]: {
+  opt_string: "embedded inside Any"
+  opt_nested: {
+    opt_string: "inception"
+  }
+}
+`,
+	}, {
+		desc: "google.protobuf.Any message expanded with missing required error",
+		mo: func() textpb.MarshalOptions {
+			m := &pb2.PartialRequired{}
+			resolver := preg.NewTypes(m.ProtoReflect().Type())
+			return textpb.MarshalOptions{Resolver: resolver}
+		}(),
+		input: func() proto.Message {
+			m := &pb2.PartialRequired{
+				OptString: scalar.String("embedded inside Any"),
+			}
+			// TODO: Switch to V2 marshal when ready.
+			b, err := protoV1.Marshal(m)
+			// Ignore required not set error.
+			if _, ok := err.(*protoV1.RequiredNotSetError); !ok {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return impl.Export{}.MessageOf(&anypb.Any{
+				TypeUrl: string(m.ProtoReflect().Type().FullName()),
+				Value:   b,
+			}).Interface()
+		}(),
+		want: `[pb2.PartialRequired]: {
+  opt_string: "embedded inside Any"
+}
+`,
+		wantErr: true,
+	}, {
+		desc: "google.protobuf.Any field",
+		mo:   textpb.MarshalOptions{Resolver: preg.NewTypes()},
+		input: func() proto.Message {
+			m := &pb2.Nested{
+				OptString: scalar.String("embedded inside Any"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("inception"),
+				},
+			}
+			// TODO: Switch to V2 marshal when ready.
+			b, err := protoV1.Marshal(m)
+			if err != nil {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return &pb2.KnownTypes{
+				OptAny: &anypb.Any{
+					TypeUrl: string(m.ProtoReflect().Type().FullName()),
+					Value:   b,
+				},
+			}
+		}(),
+		want: `opt_any: {
+  type_url: "pb2.Nested"
+  value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
+}
+`,
+	}, {
+		desc: "google.protobuf.Any field expanded using given types registry",
+		mo: func() textpb.MarshalOptions {
+			m := &pb2.Nested{}
+			resolver := preg.NewTypes(m.ProtoReflect().Type())
+			return textpb.MarshalOptions{Resolver: resolver}
+		}(),
+		input: func() proto.Message {
+			m := &pb2.Nested{
+				OptString: scalar.String("embedded inside Any"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("inception"),
+				},
+			}
+			// TODO: Switch to V2 marshal when ready.
+			b, err := protoV1.Marshal(m)
+			if err != nil {
+				t.Fatalf("error in binary marshaling message for Any.value: %v", err)
+			}
+			return &pb2.KnownTypes{
+				OptAny: &anypb.Any{
+					TypeUrl: string(m.ProtoReflect().Type().FullName()),
+					Value:   b,
+				},
+			}
+		}(),
+		want: `opt_any: {
+  [pb2.Nested]: {
+    opt_string: "embedded inside Any"
+    opt_nested: {
+      opt_string: "inception"
+    }
+  }
+}
+`,
 	}}
 
 	for _, tt := range tests {
 		tt := tt
 		t.Run(tt.desc, func(t *testing.T) {
 			t.Parallel()
-			b, err := textpb.Marshal(tt.input)
+			b, err := tt.mo.Marshal(tt.input)
 			if err != nil && !tt.wantErr {
-				t.Errorf("Marshal() returned error: %v\n\n", err)
+				t.Errorf("Marshal() returned error: %v\n", err)
 			}
 			if err == nil && tt.wantErr {
-				t.Error("Marshal() got nil error, want error\n\n")
+				t.Error("Marshal() got nil error, want error\n")
 			}
 			got := string(b)
 			if tt.want != "" && got != tt.want {