protogen: compute package names, import paths, generated filenames

Copy/duplicate the logic in github.com/golang/protobuf for computing
package names and import paths and the names of generated files.

This is all sufficiently complicated that the code is the best
documentation. In practice, users should always set a go_package option
containing an import path in every file and pass the
paths=source_relative generator flag to get reasonable behavior.

Change-Id: I34ae38fcc8db6909a4b25b16c73b982a7bad0463
Reviewed-on: https://go-review.googlesource.com/133876
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index bfae8eb..5792ebe 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -7,8 +7,6 @@
 package main
 
 import (
-	"strings"
-
 	"google.golang.org/proto/protogen"
 )
 
@@ -25,11 +23,11 @@
 }
 
 func genFile(gen *protogen.Plugin, f *protogen.File) {
-	g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.Path(), ".proto")+".pb.go", f.GoImportPath)
+	g := gen.NewGeneratedFile(f.GeneratedFilenamePrefix+".pb.go", f.GoImportPath)
 	g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
 	g.P("// source: ", f.Desc.Path())
 	g.P()
-	g.P("package TODO")
+	g.P("package ", f.GoPackageName)
 	g.P()
 
 	for _, m := range f.Messages {
diff --git a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
index 014d3a4..0ae32d1 100644
--- a/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/nested_messages.pb.go
@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // source: proto2/nested_messages.proto
 
-package TODO
+package proto2
 
 type Layer1 struct {
 }
diff --git a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
index 771028e..90ffa17 100644
--- a/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
+++ b/cmd/protoc-gen-go/testdata/proto2/proto2.pb.go
@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // source: proto2/proto2.proto
 
-package TODO
+package proto2
 
 type Message struct {
 }
diff --git a/protogen/protogen.go b/protogen/protogen.go
index ffba172..ffc1d7a 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -19,6 +19,7 @@
 	"go/token"
 	"io/ioutil"
 	"os"
+	"path"
 	"path/filepath"
 	"sort"
 	"strconv"
@@ -91,10 +92,8 @@
 	Files       []*File
 	filesByName map[string]*File
 
-	fileReg *protoregistry.Files
-
-	packageImportPath string // Go import path of the package we're generating code for.
-
+	fileReg  *protoregistry.Files
+	pathType pathType
 	genFiles []*GeneratedFile
 	err      error
 }
@@ -108,6 +107,9 @@
 	}
 
 	// TODO: Figure out how to pass parameters to the generator.
