protogen: automatic handling of imports

The GoIdent type is now a tuple of import path and name. Generated files
have an associated import path. Writing a GoIdent to a generated file
qualifies the name if the identifier is from a different package.
All necessary imports are automatically added to generated Go files.

Change-Id: I839e0b7aa8ec967ce178aea4ffb960b62779cf74
Reviewed-on: https://go-review.googlesource.com/133635
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 7973895..e144aa6 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -25,7 +25,7 @@
 }
 
 func genFile(gen *protogen.Plugin, f *protogen.File) {
-	g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto") + ".pb.go")
+	g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto")+".pb.go", f.GoImportPath)
 	g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
 	g.P("// source: ", f.Desc.GetName())
 	g.P()
diff --git a/go.mod b/go.mod
index 6aa64ad..638151c 100644
--- a/go.mod
+++ b/go.mod
@@ -5,4 +5,5 @@
 	github.com/google/go-cmp v0.2.0
 	golang.org/x/net v0.0.0-20180821023952-922f4815f713 // indirect
 	golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
+	golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48
 )
diff --git a/go.sum b/go.sum
index 1bfafa6..51e6c99 100644
--- a/go.sum
+++ b/go.sum
@@ -6,3 +6,5 @@
 golang.org/x/net v0.0.0-20180821023952-922f4815f713/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
 golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48 h1:PIz+xUHW4G/jqfFWeKhQ96ZV/t2HDsXfWj923rV0bZY=
+golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/protogen/names.go b/protogen/names.go
index b97c47d..ea3d057 100644
--- a/protogen/names.go
+++ b/protogen/names.go
@@ -1,6 +1,7 @@
 package protogen
 
 import (
+	"fmt"
 	"go/token"
 	"strconv"
 	"strings"
@@ -8,8 +9,13 @@
 	"unicode/utf8"
 )
 
-// A GoIdent is a Go identifier.
-type GoIdent string
+// A GoIdent is a Go identifier, consisting of a name and import path.
+type GoIdent struct {
+	GoName       string
+	GoImportPath GoImportPath
+}
+
+func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
 
 // A GoImportPath is the import path of a Go package. e.g., "google.golang.org/genproto/protobuf".
 type GoImportPath string
@@ -64,7 +70,7 @@
 // but it's so remote we're prepared to pretend it's nonexistent - since the
 // C++ generator lowercases names, it's extremely unlikely to have two fields
 // with different capitalizations.
-func camelCase(s string) GoIdent {
+func camelCase(s string) string {
 	if s == "" {
 		return ""
 	}
@@ -102,7 +108,7 @@
 			}
 		}
 	}
-	return GoIdent(t)
+	return string(t)
 }
 
 // Is c an ASCII lower-case letter?
diff --git a/protogen/names_test.go b/protogen/names_test.go
index 021e71a..05e698e 100644
--- a/protogen/names_test.go
+++ b/protogen/names_test.go
@@ -8,8 +8,7 @@
 
 func TestCamelCase(t *testing.T) {
 	tests := []struct {
-		in   string
-		want GoIdent
+		in, want string
 	}{
 		{"one", "One"},
 		{"one_two", "OneTwo"},
diff --git a/protogen/protogen.go b/protogen/protogen.go
index f499a8d..2d65ee0 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -20,11 +20,14 @@
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"sort"
+	"strconv"
 	"strings"
 
 	"github.com/golang/protobuf/proto"
 	descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
 	pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
+	"golang.org/x/tools/go/ast/astutil"
 )
 
 // Run executes a function as a protoc plugin.
@@ -168,7 +171,7 @@
 			}
 		}
 		resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
-			Name:    proto.String(gf.path),
+			Name:    proto.String(gf.filename),
 			Content: proto.String(string(content)),
 		})
 	}
@@ -185,16 +188,17 @@
 type File struct {
 	Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor
 
-	Messages []*Message // top-level message declartions
-	Generate bool       // true if we should generate code for this file
+	GoImportPath GoImportPath // import path of this file's Go package
+	Messages     []*Message   // top-level message declarations
+	Generate     bool         // true if we should generate code for this file
 }
 
 func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File {
 	f := &File{
 		Desc: p,
 	}
-	for _, d := range p.MessageType {
-		f.Messages = append(f.Messages, newMessage(gen, nil, d))
+	for i, mdesc := range p.MessageType {
+		f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i))
 	}
 	return f
 }
@@ -207,30 +211,40 @@
 	Messages []*Message // nested message declarations
 }
 
-func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message {
+func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message {
 	m := &Message{
-		Desc:    p,
-		GoIdent: camelCase(p.GetName()),
+		Desc: p,
+		GoIdent: GoIdent{
+			GoName:       camelCase(p.GetName()),
+			GoImportPath: f.GoImportPath,
+		},
 	}
 	if parent != nil {
-		m.GoIdent = parent.GoIdent + "_" + m.GoIdent
+		m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName
 	}
-	for _, nested := range p.GetNestedType() {
-		m.Messages = append(m.Messages, newMessage(gen, m, nested))
+	for i, nested := range p.GetNestedType() {
+		m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i))
 	}
 	return m
 }
 
 // A GeneratedFile is a generated file.
 type GeneratedFile struct {
-	path string
-	buf  bytes.Buffer
+	filename         string
+	goImportPath     GoImportPath
+	buf              bytes.Buffer
+	packageNames     map[GoImportPath]GoPackageName
+	usedPackageNames map[GoPackageName]bool
 }
 
