cmd/protoc-gen-go: register messages and map field types
Move generation of the init function that registers all the types in a
file into a single function.
Take some care to generate the registrations in the same order as the
previous protoc-gen-go, to make it easier to catch unintended
differences in output.
For the same reason, adjust the order of generation to generate all
enums before all messages (matches previous behavior).
Change-Id: Ie0d574004d01a16f8d7b10be3882719a3c41676e
Reviewed-on: https://go-review.googlesource.com/135359
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 815f080..c7d9f8d 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -14,6 +14,7 @@
"flag"
"fmt"
"math"
+ "sort"
"strconv"
"strings"
@@ -53,7 +54,8 @@
*protogen.File
locationMap map[string][]*descpb.SourceCodeInfo_Location
descriptorVar string // var containing the gzipped FileDescriptorProto
- init []string
+ allEnums []*protogen.Enum
+ allMessages []*protogen.Message
}
func genFile(gen *protogen.Plugin, file *protogen.File) {
@@ -66,6 +68,12 @@
f.locationMap[key] = append(f.locationMap[key], loc)
}
+ f.allEnums = append(f.allEnums, f.File.Enums...)
+ f.allMessages = append(f.allMessages, f.File.Messages...)
+ for _, message := range f.Messages {
+ f.initMessage(message)
+ }
+
// Determine the name of the var holding the file descriptor:
//
// fileDescriptor_<hash of filename>
@@ -91,25 +99,26 @@
}, "// please upgrade the proto package")
g.P()
- for _, enum := range f.Enums {
+ for _, enum := range f.allEnums {
genEnum(gen, g, f, enum)
}
- for _, message := range f.Messages {
+ for _, message := range f.allMessages {
genMessage(gen, g, f, message)
}
- if len(f.init) != 0 {
- g.P("func init() {")
- for _, s := range f.init {
- g.P(s)
- }
- g.P("}")
- g.P()
- }
+ genInitFunction(gen, g, f)
genFileDescriptor(gen, g, f)
}
+func (f *File) initMessage(message *protogen.Message) {
+ f.allEnums = append(f.allEnums, message.Enums...)
+ f.allMessages = append(f.allMessages, message.Messages...)
+ for _, m := range message.Messages {
+ f.initMessage(m)
+ }
+}
+
func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
// Trim the source_code_info from the descriptor.
// Marshal and gzip it.
@@ -215,14 +224,6 @@
g.P()
genWellKnownType(g, enum.GoIdent, enum.Desc)
-
- f.init = append(f.init, fmt.Sprintf("%s(%q, %s, %s)",
- g.QualifiedGoIdent(protogen.GoIdent{
- GoImportPath: protoPackage,
- GoName: "RegisterEnum",
- }),
- enumRegistryName(enum), nameMap, valueMap,
- ))
}
// enumRegistryName returns the name used to register an enum with the proto
@@ -250,10 +251,6 @@
return
}
- for _, e := range message.Enums {
- genEnum(gen, g, f, e)
- }
-
genComment(g, f, message.Path)
// TODO: deprecation
g.P("type ", message.GoIdent, " struct {")
@@ -415,9 +412,7 @@
g.P()
}
- for _, nested := range message.Messages {
- genMessage(gen, g, f, nested)
- }
+ genWellKnownType(g, message.GoIdent, message.Desc)
}
// fieldGoType returns the Go type used for a field.
@@ -590,6 +585,57 @@
return string(field.Desc.Name()) + ",omitempty"
}
+// genInitFunction generates an init function that registers the types in the
+// generated file with the proto package.
+func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
+ if len(f.allMessages) == 0 && len(f.allEnums) == 0 {
+ return
+ }
+
+ g.P("func init() {")
+ for _, message := range f.allMessages {
+ if message.Desc.IsMapEntry() {
+ continue
+ }
+
+ name := message.GoIdent.GoName
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterType",
+ }, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
+
+ // Types of map fields, sorted by the name of the field message type.
+ var mapFields []*protogen.Field
+ for _, field := range message.Fields {
+ if field.Desc.IsMap() {
+ mapFields = append(mapFields, field)
+ }
+ }
+ sort.Slice(mapFields, func(i, j int) bool {
+ ni := mapFields[i].MessageType.Desc.FullName()
+ nj := mapFields[j].MessageType.Desc.FullName()
+ return ni < nj
+ })
+ for _, field := range mapFields {
+ typeName := string(field.MessageType.Desc.FullName())
+ goType, _ := fieldGoType(g, field)
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterMapType",
+ }, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
+ }
+ }
+ for _, enum := range f.allEnums {
+ name := enum.GoIdent.GoName
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterEnum",
+ }, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+ }
+ g.P("}")
+ g.P()
+}
+
func genComment(g *protogen.GeneratedFile, f *File, path []int32) {
for _, loc := range f.locationMap[pathKey(path)] {
if loc.LeadingComments == nil {
diff --git a/cmd/protoc-gen-go/testdata/comments/comments.pb.go b/cmd/protoc-gen-go/testdata/comments/comments.pb.go
index 6a426fe..675f72c 100644
--- a/cmd/protoc-gen-go/testdata/comments/comments.pb.go
+++ b/cmd/protoc-gen-go/testdata/comments/comments.pb.go
@@ -44,6 +44,37 @@
var xxx_messageInfo_Message1 proto.InternalMessageInfo
+// COMMENT: Message2
+type Message2 struct {
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Message2) Reset() { *m = Message2{} }
+func (m *Message2) String() string { return proto.CompactTextString(m) }
+func (*Message2) ProtoMessage() {}
+func (*Message2) Descriptor() ([]byte, []int) {
+ return fileDescriptor_885e8293f1fab554, []int{1}
+}
+func (m *Message2) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Message2.Unmarshal(m, b)
+}
+func (m *Message2) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Message2.Marshal(b, m, deterministic)
+}
+func (m *Message2) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Message2.Merge(m, src)
+}
+func (m *Message2) XXX_Size() int {
+ return xxx_messageInfo_Message2.Size(m)
+}
+func (m *Message2) XXX_DiscardUnknown() {
+ xxx_messageInfo_Message2.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Message2 proto.InternalMessageInfo
+
// COMMENT: Message1A
type Message1_Message1A struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
@@ -106,37 +137,6 @@
var xxx_messageInfo_Message1_Message1B proto.InternalMessageInfo
-// COMMENT: Message2
-type Message2 struct {
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *Message2) Reset() { *m = Message2{} }
-func (m *Message2) String() string { return proto.CompactTextString(m) }
-func (*Message2) ProtoMessage() {}
-func (*Message2) Descriptor() ([]byte, []int) {
- return fileDescriptor_885e8293f1fab554, []int{1}
-}
-func (m *Message2) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_Message2.Unmarshal(m, b)
-}
-func (m *Message2) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_Message2.Marshal(b, m, deterministic)
-}
-func (m *Message2) XXX_Merge(src proto.Message) {
- xxx_messageInfo_Message2.Merge(m, src)
-}
-func (m *Message2) XXX_Size() int {
- return xxx_messageInfo_Message2.Size(m)
-}
-func (m *Message2) XXX_DiscardUnknown() {
- xxx_messageInfo_Message2.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_Message2 proto.InternalMessageInfo
-
// COMMENT: Message2A
type Message2_Message2A struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
@@ -199,6 +199,15 @@
var xxx_messageInfo_Message2_Message2B proto.InternalMessageInfo
+func init() {
+ proto.RegisterType((*Message1)(nil), "goproto.protoc.proto2.Message1")
+ proto.RegisterType((*Message2)(nil), "goproto.protoc.proto2.Message2")
+ proto.RegisterType((*Message1_Message1A)(nil), "goproto.protoc.proto2.Message1.Message1A")
+ proto.RegisterType((*Message1_Message1B)(nil), "goproto.protoc.proto2.Message1.Message1B")
+ proto.RegisterType((*Message2_Message2A)(nil), "goproto.protoc.proto2.Message2.Message2A")
+ proto.RegisterType((*Message2_Message2B)(nil), "goproto.protoc.proto2.Message2.Message2B")
+}
+
func init() { proto.RegisterFile("comments/comments.proto", fileDescriptor_885e8293f1fab554) }
var fileDescriptor_885e8293f1fab554 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go b/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
index b007046..a72a731 100644
--- a/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
+++ b/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
@@ -183,6 +183,10 @@
return ""
}
+func init() {
+ proto.RegisterType((*Message)(nil), "goproto.protoc.fieldnames.Message")
+}
+
func init() { proto.RegisterFile("fieldnames/fieldnames.proto", fileDescriptor_6bbe3f70febb9403) }
var fileDescriptor_6bbe3f70febb9403 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto2/enum.pb.go b/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
index 5a32525..57ba0c4 100644
--- a/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
@@ -170,36 +170,6 @@
return fileDescriptor_de9f68860d540858, []int{0, 1}
}
-type EnumContainerMessage1 struct {
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *EnumContainerMessage1) Reset() { *m = EnumContainerMessage1{} }
-func (m *EnumContainerMessage1) String() string { return proto.CompactTextString(m) }
-func (*EnumContainerMessage1) ProtoMessage() {}
-func (*EnumContainerMessage1) Descriptor() ([]byte, []int) {
- return fileDescriptor_de9f68860d540858, []int{0}
-}
-func (m *EnumContainerMessage1) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_EnumContainerMessage1.Unmarshal(m, b)
-}
-func (m *EnumContainerMessage1) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_EnumContainerMessage1.Marshal(b, m, deterministic)
-}
-func (m *EnumContainerMessage1) XXX_Merge(src proto.Message) {
- xxx_messageInfo_EnumContainerMessage1.Merge(m, src)
-}
-func (m *EnumContainerMessage1) XXX_Size() int {
- return xxx_messageInfo_EnumContainerMessage1.Size(m)
-}
-func (m *EnumContainerMessage1) XXX_DiscardUnknown() {
- xxx_messageInfo_EnumContainerMessage1.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_EnumContainerMessage1 proto.InternalMessageInfo
-
// NestedEnumType2A comment.
type EnumContainerMessage1_EnumContainerMessage2_NestedEnumType2A int32
@@ -276,6 +246,36 @@
return fileDescriptor_de9f68860d540858, []int{0, 0, 1}
}
+type EnumContainerMessage1 struct {
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EnumContainerMessage1) Reset() { *m = EnumContainerMessage1{} }
+func (m *EnumContainerMessage1) String() string { return proto.CompactTextString(m) }
+func (*EnumContainerMessage1) ProtoMessage() {}
+func (*EnumContainerMessage1) Descriptor() ([]byte, []int) {
+ return fileDescriptor_de9f68860d540858, []int{0}
+}
+func (m *EnumContainerMessage1) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EnumContainerMessage1.Unmarshal(m, b)
+}
+func (m *EnumContainerMessage1) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EnumContainerMessage1.Marshal(b, m, deterministic)
+}
+func (m *EnumContainerMessage1) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EnumContainerMessage1.Merge(m, src)
+}
+func (m *EnumContainerMessage1) XXX_Size() int {
+ return xxx_messageInfo_EnumContainerMessage1.Size(m)
+}
+func (m *EnumContainerMessage1) XXX_DiscardUnknown() {
+ xxx_messageInfo_EnumContainerMessage1.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EnumContainerMessage1 proto.InternalMessageInfo
+
type EnumContainerMessage1_EnumContainerMessage2 struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
@@ -311,6 +311,8 @@
var xxx_messageInfo_EnumContainerMessage1_EnumContainerMessage2 proto.InternalMessageInfo
func init() {
+ proto.RegisterType((*EnumContainerMessage1)(nil), "goproto.protoc.proto2.EnumContainerMessage1")
+ proto.RegisterType((*EnumContainerMessage1_EnumContainerMessage2)(nil), "goproto.protoc.proto2.EnumContainerMessage1.EnumContainerMessage2")
proto.RegisterEnum("goproto.protoc.proto2.EnumType1", EnumType1_name, EnumType1_value)
proto.RegisterEnum("goproto.protoc.proto2.EnumType2", EnumType2_name, EnumType2_value)
proto.RegisterEnum("goproto.protoc.proto2.EnumContainerMessage1_NestedEnumType1A", EnumContainerMessage1_NestedEnumType1A_name, EnumContainerMessage1_NestedEnumType1A_value)
diff --git a/cmd/protoc-gen-go/testdata/proto2/fields.pb.go b/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
index 9128126..87ab5ed 100644
--- a/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
@@ -885,6 +885,14 @@
var xxx_messageInfo_FieldTestMessage_Message proto.InternalMessageInfo
func init() {
+ proto.RegisterType((*FieldTestMessage)(nil), "goproto.protoc.proto2.FieldTestMessage")
+ proto.RegisterMapType((map[uint64]FieldTestMessage_Enum)(nil), "goproto.protoc.proto2.FieldTestMessage.MapFixed64EnumEntry")
+ proto.RegisterMapType((map[int32]int64)(nil), "goproto.protoc.proto2.FieldTestMessage.MapInt32Int64Entry")
+ proto.RegisterMapType((map[string]*FieldTestMessage_Message)(nil), "goproto.protoc.proto2.FieldTestMessage.MapStringMessageEntry")
+ proto.RegisterType((*FieldTestMessage_OptionalGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.OptionalGroup")
+ proto.RegisterType((*FieldTestMessage_RequiredGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.RequiredGroup")
+ proto.RegisterType((*FieldTestMessage_RepeatedGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.RepeatedGroup")
+ proto.RegisterType((*FieldTestMessage_Message)(nil), "goproto.protoc.proto2.FieldTestMessage.Message")
proto.RegisterEnum("goproto.protoc.proto2.FieldTestMessage_Enum", FieldTestMessage_Enum_name, FieldTestMessage_Enum_value)
}
diff --git a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
index 92e61cb..64ecad2 100644
--- a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
@@ -125,6 +125,12 @@
var xxx_messageInfo_Layer1_Layer2_Layer3 proto.InternalMessageInfo
+func init() {
+ proto.RegisterType((*Layer1)(nil), "goproto.protoc.proto2.Layer1")
+ proto.RegisterType((*Layer1_Layer2)(nil), "goproto.protoc.proto2.Layer1.Layer2")
+ proto.RegisterType((*Layer1_Layer2_Layer3)(nil), "goproto.protoc.proto2.Layer1.Layer2.Layer3")
+}
+
func init() { proto.RegisterFile("proto2/nested_messages.proto", fileDescriptor_7417ee157699d191) }
var fileDescriptor_7417ee157699d191 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
index 2452749..0fe87d6 100644
--- a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
@@ -57,6 +57,10 @@
return nil
}
+func init() {
+ proto.RegisterType((*Message)(nil), "goproto.protoc.proto2.Message")
+}
+
func init() { proto.RegisterFile("proto2/proto2.proto", fileDescriptor_d756bbe8817c03c1) }
var fileDescriptor_d756bbe8817c03c1 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto3/fields.pb.go b/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
index 6ea3da2..4ccfc13 100644
--- a/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
@@ -254,6 +254,11 @@
var xxx_messageInfo_FieldTestMessage_Message proto.InternalMessageInfo
func init() {
+ proto.RegisterType((*FieldTestMessage)(nil), "goproto.protoc.proto3.FieldTestMessage")
+ proto.RegisterMapType((map[uint64]FieldTestMessage_Enum)(nil), "goproto.protoc.proto3.FieldTestMessage.MapFixed64EnumEntry")
+ proto.RegisterMapType((map[int32]int64)(nil), "goproto.protoc.proto3.FieldTestMessage.MapInt32Int64Entry")
+ proto.RegisterMapType((map[string]*FieldTestMessage_Message)(nil), "goproto.protoc.proto3.FieldTestMessage.MapStringMessageEntry")
+ proto.RegisterType((*FieldTestMessage_Message)(nil), "goproto.protoc.proto3.FieldTestMessage.Message")
proto.RegisterEnum("goproto.protoc.proto3.FieldTestMessage_Enum", FieldTestMessage_Enum_name, FieldTestMessage_Enum_value)
}