+	packageNames := make(map[string]GoPackageName) // filename -> package name
+	importPaths := make(map[string]GoImportPath)   // filename -> import path
+	var packageImportPath GoImportPath
 	for _, param := range strings.Split(req.GetParameter(), ",") {
 		var value string
 		if i := strings.Index(param, "="); i >= 0 {
@@ -120,9 +122,16 @@
 		case "import_prefix":
 			// TODO
 		case "import_path":
-			gen.packageImportPath = value
+			packageImportPath = GoImportPath(value)
 		case "paths":
-			// TODO
+			switch value {
+			case "import":
+				gen.pathType = pathTypeImport
+			case "source_relative":
+				gen.pathType = pathTypeSourceRelative
+			default:
+				return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value)
+			}
 		case "plugins":
 			// TODO
 		case "annotate_code":
@@ -131,26 +140,121 @@
 			if param[0] != 'M' {
 				return nil, fmt.Errorf("unknown parameter %q", param)
 			}
-			// TODO
+			importPaths[param[1:]] = GoImportPath(value)
+		}
+	}
+
+	// Figure out the import path and package name for each file.
+	//
+	// The rules here are complicated and have grown organically over time.
+	// Interactions between different ways of specifying package information
+	// may be surprising.
+	//
+	// The recommended approach is to include a go_package option in every
+	// .proto source file specifying the full import path of the Go package
+	// associated with this file.
+	//
+	//     option go_package = "github.com/golang/protobuf/ptypes/any";
+	//
+	// Build systems which want to exert full control over import paths may
+	// specify M<filename>=<import_path> flags.
+	//
+	// Other approaches are not recommend.
+	generatedFileNames := make(map[string]bool)
+	for _, name := range gen.Request.FileToGenerate {
+		generatedFileNames[name] = true
+	}
+	// We need to determine the import paths before the package names,
+	// because the Go package name for a file is sometimes derived from
+	// different file in the same package.
+	packageNameForImportPath := make(map[GoImportPath]GoPackageName)
+	for _, fdesc := range gen.Request.ProtoFile {
+		filename := fdesc.GetName()
+		packageName, importPath := goPackageOption(fdesc)
+		switch {
+		case importPaths[filename] != "":
+			// Command line: M=foo.proto=quux/bar
+			//
+			// Explicit mapping of source file to import path.
+		case generatedFileNames[filename] && packageImportPath != "":
+			// Command line: import_path=quux/bar
+			//
+			// The import_path flag sets the import path for every file that
+			// we generate code for.
+			importPaths[filename] = packageImportPath
+		case importPath != "":
+			// Source file: option go_package = "quux/bar";
+			//
+			// The go_package option sets the import path. Most users should use this.
+			importPaths[filename] = importPath
+		default:
+			// Source filename.
+			//
+			// Last resort when nothing else is available.
+			importPaths[filename] = GoImportPath(path.Dir(filename))
+		}
+		if packageName != "" {
+			packageNameForImportPath[importPaths[filename]] = packageName
+		}
+	}
+	for _, fdesc := range gen.Request.ProtoFile {
+		filename := fdesc.GetName()
+		packageName, _ := goPackageOption(fdesc)
+		defaultPackageName := packageNameForImportPath[importPaths[filename]]
+		switch {
+		case packageName != "":
+			// Source file: option go_package = "quux/bar";
+			packageNames[filename] = packageName
+		case defaultPackageName != "":
+			// A go_package option in another file in the same package.
+			//
+			// This is a poor choice in general, since every source file should
+			// contain a go_package option. Supported mainly for historical
+			// compatibility.
+			packageNames[filename] = defaultPackageName
+		case generatedFileNames[filename] && packageImportPath != "":
+			// Command line: import_path=quux/bar
+			packageNames[filename] = cleanPackageName(path.Base(string(packageImportPath)))
+		case fdesc.GetPackage() != "":
+			// Source file: package quux.bar;
+			packageNames[filename] = cleanPackageName(fdesc.GetPackage())
+		default:
+			// Source filename.
+			packageNames[filename] = cleanPackageName(baseName(filename))
+		}
+	}
+
+	// Consistency check: Every file with the same Go import path should have
+	// the same Go package name.
+	packageFiles := make(map[GoImportPath][]string)
+	for filename, importPath := range importPaths {
+		packageFiles[importPath] = append(packageFiles[importPath], filename)
+	}
+	for importPath, filenames := range packageFiles {
+		for i := 1; i < len(filenames); i++ {
+			if a, b := packageNames[filenames[0]], packageNames[filenames[i]]; a != b {
+				return nil, fmt.Errorf("Go package %v has inconsistent names %v (%v) and %v (%v)",
+					importPath, a, filenames[0], b, filenames[i])
+			}
 		}
 	}
 
 	for _, fdesc := range gen.Request.ProtoFile {
-		f, err := newFile(gen, fdesc)
+		filename := fdesc.GetName()
+		if gen.filesByName[filename] != nil {
+			return nil, fmt.Errorf("duplicate file name: %q", filename)
+		}
+		f, err := newFile(gen, fdesc, packageNames[filename], importPaths[filename])
 		if err != nil {
 			return nil, err
 		}
-		name := f.Desc.Path()
-		if gen.filesByName[name] != nil {
-			return nil, fmt.Errorf("duplicate file name: %q", name)
-		}
 		gen.Files = append(gen.Files, f)
-		gen.filesByName[name] = f
+		gen.filesByName[filename] = f
 	}
