cmd/protoc-gen-go: register messages and map field types

Move generation of the init function that registers all the types in a
file into a single function.

Take some care to generate the registrations in the same order as the
previous protoc-gen-go, to make it easier to catch unintended
differences in output.

For the same reason, adjust the order of generation to generate all
enums before all messages (matches previous behavior).

Change-Id: Ie0d574004d01a16f8d7b10be3882719a3c41676e
Reviewed-on: https://go-review.googlesource.com/135359
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 815f080..c7d9f8d 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -14,6 +14,7 @@
 	"flag"
 	"fmt"
 	"math"
+	"sort"
 	"strconv"
 	"strings"
 
@@ -53,7 +54,8 @@
 	*protogen.File
 	locationMap   map[string][]*descpb.SourceCodeInfo_Location
 	descriptorVar string // var containing the gzipped FileDescriptorProto
-	init          []string
+	allEnums      []*protogen.Enum
+	allMessages   []*protogen.Message
 }
 
 func genFile(gen *protogen.Plugin, file *protogen.File) {
@@ -66,6 +68,12 @@
 		f.locationMap[key] = append(f.locationMap[key], loc)
 	}
 
+	f.allEnums = append(f.allEnums, f.File.Enums...)
+	f.allMessages = append(f.allMessages, f.File.Messages...)
+	for _, message := range f.Messages {
+		f.initMessage(message)
+	}
+
 	// Determine the name of the var holding the file descriptor:
 	//
 	//     fileDescriptor_<hash of filename>
@@ -91,25 +99,26 @@
 	}, "// please upgrade the proto package")
 	g.P()
 
-	for _, enum := range f.Enums {
+	for _, enum := range f.allEnums {
 		genEnum(gen, g, f, enum)
 	}
-	for _, message := range f.Messages {
+	for _, message := range f.allMessages {
 		genMessage(gen, g, f, message)
 	}
 
-	if len(f.init) != 0 {
-		g.P("func init() {")
-		for _, s := range f.init {
-			g.P(s)
-		}
-		g.P("}")
-		g.P()
-	}
+	genInitFunction(gen, g, f)
 
 	genFileDescriptor(gen, g, f)
 }
 
+func (f *File) initMessage(message *protogen.Message) {
+	f.allEnums = append(f.allEnums, message.Enums...)
+	f.allMessages = append(f.allMessages, message.Messages...)
+	for _, m := range message.Messages {
+		f.initMessage(m)
+	}
+}
+
 func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
 	// Trim the source_code_info from the descriptor.
 	// Marshal and gzip it.
@@ -215,14 +224,6 @@
 	g.P()
 
 	genWellKnownType(g, enum.GoIdent, enum.Desc)
-
-	f.init = append(f.init, fmt.Sprintf("%s(%q, %s, %s)",
-		g.QualifiedGoIdent(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterEnum",
-		}),
-		enumRegistryName(enum), nameMap, valueMap,
-	))
 }
 
 // enumRegistryName returns the name used to register an enum with the proto
@@ -250,10 +251,6 @@
 		return
 	}
 
-	for _, e := range message.Enums {
-		genEnum(gen, g, f, e)
-	}
-
 	genComment(g, f, message.Path)
 	// TODO: deprecation
 	g.P("type ", message.GoIdent, " struct {")
@@ -415,9 +412,7 @@
 		g.P()
 	}
 
-	for _, nested := range message.Messages {
-		genMessage(gen, g, f, nested)
-	}
+	genWellKnownType(g, message.GoIdent, message.Desc)
 }
 
 // fieldGoType returns the Go type used for a field.
@@ -590,6 +585,57 @@
 	return string(field.Desc.Name()) + ",omitempty"
 }
 
