internal/importers: find structs with embedded types

Add support for using the Go AST package to find exported Go structs
that embed prefixed types. This is needed for a later CL to generate
wrappers for Go generated Java classes.

Change-Id: Ia304a7924a4e09332b74dc42a572932b7498cdca
Reviewed-on: https://go-review.googlesource.com/34775
Reviewed-by: David Crawshaw <crawshaw@golang.org>
diff --git a/internal/importers/ast.go b/internal/importers/ast.go
index e9c4ee1..4a7b920 100644
--- a/internal/importers/ast.go
+++ b/internal/importers/ast.go
@@ -33,6 +33,7 @@
 	"go/token"
 	"path"
 	"path/filepath"
+	"sort"
 	"strconv"
 	"strings"
 )
@@ -58,6 +59,17 @@
 	// Useful as a conservative upper bound on the set of identifiers
 	// referenced from a set of packages.
 	Names map[string]struct{}
+	// Embedders is a list of struct types with prefixed types
+	// embedded.
+	Embedders []Struct
+}
+
+// Struct is a representation of a struct type with embedded
+// types.
+type Struct struct {
+	Name string
+	Pkg  string
+	Refs []PkgRef
 }
 
 // PkgRef is a reference to an identifier in a package.
@@ -81,6 +93,7 @@
 	// Ignore errors (from unknown packages)
 	pkg, _ := ast.NewPackage(fset, files, visitor.importer(), nil)
 	ast.Walk(visitor, pkg)
+	visitor.findEmbeddingStructs(pkg)
 	return &visitor.References, nil
 }
 
@@ -103,10 +116,57 @@
 		// Ignore errors (from unknown packages)
 		astpkg, _ := ast.NewPackage(fset, files, imp, nil)
 		ast.Walk(visitor, astpkg)
+		visitor.findEmbeddingStructs(astpkg)
 	}
 	return &visitor.References, nil
 }
 
+// findEmbeddingStructs finds all top level declarations embedding a prefixed type.
+//
+// For example:
+//
+// import "Prefix/some/Package"
+//
+// type T struct {
+//     Package.Class
+// }
+func (v *refsSaver) findEmbeddingStructs(pkg *ast.Package) {
+	var names []string
+	for _, obj := range pkg.Scope.Objects {
+		if obj.Kind != ast.Typ || !ast.IsExported(obj.Name) {
+			continue
+		}
+		names = append(names, obj.Name)
+	}
+	sort.Strings(names)
+	for _, name := range names {
+		obj := pkg.Scope.Objects[name]
+
+		t, ok := obj.Decl.(*ast.TypeSpec).Type.(*ast.StructType)
+		if !ok {
+			continue
+		}
+		var refs []PkgRef
+		for _, f := range t.Fields.List {
+			sel, ok := f.Type.(*ast.SelectorExpr)
+			if !ok {
+				continue
+			}
+			if ref, ok := v.parseRef(sel); ok {
+				refs = append(refs, ref)
+			}
+		}
+		if len(refs) > 0 {
+			v.Embedders = append(v.Embedders, Struct{
+				Name: obj.Name,
+				Pkg:  pkg.Name,
+
+				Refs: refs,
+			})
+		}
+	}
+}
+
 func newRefsSaver(pkgPrefix string) *refsSaver {
 	s := &refsSaver{
 		pkgPrefix: pkgPrefix,
@@ -130,27 +190,37 @@
 	}
 }
 
+func (v *refsSaver) parseRef(sel *ast.SelectorExpr) (PkgRef, bool) {
+	x, ok := sel.X.(*ast.Ident)
+	if !ok || x.Obj == nil {
+		return PkgRef{}, false
+	}
+	imp, ok := x.Obj.Decl.(*ast.ImportSpec)
+	if !ok {
+		return PkgRef{}, false
+	}
+	pkgPath, err := strconv.Unquote(imp.Path.Value)
+	if err != nil {
+		return PkgRef{}, false
+	}
+	if !strings.HasPrefix(pkgPath, v.pkgPrefix) {
+		return PkgRef{}, false
+	}
+	pkgPath = pkgPath[len(v.pkgPrefix):]
+	return PkgRef{Pkg: pkgPath, Name: sel.Sel.Name}, true
+}
+
 func (v *refsSaver) Visit(n ast.Node) ast.Visitor {
 	switch n := n.(type) {
 	case *ast.SelectorExpr:
 		v.Names[n.Sel.Name] = struct{}{}
-		if x, ok := n.X.(*ast.Ident); ok && x.Obj != nil {
-			if imp, ok := x.Obj.Decl.(*ast.ImportSpec); ok {
-				pkgPath, err := strconv.Unquote(imp.Path.Value)
-				if err != nil {
-					return nil
-				}
-				if strings.HasPrefix(pkgPath, v.pkgPrefix) {
-					pkgPath = pkgPath[len(v.pkgPrefix):]
-					ref := PkgRef{Pkg: pkgPath, Name: n.Sel.Name}
-					if _, exists := v.refMap[ref]; !exists {
-						v.refMap[ref] = struct{}{}
-						v.Refs = append(v.Refs, ref)
-					}
-				}
-				return nil
+		if ref, ok := v.parseRef(n); ok {
+			if _, exists := v.refMap[ref]; !exists {
+				v.refMap[ref] = struct{}{}
+				v.Refs = append(v.Refs, ref)
 			}
 		}
+		return nil
 	case *ast.FuncDecl:
 		if n.Recv != nil { // Methods
 			v.Names[n.Name.Name] = struct{}{}
diff --git a/internal/importers/ast_test.go b/internal/importers/ast_test.go
index 92ba531..ee9bc8c 100644
--- a/internal/importers/ast_test.go
+++ b/internal/importers/ast_test.go
@@ -3,6 +3,7 @@
 import (
 	"go/parser"
 	"go/token"
+	"reflect"
 	"testing"
 )
 
@@ -12,6 +13,10 @@
 import "Prefix/some/pkg/Name"
 
 const c = Name.Constant
+
+type T struct {
+	Name.Type
+}
 `
 	fset := token.NewFileSet()
 	f, err := parser.ParseFile(fset, "ast_test.go", file, parser.AllErrors)
@@ -22,8 +27,8 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	if len(refs.Refs) != 1 {
-		t.Fatalf("expected 1 reference; got %d", len(refs.Refs))
+	if len(refs.Refs) != 2 {
+		t.Fatalf("expected 2 references; got %d", len(refs.Refs))
 	}
 	got := refs.Refs[0]
 	if exp := (PkgRef{"some/pkg/Name", "Constant"}); exp != got {
@@ -32,4 +37,16 @@
 	if _, exists := refs.Names["Constant"]; !exists {
 		t.Errorf("expected \"Constant\" in the names set")
 	}
+	if len(refs.Embedders) != 1 {
+		t.Fatalf("expected 1 struct; got %d", len(refs.Embedders))
+	}
+	s := refs.Embedders[0]
+	exp := Struct{
+		Name: "T",
+		Pkg:  "ast_test",
+		Refs: []PkgRef{{"some/pkg/Name", "Type"}},
+	}
+	if !reflect.DeepEqual(exp, s) {
+		t.Errorf("expected struct %v; got %v", exp, got)
+	}
 }