-	for _, name := range gen.Request.FileToGenerate {
-		f, ok := gen.FileByName(name)
+	for _, filename := range gen.Request.FileToGenerate {
+		f, ok := gen.FileByName(filename)
 		if !ok {
-			return nil, fmt.Errorf("no descriptor for generated file: %v", name)
+			return nil, fmt.Errorf("no descriptor for generated file: %v", filename)
 		}
 		f.Generate = true
 	}
@@ -197,12 +301,20 @@
 type File struct {
 	Desc protoreflect.FileDescriptor
 
-	GoImportPath GoImportPath // import path of this file's Go package
-	Messages     []*Message   // top-level message declarations
-	Generate     bool         // true if we should generate code for this file
+	GoPackageName GoPackageName // name of this file's Go package
+	GoImportPath  GoImportPath  // import path of this file's Go package
+	Messages      []*Message    // top-level message declarations
+	Generate      bool          // true if we should generate code for this file
+
+	// GeneratedFilenamePrefix is used to construct filenames for generated
+	// files associated with this source file.
+	//
+	// For example, the source file "dir/foo.proto" might have a filename prefix
+	// of "dir/foo". Appending ".pb.go" produces an output file of "dir/foo.pb.go".
+	GeneratedFilenamePrefix string
 }
 
-func newFile(gen *Plugin, p *descpb.FileDescriptorProto) (*File, error) {
+func newFile(gen *Plugin, p *descpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) {
 	desc, err := prototype.NewFileFromDescriptorProto(p, gen.fileReg)
 	if err != nil {
 		return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
@@ -211,14 +323,55 @@
 		return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
 	}
 	f := &File{
-		Desc: desc,
+		Desc:          desc,
+		GoPackageName: packageName,
+		GoImportPath:  importPath,
 	}
+
+	// Determine the prefix for generated Go files.
+	prefix := p.GetName()
+	if ext := path.Ext(prefix); ext == ".proto" || ext == ".protodevel" {
+		prefix = prefix[:len(prefix)-len(ext)]
+	}
+	if gen.pathType == pathTypeImport {
+		// If paths=import (the default) and the file contains a go_package option
+		// with a full import path, the output filename is derived from the Go import
+		// path.
+		//
+		// Pass the paths=source_relative flag to always derive the output filename
+		// from the input filename instead.
+		if _, importPath := goPackageOption(p); importPath != "" {
+			prefix = path.Join(string(importPath), path.Base(prefix))
+		}
+	}
+	f.GeneratedFilenamePrefix = prefix
+
 	for i, mdescs := 0, desc.Messages(); i < mdescs.Len(); i++ {
 		f.Messages = append(f.Messages, newMessage(gen, f, nil, mdescs.Get(i), i))
 	}
 	return f, nil
 }
 
+// goPackageOption interprets a file's go_package option.
+// If there is no go_package, it returns ("", "").
+// If there's a simple name, it returns (pkg, "").
+// If the option implies an import path, it returns (pkg, impPath).
+func goPackageOption(d *descpb.FileDescriptorProto) (pkg GoPackageName, impPath GoImportPath) {
+	opt := d.GetOptions().GetGoPackage()
+	if opt == "" {
+		return "", ""
+	}
+	// A semicolon-delimited suffix delimits the import path and package name.
+	if i := strings.Index(opt, ";"); i >= 0 {
+		return cleanPackageName(opt[i+1:]), GoImportPath(opt[:i])
+	}
+	// The presence of a slash implies there's an import path.
+	if i := strings.LastIndex(opt, "/"); i >= 0 {
+		return cleanPackageName(opt[i+1:]), GoImportPath(opt)
+	}
+	return cleanPackageName(opt), ""
+}
+
 // A Message describes a message.
 type Message struct {
 	Desc protoreflect.MessageDescriptor
@@ -339,3 +492,10 @@
 	return out.Bytes(), nil
 
 }
+
+type pathType int
+
+const (
+	pathTypeImport pathType = iota
+	pathTypeSourceRelative
+)
diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go
index 4da5b3c..05d0bf2 100644
--- a/protogen/protogen_test.go
+++ b/protogen/protogen_test.go
@@ -5,6 +5,7 @@
 package protogen
 
 import (
+	"fmt"
 	"io/ioutil"
 	"os"
 	"os/exec"
@@ -13,6 +14,7 @@
 	"testing"
 
 	"github.com/golang/protobuf/proto"
+	descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
 	pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
 )
 
@@ -45,6 +47,177 @@
 	}
 }
 
