protogen: automatic handling of imports
The GoIdent type is now a tuple of import path and name. Generated files
have an associated import path. Writing a GoIdent to a generated file
qualifies the name if the identifier is from a different package.
All necessary imports are automatically added to generated Go files.
Change-Id: I839e0b7aa8ec967ce178aea4ffb960b62779cf74
Reviewed-on: https://go-review.googlesource.com/133635
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 7973895..e144aa6 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -25,7 +25,7 @@
}
func genFile(gen *protogen.Plugin, f *protogen.File) {
- g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto") + ".pb.go")
+ g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto")+".pb.go", f.GoImportPath)
g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
g.P("// source: ", f.Desc.GetName())
g.P()
diff --git a/go.mod b/go.mod
index 6aa64ad..638151c 100644
--- a/go.mod
+++ b/go.mod
@@ -5,4 +5,5 @@
github.com/google/go-cmp v0.2.0
golang.org/x/net v0.0.0-20180821023952-922f4815f713 // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
+ golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48
)
diff --git a/go.sum b/go.sum
index 1bfafa6..51e6c99 100644
--- a/go.sum
+++ b/go.sum
@@ -6,3 +6,5 @@
golang.org/x/net v0.0.0-20180821023952-922f4815f713/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48 h1:PIz+xUHW4G/jqfFWeKhQ96ZV/t2HDsXfWj923rV0bZY=
+golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/protogen/names.go b/protogen/names.go
index b97c47d..ea3d057 100644
--- a/protogen/names.go
+++ b/protogen/names.go
@@ -1,6 +1,7 @@
package protogen
import (
+ "fmt"
"go/token"
"strconv"
"strings"
@@ -8,8 +9,13 @@
"unicode/utf8"
)
-// A GoIdent is a Go identifier.
-type GoIdent string
+// A GoIdent is a Go identifier, consisting of a name and import path.
+type GoIdent struct {
+ GoName string
+ GoImportPath GoImportPath
+}
+
+func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
// A GoImportPath is the import path of a Go package. e.g., "google.golang.org/genproto/protobuf".
type GoImportPath string
@@ -64,7 +70,7 @@
// but it's so remote we're prepared to pretend it's nonexistent - since the
// C++ generator lowercases names, it's extremely unlikely to have two fields
// with different capitalizations.
-func camelCase(s string) GoIdent {
+func camelCase(s string) string {
if s == "" {
return ""
}
@@ -102,7 +108,7 @@
}
}
}
- return GoIdent(t)
+ return string(t)
}
// Is c an ASCII lower-case letter?
diff --git a/protogen/names_test.go b/protogen/names_test.go
index 021e71a..05e698e 100644
--- a/protogen/names_test.go
+++ b/protogen/names_test.go
@@ -8,8 +8,7 @@
func TestCamelCase(t *testing.T) {
tests := []struct {
- in string
- want GoIdent
+ in, want string
}{
{"one", "One"},
{"one_two", "OneTwo"},
diff --git a/protogen/protogen.go b/protogen/protogen.go
index f499a8d..2d65ee0 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -20,11 +20,14 @@
"io/ioutil"
"os"
"path/filepath"
+ "sort"
+ "strconv"
"strings"
"github.com/golang/protobuf/proto"
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
+ "golang.org/x/tools/go/ast/astutil"
)
// Run executes a function as a protoc plugin.
@@ -168,7 +171,7 @@
}
}
resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
- Name: proto.String(gf.path),
+ Name: proto.String(gf.filename),
Content: proto.String(string(content)),
})
}
@@ -185,16 +188,17 @@
type File struct {
Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor
- Messages []*Message // top-level message declartions
- Generate bool // true if we should generate code for this file
+ GoImportPath GoImportPath // import path of this file's Go package
+ Messages []*Message // top-level message declarations
+ Generate bool // true if we should generate code for this file
}
func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File {
f := &File{
Desc: p,
}
- for _, d := range p.MessageType {
- f.Messages = append(f.Messages, newMessage(gen, nil, d))
+ for i, mdesc := range p.MessageType {
+ f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i))
}
return f
}
@@ -207,30 +211,40 @@
Messages []*Message // nested message declarations
}
-func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message {
+func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message {
m := &Message{
- Desc: p,
- GoIdent: camelCase(p.GetName()),
+ Desc: p,
+ GoIdent: GoIdent{
+ GoName: camelCase(p.GetName()),
+ GoImportPath: f.GoImportPath,
+ },
}
if parent != nil {
- m.GoIdent = parent.GoIdent + "_" + m.GoIdent
+ m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName
}
- for _, nested := range p.GetNestedType() {
- m.Messages = append(m.Messages, newMessage(gen, m, nested))
+ for i, nested := range p.GetNestedType() {
+ m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i))
}
return m
}
// A GeneratedFile is a generated file.
type GeneratedFile struct {
- path string
- buf bytes.Buffer
+ filename string
+ goImportPath GoImportPath
+ buf bytes.Buffer
+ packageNames map[GoImportPath]GoPackageName
+ usedPackageNames map[GoPackageName]bool
}
-// NewGeneratedFile creates a new generated file with the given path.
-func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
+// NewGeneratedFile creates a new generated file with the given filename
+// and import path.
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
g := &GeneratedFile{
- path: path,
+ filename: filename,
+ goImportPath: goImportPath,
+ packageNames: make(map[GoImportPath]GoPackageName),
+ usedPackageNames: make(map[GoPackageName]bool),
}
gen.genFiles = append(gen.genFiles, g)
return g
@@ -243,11 +257,33 @@
// TODO: .meta file annotations.
func (g *GeneratedFile) P(v ...interface{}) {
for _, x := range v {
- fmt.Fprint(&g.buf, x)
+ switch x := x.(type) {
+ case GoIdent:
+ if x.GoImportPath != g.goImportPath {
+ fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
+ fmt.Fprint(&g.buf, ".")
+ }
+ fmt.Fprint(&g.buf, x.GoName)
+ default:
+ fmt.Fprint(&g.buf, x)
+ }
}
fmt.Fprintln(&g.buf)
}
+func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
+ if name, ok := g.packageNames[importPath]; ok {
+ return name
+ }
+ name := cleanPackageName(baseName(string(importPath)))
+ for i, orig := 1, name; g.usedPackageNames[name]; i++ {
+ name = orig + GoPackageName(strconv.Itoa(i))
+ }
+ g.packageNames[importPath] = name
+ g.usedPackageNames[name] = true
+ return name
+}
+
// Write implements io.Writer.
func (g *GeneratedFile) Write(p []byte) (n int, err error) {
return g.buf.Write(p)
@@ -255,7 +291,7 @@
// Content returns the contents of the generated file.
func (g *GeneratedFile) Content() ([]byte, error) {
- if !strings.HasSuffix(g.path, ".go") {
+ if !strings.HasSuffix(g.filename, ".go") {
return g.buf.Bytes(), nil
}
@@ -272,13 +308,24 @@
for line := 1; s.Scan(); line++ {
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
}
- return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String())
+ return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
}
+
+ // Add imports.
+ var importPaths []string
+ for importPath := range g.packageNames {
+ importPaths = append(importPaths, string(importPath))
+ }
+ sort.Strings(importPaths)
+ for _, importPath := range importPaths {
+ astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
+ }
+
var out bytes.Buffer
if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
- return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err)
+ return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
}
- // TODO: Patch annotation locations.
+ // TODO: Annotations.
return out.Bytes(), nil
}
diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go
index 1d23cc0..4da5b3c 100644
--- a/protogen/protogen_test.go
+++ b/protogen/protogen_test.go
@@ -45,6 +45,57 @@
}
}
+func TestImports(t *testing.T) {
+ gen, err := New(&pluginpb.CodeGeneratorRequest{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ g := gen.NewGeneratedFile("foo.go", "golang.org/x/foo")
+ g.P("package foo")
+ g.P()
+ for _, importPath := range []GoImportPath{
+ "golang.org/x/foo",
+ // Multiple references to the same package.
+ "golang.org/x/bar",
+ "golang.org/x/bar",
+ // Reference to a different package with the same basename.
+ "golang.org/y/bar",
+ "golang.org/x/baz",
+ } {
+ g.P("var _ = ", GoIdent{GoName: "X", GoImportPath: importPath}, " // ", importPath)
+ }
+ want := `package foo
+
+import (
+ bar "golang.org/x/bar"
+ bar1 "golang.org/y/bar"
+ baz "golang.org/x/baz"
+)
+
+var _ = X // "golang.org/x/foo"
+var _ = bar.X // "golang.org/x/bar"
+var _ = bar.X // "golang.org/x/bar"
+var _ = bar1.X // "golang.org/y/bar"
+var _ = baz.X // "golang.org/x/baz"
+`
+ got, err := g.Content()
+ if err != nil {
+ t.Fatalf("g.Content() = %v", err)
+ }
+ if want != string(got) {
+ t.Fatalf(`want:
+==========
+%v
+==========
+
+got:
+==========
+%v
+==========`,
+ want, string(got))
+ }
+}
+
// makeRequest returns a CodeGeneratorRequest for the given protoc inputs.
//
// It does this by running protoc with the current binary as the protoc-gen-go
@@ -86,7 +137,7 @@
func init() {
if os.Getenv("RUN_AS_PROTOC_PLUGIN") != "" {
Run(func(p *Plugin) error {
- g := p.NewGeneratedFile("request")
+ g := p.NewGeneratedFile("request", "")
return proto.MarshalText(g, p.Request)
})
os.Exit(0)