internal/importers: find structs with embedded types
Add support for using the Go AST package to find exported Go structs
that embed prefixed types. This is needed for a later CL to generate
wrappers for Go generated Java classes.
Change-Id: Ia304a7924a4e09332b74dc42a572932b7498cdca
Reviewed-on: https://go-review.googlesource.com/34775
Reviewed-by: David Crawshaw <crawshaw@golang.org>
diff --git a/internal/importers/ast.go b/internal/importers/ast.go
index e9c4ee1..4a7b920 100644
--- a/internal/importers/ast.go
+++ b/internal/importers/ast.go
@@ -33,6 +33,7 @@
"go/token"
"path"
"path/filepath"
+ "sort"
"strconv"
"strings"
)
@@ -58,6 +59,17 @@
// Useful as a conservative upper bound on the set of identifiers
// referenced from a set of packages.
Names map[string]struct{}
+ // Embedders is a list of struct types with prefixed types
+ // embedded.
+ Embedders []Struct
+}
+
+// Struct is a representation of a struct type with embedded
+// types.
+type Struct struct {
+ Name string
+ Pkg string
+ Refs []PkgRef
}
// PkgRef is a reference to an identifier in a package.
@@ -81,6 +93,7 @@
// Ignore errors (from unknown packages)
pkg, _ := ast.NewPackage(fset, files, visitor.importer(), nil)
ast.Walk(visitor, pkg)
+ visitor.findEmbeddingStructs(pkg)
return &visitor.References, nil
}
@@ -103,10 +116,57 @@
// Ignore errors (from unknown packages)
astpkg, _ := ast.NewPackage(fset, files, imp, nil)
ast.Walk(visitor, astpkg)
+ visitor.findEmbeddingStructs(astpkg)
}
return &visitor.References, nil
}
+// findEmbeddingStructs finds all top level declarations embedding a prefixed type.
+//
+// For example:
+//
+// import "Prefix/some/Package"
+//
+// type T struct {
+// Package.Class
+// }
+func (v *refsSaver) findEmbeddingStructs(pkg *ast.Package) {
+ var names []string
+ for _, obj := range pkg.Scope.Objects {
+ if obj.Kind != ast.Typ || !ast.IsExported(obj.Name) {
+ continue
+ }
+ names = append(names, obj.Name)
+ }
+ sort.Strings(names)
+ for _, name := range names {
+ obj := pkg.Scope.Objects[name]
+
+ t, ok := obj.Decl.(*ast.TypeSpec).Type.(*ast.StructType)
+ if !ok {
+ continue
+ }
+ var refs []PkgRef
+ for _, f := range t.Fields.List {
+ sel, ok := f.Type.(*ast.SelectorExpr)
+ if !ok {
+ continue
+ }
+ if ref, ok := v.parseRef(sel); ok {
+ refs = append(refs, ref)
+ }
+ }
+ if len(refs) > 0 {
+ v.Embedders = append(v.Embedders, Struct{
+ Name: obj.Name,
+ Pkg: pkg.Name,
+
+ Refs: refs,
+ })
+ }
+ }
+}
+
func newRefsSaver(pkgPrefix string) *refsSaver {
s := &refsSaver{
pkgPrefix: pkgPrefix,
@@ -130,27 +190,37 @@
}
}
+func (v *refsSaver) parseRef(sel *ast.SelectorExpr) (PkgRef, bool) {
+ x, ok := sel.X.(*ast.Ident)
+ if !ok || x.Obj == nil {
+ return PkgRef{}, false
+ }
+ imp, ok := x.Obj.Decl.(*ast.ImportSpec)
+ if !ok {
+ return PkgRef{}, false
+ }
+ pkgPath, err := strconv.Unquote(imp.Path.Value)
+ if err != nil {
+ return PkgRef{}, false
+ }
+ if !strings.HasPrefix(pkgPath, v.pkgPrefix) {
+ return PkgRef{}, false
+ }
+ pkgPath = pkgPath[len(v.pkgPrefix):]
+ return PkgRef{Pkg: pkgPath, Name: sel.Sel.Name}, true
+}
+
func (v *refsSaver) Visit(n ast.Node) ast.Visitor {
switch n := n.(type) {
case *ast.SelectorExpr:
v.Names[n.Sel.Name] = struct{}{}
- if x, ok := n.X.(*ast.Ident); ok && x.Obj != nil {
- if imp, ok := x.Obj.Decl.(*ast.ImportSpec); ok {
- pkgPath, err := strconv.Unquote(imp.Path.Value)
- if err != nil {
- return nil
- }
- if strings.HasPrefix(pkgPath, v.pkgPrefix) {
- pkgPath = pkgPath[len(v.pkgPrefix):]
- ref := PkgRef{Pkg: pkgPath, Name: n.Sel.Name}
- if _, exists := v.refMap[ref]; !exists {
- v.refMap[ref] = struct{}{}
- v.Refs = append(v.Refs, ref)
- }
- }
- return nil
+ if ref, ok := v.parseRef(n); ok {
+ if _, exists := v.refMap[ref]; !exists {
+ v.refMap[ref] = struct{}{}
+ v.Refs = append(v.Refs, ref)
}
}
+ return nil
case *ast.FuncDecl:
if n.Recv != nil { // Methods
v.Names[n.Name.Name] = struct{}{}
diff --git a/internal/importers/ast_test.go b/internal/importers/ast_test.go
index 92ba531..ee9bc8c 100644
--- a/internal/importers/ast_test.go
+++ b/internal/importers/ast_test.go
@@ -3,6 +3,7 @@
import (
"go/parser"
"go/token"
+ "reflect"
"testing"
)
@@ -12,6 +13,10 @@
import "Prefix/some/pkg/Name"
const c = Name.Constant
+
+type T struct {
+ Name.Type
+}
`
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "ast_test.go", file, parser.AllErrors)
@@ -22,8 +27,8 @@
if err != nil {
t.Fatal(err)
}
- if len(refs.Refs) != 1 {
- t.Fatalf("expected 1 reference; got %d", len(refs.Refs))
+ if len(refs.Refs) != 2 {
+ t.Fatalf("expected 2 references; got %d", len(refs.Refs))
}
got := refs.Refs[0]
if exp := (PkgRef{"some/pkg/Name", "Constant"}); exp != got {
@@ -32,4 +37,16 @@
if _, exists := refs.Names["Constant"]; !exists {
t.Errorf("expected \"Constant\" in the names set")
}
+ if len(refs.Embedders) != 1 {
+ t.Fatalf("expected 1 struct; got %d", len(refs.Embedders))
+ }
+ s := refs.Embedders[0]
+ exp := Struct{
+ Name: "T",
+ Pkg: "ast_test",
+ Refs: []PkgRef{{"some/pkg/Name", "Type"}},
+ }
+ if !reflect.DeepEqual(exp, s) {
+ t.Errorf("expected struct %v; got %v", exp, got)
+ }
}