+func TestPackageNamesAndPaths(t *testing.T) {
+	const (
+		filename         = "dir/filename.proto"
+		protoPackageName = "proto.package"
+	)
+	for _, test := range []struct {
+		desc               string
+		parameter          string
+		goPackageOption    string
+		generate           bool
+		wantPackageName    GoPackageName
+		wantImportPath     GoImportPath
+		wantFilenamePrefix string
+	}{
+		{
+			desc:               "no parameters, no go_package option",
+			generate:           true,
+			wantPackageName:    "proto_package",
+			wantImportPath:     "dir",
+			wantFilenamePrefix: "dir/filename",
+		},
+		{
+			desc:               "go_package option sets import path",
+			goPackageOption:    "golang.org/x/foo",
+			generate:           true,
+			wantPackageName:    "foo",
+			wantImportPath:     "golang.org/x/foo",
+			wantFilenamePrefix: "golang.org/x/foo/filename",
+		},
+		{
+			desc:               "go_package option sets import path and package",
+			goPackageOption:    "golang.org/x/foo;bar",
+			generate:           true,
+			wantPackageName:    "bar",
+			wantImportPath:     "golang.org/x/foo",
+			wantFilenamePrefix: "golang.org/x/foo/filename",
+		},
+		{
+			desc:               "go_package option sets package",
+			goPackageOption:    "foo",
+			generate:           true,
+			wantPackageName:    "foo",
+			wantImportPath:     "dir",
+			wantFilenamePrefix: "dir/filename",
+		},
+		{
+			desc:               "command line sets import path for a file",
+			parameter:          "Mdir/filename.proto=golang.org/x/bar",
+			goPackageOption:    "golang.org/x/foo",
+			generate:           true,
+			wantPackageName:    "foo",
+			wantImportPath:     "golang.org/x/bar",
+			wantFilenamePrefix: "golang.org/x/foo/filename",
+		},
+		{
+			desc:               "import_path parameter sets import path of generated files",
+			parameter:          "import_path=golang.org/x/bar",
+			goPackageOption:    "golang.org/x/foo",
+			generate:           true,
+			wantPackageName:    "foo",
+			wantImportPath:     "golang.org/x/bar",
+			wantFilenamePrefix: "golang.org/x/foo/filename",
+		},
+		{
+			desc:               "import_path parameter does not set import path of dependencies",
+			parameter:          "import_path=golang.org/x/bar",
+			goPackageOption:    "golang.org/x/foo",
+			generate:           false,
+			wantPackageName:    "foo",
+			wantImportPath:     "golang.org/x/foo",
+			wantFilenamePrefix: "golang.org/x/foo/filename",
+		},
+	} {
+		context := fmt.Sprintf(`
+TEST: %v
+  --go_out=%v:.
+  file %q: generate=%v
+  option go_package = %q;
+
+  `,
+			test.desc, test.parameter, filename, test.generate, test.goPackageOption)
+
+		req := &pluginpb.CodeGeneratorRequest{
+			Parameter: proto.String(test.parameter),
+			ProtoFile: []*descpb.FileDescriptorProto{
+				{
+					Name:    proto.String(filename),
+					Package: proto.String(protoPackageName),
+					Options: &descpb.FileOptions{
+						GoPackage: proto.String(test.goPackageOption),
+					},
+				},
+			},
+		}
+		if test.generate {
+			req.FileToGenerate = []string{filename}
+		}
+		gen, err := New(req)
+		if err != nil {
+			t.Errorf("%vNew(req) = %v", context, err)
+			continue
+		}
+		gotFile, ok := gen.FileByName(filename)
+		if !ok {
+			t.Errorf("%v%v: missing file info", context, filename)
+			continue
+		}
+		if got, want := gotFile.GoPackageName, test.wantPackageName; got != want {
+			t.Errorf("%vGoPackageName=%v, want %v", context, got, want)
+		}
+		if got, want := gotFile.GoImportPath, test.wantImportPath; got != want {
+			t.Errorf("%vGoImportPath=%v, want %v", context, got, want)
+		}
+		if got, want := gotFile.GeneratedFilenamePrefix, test.wantFilenamePrefix; got != want {
+			t.Errorf("%vGeneratedFilenamePrefix=%v, want %v", context, got, want)
+		}
+	}
+}
+
+func TestPackageNameInference(t *testing.T) {
+	gen, err := New(&pluginpb.CodeGeneratorRequest{
+		ProtoFile: []*descpb.FileDescriptorProto{
+			{
+				Name:    proto.String("dir/file1.proto"),
+				Package: proto.String("proto.package"),
+			},
+			{
+				Name:    proto.String("dir/file2.proto"),
+				Package: proto.String("proto.package"),
+				Options: &descpb.FileOptions{
+					GoPackage: proto.String("foo"),
+				},
+			},
+		},
+		FileToGenerate: []string{"dir/file1.proto", "dir/file2.proto"},
+	})
+	if err != nil {
+		t.Fatalf("New(req) = %v", err)
+	}
+	if f1, ok := gen.FileByName("dir/file1.proto"); !ok {
+		t.Errorf("missing file info for dir/file1.proto")
+	} else if f1.GoPackageName != "foo" {
+		t.Errorf("dir/file1.proto: GoPackageName=%v, want foo; package name should be derived from dir/file2.proto", f1.GoPackageName)
+	}
+}
+
+func TestInconsistentPackageNames(t *testing.T) {
+	_, err := New(&pluginpb.CodeGeneratorRequest{
+		ProtoFile: []*descpb.FileDescriptorProto{
+			{
+				Name:    proto.String("dir/file1.proto"),
+				Package: proto.String("proto.package"),
+				Options: &descpb.FileOptions{
+					GoPackage: proto.String("golang.org/x/foo"),
+				},
+			},
+			{
+				Name:    proto.String("dir/file2.proto"),
+				Package: proto.String("proto.package"),
+				Options: &descpb.FileOptions{
+					GoPackage: proto.String("golang.org/x/foo;bar"),
+				},
+			},
+		},
+		FileToGenerate: []string{"dir/file1.proto", "dir/file2.proto"},
+	})
+	if err == nil {
+		t.Fatalf("inconsistent package names for the same import path: New(req) = nil, want error")
+	}
+}
+
 func TestImports(t *testing.T) {
 	gen, err := New(&pluginpb.CodeGeneratorRequest{})
 	if err != nil {