protogen: move comment generation into protogen

Most plugins need to copy comments from .proto source files into the
generated code. Move this functionality into protogen to avoid
duplicating it everywhere.

Change-Id: I48a96ba794192e7ddc00281342afd4805ef6fe0f
Reviewed-on: https://go-review.googlesource.com/c/142890
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go-grpc/internal_gengogrpc/grpc.go b/cmd/protoc-gen-go-grpc/internal_gengogrpc/grpc.go
index 1c2fc66..e35f7c3 100644
--- a/cmd/protoc-gen-go-grpc/internal_gengogrpc/grpc.go
+++ b/cmd/protoc-gen-go-grpc/internal_gengogrpc/grpc.go
@@ -10,42 +10,28 @@
 	"strconv"
 	"strings"
 
-	descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
 	"github.com/golang/protobuf/v2/protogen"
 )
 
-type fileInfo struct {
-	*protogen.File
-	locationMap map[string][]*descpb.SourceCodeInfo_Location
-}
-
 // GenerateFile generates a _grpc.pb.go file containing gRPC service definitions.
-func GenerateFile(gen *protogen.Plugin, f *protogen.File) {
-	if len(f.Services) == 0 {
+func GenerateFile(gen *protogen.Plugin, file *protogen.File) {
+	if len(file.Services) == 0 {
 		return
 	}
-	filename := f.GeneratedFilenamePrefix + "_grpc.pb.go"
-	g := gen.NewGeneratedFile(filename, f.GoImportPath)
+	filename := file.GeneratedFilenamePrefix + "_grpc.pb.go"
+	g := gen.NewGeneratedFile(filename, file.GoImportPath)
 	g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.")
 	g.P()
-	g.P("package ", f.GoPackageName)
+	g.P("package ", file.GoPackageName)
 	g.P()
-	GenerateFileContent(gen, f, g)
+	GenerateFileContent(gen, file, g)
 }
 
 // GenerateFileContent generates the gRPC service definitions, excluding the package statement.
-func GenerateFileContent(gen *protogen.Plugin, f *protogen.File, g *protogen.GeneratedFile) {
-	if len(f.Services) == 0 {
+func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
+	if len(file.Services) == 0 {
 		return
 	}
-	file := &fileInfo{
-		File:        f,
-		locationMap: make(map[string][]*descpb.SourceCodeInfo_Location),
-	}
-	for _, loc := range file.Proto.GetSourceCodeInfo().GetLocation() {
-		key := pathKey(loc.Path)
-		file.locationMap[key] = append(file.locationMap[key], loc)
-	}
 
 	// TODO: Remove this. We don't need to include these references any more.
 	g.P("// Reference imports to suppress errors if they are not otherwise used.")
@@ -62,7 +48,7 @@
 	}
 }
 
-func genService(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, service *protogen.Service) {
+func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
 	clientName := service.GoName + "Client"
 
 	g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
@@ -77,7 +63,7 @@
 	g.Annotate(clientName, service.Location)
 	g.P("type ", clientName, " interface {")
 	for _, method := range service.Methods {
-		genComment(g, file, method.Location)
+		g.PrintLeadingComments(method.Location)
 		g.Annotate(clientName+"."+method.GoName, method.Location)
 		g.P(clientSignature(g, method))
 	}
@@ -123,7 +109,7 @@
 	g.Annotate(serverType, service.Location)
 	g.P("type ", serverType, " interface {")
 	for _, method := range service.Methods {
-		genComment(g, file, method.Location)
+		g.PrintLeadingComments(method.Location)
 		g.Annotate(serverType+"."+method.GoName, method.Location)
 		g.P(serverSignature(g, method))
 	}
@@ -199,7 +185,7 @@
 	return s
 }
 
-func genClientMethod(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, method *protogen.Method, index int) {
+func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
 	service := method.ParentService
 	sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
 
@@ -294,7 +280,7 @@
 	return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
 }
 
-func genServerMethod(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, method *protogen.Method) string {
+func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
 	service := method.ParentService
 	hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
 
@@ -388,32 +374,6 @@
 	}
 }
 
-func genComment(g *protogen.GeneratedFile, file *fileInfo, loc protogen.Location) (hasComment bool) {
-	for _, loc := range file.locationMap[pathKey(loc.Path)] {
-		if loc.LeadingComments == nil {
-			continue
-		}
-		for _, line := range strings.Split(strings.TrimSuffix(loc.GetLeadingComments(), "\n"), "\n") {
-			hasComment = true
-			g.P("//", line)
-		}
-		break
-	}
-	return hasComment
-}
-
 const deprecationComment = "// Deprecated: Do not use."
 
-// pathKey converts a location path to a string suitable for use as a map key.
-func pathKey(path []int32) string {
-	var buf []byte
-	for i, x := range path {
-		if i != 0 {
-			buf = append(buf, ',')
-		}
-		buf = strconv.AppendInt(buf, int64(x), 10)
-	}
-	return string(buf)
-}
-
 func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
diff --git a/cmd/protoc-gen-go/internal_gengo/main.go b/cmd/protoc-gen-go/internal_gengo/main.go
index c3154f5..509bbc0 100644
--- a/cmd/protoc-gen-go/internal_gengo/main.go
+++ b/cmd/protoc-gen-go/internal_gengo/main.go
@@ -32,7 +32,6 @@
 
 type fileInfo struct {
 	*protogen.File
-	locationMap   map[string][]*descpb.SourceCodeInfo_Location
 	descriptorVar string // var containing the gzipped FileDescriptorProto
 	allEnums      []*protogen.Enum
 	allMessages   []*protogen.Message
@@ -42,12 +41,7 @@
 // GenerateFile generates the contents of a .pb.go file.
 func GenerateFile(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
 	f := &fileInfo{
-		File:        file,
-		locationMap: make(map[string][]*descpb.SourceCodeInfo_Location),
-	}
-	for _, loc := range file.Proto.GetSourceCodeInfo().GetLocation() {
-		key := pathKey(loc.Path)
-		f.locationMap[key] = append(f.locationMap[key], loc)
+		File: file,
 	}
 
 	// The different order for enums and extensions is to match the output
@@ -76,7 +70,10 @@
 	}
 	g.P()
 	const filePackageField = 2 // FileDescriptorProto.package
-	genComment(g, f, protogen.Location{Path: []int32{filePackageField}})
+	g.PrintLeadingComments(protogen.Location{
+		SourceFile: f.Proto.GetName(),
+		Path:       []int32{filePackageField},
+	})
 	g.P()
 	g.P("package ", f.GoPackageName)
 	g.P()
@@ -237,13 +234,13 @@
 }
 
 func genEnum(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, enum *protogen.Enum) {
-	genComment(g, f, enum.Location)
+	g.PrintLeadingComments(enum.Location)
 	g.Annotate(enum.GoIdent.GoName, enum.Location)
 	g.P("type ", enum.GoIdent, " int32",
 		deprecationComment(enumOptions(gen, enum).GetDeprecated()))
 	g.P("const (")
 	for _, value := range enum.Values {
-		genComment(g, f, value.Location)
+		g.PrintLeadingComments(value.Location)
 		g.Annotate(value.GoIdent.GoName, value.Location)
 		g.P(value.GoIdent, " ", enum.GoIdent, " = ", value.Desc.Number(),
 			deprecationComment(enumValueOptions(gen, value).GetDeprecated()))
@@ -335,7 +332,7 @@
 		return
 	}
 
-	hasComment := genComment(g, f, message.Location)
+	hasComment := g.PrintLeadingComments(message.Location)
 	if messageOptions(gen, message).GetDeprecated() {
 		if hasComment {
 			g.P("//")
@@ -355,7 +352,7 @@
 			}
 			continue
 		}
-		genComment(g, f, field.Location)
+		g.PrintLeadingComments(field.Location)
 		goType, pointer := fieldGoType(g, field)
 		if pointer {
 			goType = "*" + goType
@@ -902,20 +899,6 @@
 	}
 }
 
-func genComment(g *protogen.GeneratedFile, f *fileInfo, loc protogen.Location) (hasComment bool) {
-	for _, loc := range f.locationMap[pathKey(loc.Path)] {
-		if loc.LeadingComments == nil {
-			continue
-		}
-		for _, line := range strings.Split(strings.TrimSuffix(loc.GetLeadingComments(), "\n"), "\n") {
-			hasComment = true
-			g.P("//", line)
-		}
-		break
-	}
-	return hasComment
-}
-
 // deprecationComment returns a standard deprecation comment if deprecated is true.
 func deprecationComment(deprecated bool) string {
 	if !deprecated {
@@ -924,18 +907,6 @@
 	return "// Deprecated: Do not use."
 }
 
-// pathKey converts a location path to a string suitable for use as a map key.
-func pathKey(path []int32) string {
-	var buf []byte
-	for i, x := range path {
-		if i != 0 {
-			buf = append(buf, ',')
-		}
-		buf = strconv.AppendInt(buf, int64(x), 10)
-	}
-	return string(buf)
-}
-
 func genWellKnownType(g *protogen.GeneratedFile, ptr string, ident protogen.GoIdent, desc protoreflect.Descriptor) {
 	if wellKnownTypes[desc.FullName()] {
 		g.P("func (", ptr, ident, `) XXX_WellKnownType() string { return "`, desc.Name(), `" }`)
diff --git a/cmd/protoc-gen-go/internal_gengo/oneof.go b/cmd/protoc-gen-go/internal_gengo/oneof.go
index ab06550..26ad1c5 100644
--- a/cmd/protoc-gen-go/internal_gengo/oneof.go
+++ b/cmd/protoc-gen-go/internal_gengo/oneof.go
@@ -15,12 +15,12 @@
 
 // genOneofField generates the struct field for a oneof.
 func genOneofField(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, message *protogen.Message, oneof *protogen.Oneof) {
-	if genComment(g, f, oneof.Location) {
+	if g.PrintLeadingComments(oneof.Location) {
 		g.P("//")
 	}
 	g.P("// Types that are valid to be assigned to ", oneof.GoName, ":")
 	for _, field := range oneof.Fields {
-		genComment(g, f, field.Location)
+		g.PrintLeadingComments(field.Location)
 		g.P("//\t*", fieldOneofType(field))
 	}
 	g.Annotate(message.GoIdent.GoName+"."+oneof.GoName, oneof.Location)
diff --git a/protogen/protogen.go b/protogen/protogen.go
index 22f663e..0a80aaa 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -13,6 +13,7 @@
 import (
 	"bufio"
 	"bytes"
+	"encoding/binary"
 	"fmt"
 	"go/ast"
 	"go/parser"
@@ -390,6 +391,8 @@
 	// For example, the source file "dir/foo.proto" might have a filename prefix
 	// of "dir/foo". Appending ".pb.go" produces an output file of "dir/foo.pb.go".
 	GeneratedFilenamePrefix string
+
+	sourceInfo map[pathKey][]*descpb.SourceCodeInfo_Location
 }
 
 func newFile(gen *Plugin, p *descpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) {
@@ -405,6 +408,7 @@
 		Proto:         p,
 		GoPackageName: packageName,
 		GoImportPath:  importPath,
+		sourceInfo:    make(map[pathKey][]*descpb.SourceCodeInfo_Location),
 	}
 
 	// Determine the prefix for generated Go files.
@@ -425,6 +429,10 @@
 	}
 	f.GeneratedFilenamePrefix = prefix
 
+	for _, loc := range p.GetSourceCodeInfo().GetLocation() {
+		key := newPathKey(loc.Path)
+		f.sourceInfo[key] = append(f.sourceInfo[key], loc)
+	}
 	for i, mdescs := 0, desc.Messages(); i < mdescs.Len(); i++ {
 		f.Messages = append(f.Messages, newMessage(gen, f, nil, mdescs.Get(i)))
 	}
@@ -854,6 +862,29 @@
 	fmt.Fprintln(&g.buf)
 }
 
+// PrintLeadingComments writes the comment appearing before a location in
+// the .proto source to the generated file.
+//
+// It returns true if a comment was present at the location.
+func (g *GeneratedFile) PrintLeadingComments(loc Location) (hasComment bool) {
+	f := g.gen.filesByName[loc.SourceFile]
+	if f == nil {
+		return false
+	}
+	for _, infoLoc := range f.sourceInfo[newPathKey(loc.Path)] {
+		if infoLoc.LeadingComments == nil {
+			continue
+		}
+		for _, line := range strings.Split(strings.TrimSuffix(infoLoc.GetLeadingComments(), "\n"), "\n") {
+			g.buf.WriteString("//")
+			g.buf.WriteString(line)
+			g.buf.WriteString("\n")
+		}
+		return true
+	}
+	return false
+}
+
 // QualifiedGoIdent returns the string to use for a Go identifier.
 //
 // If the identifier is from a different Go package than the generated file,
@@ -1070,3 +1101,17 @@
 		Path:       n,
 	}
 }
+
+// A pathKey is a representation of a location path suitable for use as a map key.
+type pathKey struct {
+	s string
+}
+
+// newPathKey converts a location path to a pathKey.
+func newPathKey(path []int32) pathKey {
+	buf := make([]byte, 4*len(path))
+	for i, x := range path {
+		binary.LittleEndian.PutUint32(buf[i*4:], uint32(x))
+	}
+	return pathKey{string(buf)}
+}