+// genInitFunction generates an init function that registers the types in the
+// generated file with the proto package.
+func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
+	if len(f.allMessages) == 0 && len(f.allEnums) == 0 {
+		return
+	}
+
+	g.P("func init() {")
+	for _, message := range f.allMessages {
+		if message.Desc.IsMapEntry() {
+			continue
+		}
+
+		name := message.GoIdent.GoName
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterType",
+		}, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
+
+		// Types of map fields, sorted by the name of the field message type.
+		var mapFields []*protogen.Field
+		for _, field := range message.Fields {
+			if field.Desc.IsMap() {
+				mapFields = append(mapFields, field)
+			}
+		}
+		sort.Slice(mapFields, func(i, j int) bool {
+			ni := mapFields[i].MessageType.Desc.FullName()
+			nj := mapFields[j].MessageType.Desc.FullName()
+			return ni < nj
+		})
+		for _, field := range mapFields {
+			typeName := string(field.MessageType.Desc.FullName())
+			goType, _ := fieldGoType(g, field)
+			g.P(protogen.GoIdent{
+				GoImportPath: protoPackage,
+				GoName:       "RegisterMapType",
+			}, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
+		}
+	}
+	for _, enum := range f.allEnums {
+		name := enum.GoIdent.GoName
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterEnum",
+		}, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+	}
+	g.P("}")
+	g.P()
+}
+
 func genComment(g *protogen.GeneratedFile, f *File, path []int32) {
 	for _, loc := range f.locationMap[pathKey(path)] {
 		if loc.LeadingComments == nil {
diff --git a/cmd/protoc-gen-go/testdata/comments/comments.pb.go b/cmd/protoc-gen-go/testdata/comments/comments.pb.go
index 6a426fe..675f72c 100644
--- a/cmd/protoc-gen-go/testdata/comments/comments.pb.go
+++ b/cmd/protoc-gen-go/testdata/comments/comments.pb.go
@@ -44,6 +44,37 @@
 
 var xxx_messageInfo_Message1 proto.InternalMessageInfo
 
+// COMMENT: Message2
+type Message2 struct {
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *Message2) Reset()         { *m = Message2{} }
+func (m *Message2) String() string { return proto.CompactTextString(m) }
+func (*Message2) ProtoMessage()    {}
+func (*Message2) Descriptor() ([]byte, []int) {
+	return fileDescriptor_885e8293f1fab554, []int{1}
+}
+func (m *Message2) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_Message2.Unmarshal(m, b)
+}
+func (m *Message2) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_Message2.Marshal(b, m, deterministic)
+}
+func (m *Message2) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_Message2.Merge(m, src)
+}
+func (m *Message2) XXX_Size() int {
+	return xxx_messageInfo_Message2.Size(m)
+}
+func (m *Message2) XXX_DiscardUnknown() {
+	xxx_messageInfo_Message2.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Message2 proto.InternalMessageInfo
+
 // COMMENT: Message1A
 type Message1_Message1A struct {
 	XXX_NoUnkeyedLiteral struct{} `json:"-"`
@@ -106,37 +137,6 @@
 
 var xxx_messageInfo_Message1_Message1B proto.InternalMessageInfo
 
-// COMMENT: Message2
-type Message2 struct {
-	XXX_NoUnkeyedLiteral struct{} `json:"-"`
-	XXX_unrecognized     []byte   `json:"-"`
-	XXX_sizecache        int32    `json:"-"`
-}
-
-func (m *Message2) Reset()         { *m = Message2{} }
-func (m *Message2) String() string { return proto.CompactTextString(m) }
-func (*Message2) ProtoMessage()    {}
-func (*Message2) Descriptor() ([]byte, []int) {
-	return fileDescriptor_885e8293f1fab554, []int{1}
-}
-func (m *Message2) XXX_Unmarshal(b []byte) error {
-	return xxx_messageInfo_Message2.Unmarshal(m, b)
-}
-func (m *Message2) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
-	return xxx_messageInfo_Message2.Marshal(b, m, deterministic)
-}
-func (m *Message2) XXX_Merge(src proto.Message) {
-	xxx_messageInfo_Message2.Merge(m, src)
-}
-func (m *Message2) XXX_Size() int {
-	return xxx_messageInfo_Message2.Size(m)
-}
-func (m *Message2) XXX_DiscardUnknown() {
-	xxx_messageInfo_Message2.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_Message2 proto.InternalMessageInfo
-
 // COMMENT: Message2A
 type Message2_Message2A struct {
 	XXX_NoUnkeyedLiteral struct{} `json:"-"`
@@ -199,6 +199,15 @@
 
 var xxx_messageInfo_Message2_Message2B proto.InternalMessageInfo
 
+func init() {
+	proto.RegisterType((*Message1)(nil), "goproto.protoc.proto2.Message1")
+	proto.RegisterType((*Message2)(nil), "goproto.protoc.proto2.Message2")
+	proto.RegisterType((*Message1_Message1A)(nil), "goproto.protoc.proto2.Message1.Message1A")
+	proto.RegisterType((*Message1_Message1B)(nil), "goproto.protoc.proto2.Message1.Message1B")
+	proto.RegisterType((*Message2_Message2A)(nil), "goproto.protoc.proto2.Message2.Message2A")
+	proto.RegisterType((*Message2_Message2B)(nil), "goproto.protoc.proto2.Message2.Message2B")
+}
+
 func init() { proto.RegisterFile("comments/comments.proto", fileDescriptor_885e8293f1fab554) }
 
 var fileDescriptor_885e8293f1fab554 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go b/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
index b007046..a72a731 100644
--- a/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
+++ b/cmd/protoc-gen-go/testdata/fieldnames/fieldnames.pb.go
@@ -183,6 +183,10 @@
 	return ""
 }
 
+func init() {
+	proto.RegisterType((*Message)(nil), "goproto.protoc.fieldnames.Message")
+}
+
 func init() { proto.RegisterFile("fieldnames/fieldnames.proto", fileDescriptor_6bbe3f70febb9403) }
 
 var fileDescriptor_6bbe3f70febb9403 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto2/enum.pb.go b/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
index 5a32525..57ba0c4 100644
--- a/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/enum.pb.go
@@ -170,36 +170,6 @@
 	return fileDescriptor_de9f68860d540858, []int{0, 1}
 }
 
-type EnumContainerMessage1 struct {
-	XXX_NoUnkeyedLiteral struct{} `json:"-"`
-	XXX_unrecognized     []byte   `json:"-"`
-	XXX_sizecache        int32    `json:"-"`
-}
-
-func (m *EnumContainerMessage1) Reset()         { *m = EnumContainerMessage1{} }
-func (m *EnumContainerMessage1) String() string { return proto.CompactTextString(m) }
-func (*EnumContainerMessage1) ProtoMessage()    {}
-func (*EnumContainerMessage1) Descriptor() ([]byte, []int) {
-	return fileDescriptor_de9f68860d540858, []int{0}
-}
-func (m *EnumContainerMessage1) XXX_Unmarshal(b []byte) error {
-	return xxx_messageInfo_EnumContainerMessage1.Unmarshal(m, b)
-}
-func (m *EnumContainerMessage1) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
-	return xxx_messageInfo_EnumContainerMessage1.Marshal(b, m, deterministic)
-}
-func (m *EnumContainerMessage1) XXX_Merge(src proto.Message) {
-	xxx_messageInfo_EnumContainerMessage1.Merge(m, src)
-}
-func (m *EnumContainerMessage1) XXX_Size() int {
-	return xxx_messageInfo_EnumContainerMessage1.Size(m)
-}
-func (m *EnumContainerMessage1) XXX_DiscardUnknown() {
-	xxx_messageInfo_EnumContainerMessage1.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_EnumContainerMessage1 proto.InternalMessageInfo
-
 // NestedEnumType2A comment.
 type EnumContainerMessage1_EnumContainerMessage2_NestedEnumType2A int32
 
@@ -276,6 +246,36 @@
 	return fileDescriptor_de9f68860d540858, []int{0, 0, 1}
 }
 