-// NewGeneratedFile creates a new generated file with the given path.
-func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
+// NewGeneratedFile creates a new generated file with the given filename
+// and import path.
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
 	g := &GeneratedFile{
-		path: path,
+		filename:         filename,
+		goImportPath:     goImportPath,
+		packageNames:     make(map[GoImportPath]GoPackageName),
+		usedPackageNames: make(map[GoPackageName]bool),
 	}
 	gen.genFiles = append(gen.genFiles, g)
 	return g
@@ -243,11 +257,33 @@
 // TODO: .meta file annotations.
 func (g *GeneratedFile) P(v ...interface{}) {
 	for _, x := range v {
-		fmt.Fprint(&g.buf, x)
+		switch x := x.(type) {
+		case GoIdent:
+			if x.GoImportPath != g.goImportPath {
+				fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
+				fmt.Fprint(&g.buf, ".")
+			}
+			fmt.Fprint(&g.buf, x.GoName)
+		default:
+			fmt.Fprint(&g.buf, x)
+		}
 	}
 	fmt.Fprintln(&g.buf)
 }
 
+func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
+	if name, ok := g.packageNames[importPath]; ok {
+		return name
+	}
+	name := cleanPackageName(baseName(string(importPath)))
+	for i, orig := 1, name; g.usedPackageNames[name]; i++ {
+		name = orig + GoPackageName(strconv.Itoa(i))
+	}
+	g.packageNames[importPath] = name
+	g.usedPackageNames[name] = true
+	return name
+}
+
 // Write implements io.Writer.
 func (g *GeneratedFile) Write(p []byte) (n int, err error) {
 	return g.buf.Write(p)
@@ -255,7 +291,7 @@
 
 // Content returns the contents of the generated file.
 func (g *GeneratedFile) Content() ([]byte, error) {
-	if !strings.HasSuffix(g.path, ".go") {
+	if !strings.HasSuffix(g.filename, ".go") {
 		return g.buf.Bytes(), nil
 	}
 
@@ -272,13 +308,24 @@
 		for line := 1; s.Scan(); line++ {
 			fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
 		}
-		return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String())
+		return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
 	}
+
+	// Add imports.
+	var importPaths []string
+	for importPath := range g.packageNames {
+		importPaths = append(importPaths, string(importPath))
+	}
+	sort.Strings(importPaths)
+	for _, importPath := range importPaths {
+		astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
+	}
+
 	var out bytes.Buffer
 	if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
-		return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err)
+		return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
 	}
-	// TODO: Patch annotation locations.
+	// TODO: Annotations.
 	return out.Bytes(), nil
 
 }
diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go
index 1d23cc0..4da5b3c 100644
--- a/protogen/protogen_test.go
+++ b/protogen/protogen_test.go
@@ -45,6 +45,57 @@
 	}
 }
 
+func TestImports(t *testing.T) {
+	gen, err := New(&pluginpb.CodeGeneratorRequest{})
+	if err != nil {
+		t.Fatal(err)
+	}
+	g := gen.NewGeneratedFile("foo.go", "golang.org/x/foo")
+	g.P("package foo")
+	g.P()
+	for _, importPath := range []GoImportPath{
+		"golang.org/x/foo",
+		// Multiple references to the same package.
+		"golang.org/x/bar",
+		"golang.org/x/bar",
+		// Reference to a different package with the same basename.
+		"golang.org/y/bar",
+		"golang.org/x/baz",
+	} {
+		g.P("var _ = ", GoIdent{GoName: "X", GoImportPath: importPath}, " // ", importPath)
+	}
+	want := `package foo
+
+import (
+	bar "golang.org/x/bar"
+	bar1 "golang.org/y/bar"
+	baz "golang.org/x/baz"
+)
+
+var _ = X      // "golang.org/x/foo"
+var _ = bar.X  // "golang.org/x/bar"
+var _ = bar.X  // "golang.org/x/bar"
+var _ = bar1.X // "golang.org/y/bar"
+var _ = baz.X  // "golang.org/x/baz"
+`
+	got, err := g.Content()
+	if err != nil {
+		t.Fatalf("g.Content() = %v", err)
+	}
+	if want != string(got) {
+		t.Fatalf(`want:
+==========
+%v
+==========
+
+got:
+==========
+%v
+==========`,
+			want, string(got))
+	}
+}
+
 // makeRequest returns a CodeGeneratorRequest for the given protoc inputs.
 //
 // It does this by running protoc with the current binary as the protoc-gen-go
@@ -86,7 +137,7 @@
 func init() {
 	if os.Getenv("RUN_AS_PROTOC_PLUGIN") != "" {
 		Run(func(p *Plugin) error {
-			g := p.NewGeneratedFile("request")
+			g := p.NewGeneratedFile("request", "")
 			return proto.MarshalText(g, p.Request)
 		})
 		os.Exit(0)