+type EnumContainerMessage1 struct {
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *EnumContainerMessage1) Reset()         { *m = EnumContainerMessage1{} }
+func (m *EnumContainerMessage1) String() string { return proto.CompactTextString(m) }
+func (*EnumContainerMessage1) ProtoMessage()    {}
+func (*EnumContainerMessage1) Descriptor() ([]byte, []int) {
+	return fileDescriptor_de9f68860d540858, []int{0}
+}
+func (m *EnumContainerMessage1) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_EnumContainerMessage1.Unmarshal(m, b)
+}
+func (m *EnumContainerMessage1) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_EnumContainerMessage1.Marshal(b, m, deterministic)
+}
+func (m *EnumContainerMessage1) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_EnumContainerMessage1.Merge(m, src)
+}
+func (m *EnumContainerMessage1) XXX_Size() int {
+	return xxx_messageInfo_EnumContainerMessage1.Size(m)
+}
+func (m *EnumContainerMessage1) XXX_DiscardUnknown() {
+	xxx_messageInfo_EnumContainerMessage1.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EnumContainerMessage1 proto.InternalMessageInfo
+
 type EnumContainerMessage1_EnumContainerMessage2 struct {
 	XXX_NoUnkeyedLiteral struct{} `json:"-"`
 	XXX_unrecognized     []byte   `json:"-"`
@@ -311,6 +311,8 @@
 var xxx_messageInfo_EnumContainerMessage1_EnumContainerMessage2 proto.InternalMessageInfo
 
 func init() {
+	proto.RegisterType((*EnumContainerMessage1)(nil), "goproto.protoc.proto2.EnumContainerMessage1")
+	proto.RegisterType((*EnumContainerMessage1_EnumContainerMessage2)(nil), "goproto.protoc.proto2.EnumContainerMessage1.EnumContainerMessage2")
 	proto.RegisterEnum("goproto.protoc.proto2.EnumType1", EnumType1_name, EnumType1_value)
 	proto.RegisterEnum("goproto.protoc.proto2.EnumType2", EnumType2_name, EnumType2_value)
 	proto.RegisterEnum("goproto.protoc.proto2.EnumContainerMessage1_NestedEnumType1A", EnumContainerMessage1_NestedEnumType1A_name, EnumContainerMessage1_NestedEnumType1A_value)
diff --git a/cmd/protoc-gen-go/testdata/proto2/fields.pb.go b/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
index 9128126..87ab5ed 100644
--- a/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/fields.pb.go
@@ -885,6 +885,14 @@
 var xxx_messageInfo_FieldTestMessage_Message proto.InternalMessageInfo
 
 func init() {
+	proto.RegisterType((*FieldTestMessage)(nil), "goproto.protoc.proto2.FieldTestMessage")
+	proto.RegisterMapType((map[uint64]FieldTestMessage_Enum)(nil), "goproto.protoc.proto2.FieldTestMessage.MapFixed64EnumEntry")
+	proto.RegisterMapType((map[int32]int64)(nil), "goproto.protoc.proto2.FieldTestMessage.MapInt32Int64Entry")
+	proto.RegisterMapType((map[string]*FieldTestMessage_Message)(nil), "goproto.protoc.proto2.FieldTestMessage.MapStringMessageEntry")
+	proto.RegisterType((*FieldTestMessage_OptionalGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.OptionalGroup")
+	proto.RegisterType((*FieldTestMessage_RequiredGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.RequiredGroup")
+	proto.RegisterType((*FieldTestMessage_RepeatedGroup)(nil), "goproto.protoc.proto2.FieldTestMessage.RepeatedGroup")
+	proto.RegisterType((*FieldTestMessage_Message)(nil), "goproto.protoc.proto2.FieldTestMessage.Message")
 	proto.RegisterEnum("goproto.protoc.proto2.FieldTestMessage_Enum", FieldTestMessage_Enum_name, FieldTestMessage_Enum_value)
 }
 
diff --git a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
index 92e61cb..64ecad2 100644
--- a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
@@ -125,6 +125,12 @@
 
 var xxx_messageInfo_Layer1_Layer2_Layer3 proto.InternalMessageInfo
 
+func init() {
+	proto.RegisterType((*Layer1)(nil), "goproto.protoc.proto2.Layer1")
+	proto.RegisterType((*Layer1_Layer2)(nil), "goproto.protoc.proto2.Layer1.Layer2")
+	proto.RegisterType((*Layer1_Layer2_Layer3)(nil), "goproto.protoc.proto2.Layer1.Layer2.Layer3")
+}
+
 func init() { proto.RegisterFile("proto2/nested_messages.proto", fileDescriptor_7417ee157699d191) }
 
 var fileDescriptor_7417ee157699d191 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
index 2452749..0fe87d6 100644
--- a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
@@ -57,6 +57,10 @@
 	return nil
 }
 
+func init() {
+	proto.RegisterType((*Message)(nil), "goproto.protoc.proto2.Message")
+}
+
 func init() { proto.RegisterFile("proto2/proto2.proto", fileDescriptor_d756bbe8817c03c1) }
 
 var fileDescriptor_d756bbe8817c03c1 = []byte{
diff --git a/cmd/protoc-gen-go/testdata/proto3/fields.pb.go b/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
index 6ea3da2..4ccfc13 100644
--- a/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto3/fields.pb.go
@@ -254,6 +254,11 @@
 var xxx_messageInfo_FieldTestMessage_Message proto.InternalMessageInfo
 
 func init() {
+	proto.RegisterType((*FieldTestMessage)(nil), "goproto.protoc.proto3.FieldTestMessage")
+	proto.RegisterMapType((map[uint64]FieldTestMessage_Enum)(nil), "goproto.protoc.proto3.FieldTestMessage.MapFixed64EnumEntry")
+	proto.RegisterMapType((map[int32]int64)(nil), "goproto.protoc.proto3.FieldTestMessage.MapInt32Int64Entry")
+	proto.RegisterMapType((map[string]*FieldTestMessage_Message)(nil), "goproto.protoc.proto3.FieldTestMessage.MapStringMessageEntry")
+	proto.RegisterType((*FieldTestMessage_Message)(nil), "goproto.protoc.proto3.FieldTestMessage.Message")
 	proto.RegisterEnum("goproto.protoc.proto3.FieldTestMessage_Enum", FieldTestMessage_Enum_name, FieldTestMessage_Enum_value)
 }