bind,internal/importers: add Unwrap methods to unwrap Java wrappers

For Java classes implemented in Go, it is useful to take a Java instance
and extract its wrapped Go instance. For example, consider the
java.lang.Runnable implementation wrapping a Go function:

package somepkg

type GoRunnable struct {
    lang.Runnable
    f func()
}

Java methods that take a java.lang.Runnable cannot directly take a
*GoRunnable, so this CL adds a Unwrap method:

import gorun "Java/somepkg/GoRunnable"

...

r := gorun.New()
r.Unwrap().(*GoRunnable).f = func() { ... }
javapkg.Run(r)

The extra interface conversion is unfortunately needed to avoid
import cycles.

Change-Id: Ib775a5712cd25aa75a19d364a55d76b1e11dce77
Reviewed-on: https://go-review.googlesource.com/35295
Reviewed-by: David Crawshaw <crawshaw@golang.org>
diff --git a/bind/bind_test.go b/bind/bind_test.go
index b9fe99a..ed95611 100644
--- a/bind/bind_test.go
+++ b/bind/bind_test.go
@@ -14,6 +14,7 @@
 	"log"
 	"os"
 	"os/exec"
+	"path"
 	"path/filepath"
 	"runtime"
 	"strings"
@@ -65,6 +66,10 @@
 	if err != nil {
 		t.Fatalf("%s: %v", filename, err)
 	}
+	fakePath := path.Dir(filename)
+	for i := range refs.Embedders {
+		refs.Embedders[i].PkgPath = fakePath
+	}
 	return refs
 }
 
@@ -299,11 +304,7 @@
 						Buf:        new(bytes.Buffer),
 					},
 				}
-				var genNames []string
-				for _, emb := range refs.Embedders {
-					genNames = append(genNames, emb.Pkg+"."+emb.Name)
-				}
-				cg.Init(classes, genNames)
+				cg.Init(classes, refs.Embedders)
 				genJavaPackages(t, tmpGopath, cg)
 				cg.Buf = &buf
 			}
@@ -419,11 +420,7 @@
 				Buf:        &buf,
 			},
 		}
-		var genNames []string
-		for _, emb := range refs.Embedders {
-			genNames = append(genNames, emb.Pkg+"."+emb.Name)
-		}
-		cg.Init(classes, genNames)
+		cg.Init(classes, refs.Embedders)
 		genJavaPackages(t, tmpGopath, cg)
 		pkg := typeCheck(t, filename, tmpGopath)
 		cg.GenGo()
diff --git a/bind/genclasses.go b/bind/genclasses.go
index 7e99683..6be2fee 100644
--- a/bind/genclasses.go
+++ b/bind/genclasses.go
@@ -12,6 +12,7 @@
 	"unicode"
 	"unicode/utf8"
 
+	"golang.org/x/mobile/internal/importers"
 	"golang.org/x/mobile/internal/importers/java"
 )
 
@@ -25,6 +26,9 @@
 	// will work.
 	ClassGen struct {
 		*Printer
+		// JavaPkg is the Java package prefix for the generated classes. The prefix is prepended to the Go
+		// package name to create the full Java package name.
+		JavaPkg  string
 		imported map[string]struct{}
 		// The list of imported Java classes
 		classes []*java.Class
@@ -35,8 +39,12 @@
 		// For each Go package path, the Java class with static functions
 		// or constants.
 		clsPkgs map[string]*java.Class
-		// supers is the map of classes that need Super methods
-		supers map[string]struct{}
+		// goClsMap is the map of Java class names to Go type names, qualified with package name. Go types
+		// that implement Java classes need Super methods and Unwrap methods.
+		goClsMap map[string]string
+		// goClsImports is the list of imports of user packages that contains the Go types implementing Java
+		// classes.
+		goClsImports []string
 	}
 )
 
@@ -110,12 +118,22 @@
 }
 
 // Init initializes the class wrapper generator. Classes is the
-// list of classes to wrap, supers is the list of class names
-// that need Super methods.
-func (g *ClassGen) Init(classes []*java.Class, supers []string) {
-	g.supers = make(map[string]struct{})
-	for _, s := range supers {
-		g.supers[s] = struct{}{}
+// list of classes to wrap, goClasses is the list of Java classes
+// implemented in Go.
+func (g *ClassGen) Init(classes []*java.Class, goClasses []importers.Struct) {
+	g.goClsMap = make(map[string]string)
+	impMap := make(map[string]struct{})
+	for _, s := range goClasses {
+		n := s.Pkg + "." + s.Name
+		jn := n
+		if g.JavaPkg != "" {
+			jn = g.JavaPkg + "." + jn
+		}
+		g.goClsMap[jn] = n
+		if _, exists := impMap[s.PkgPath]; !exists {
+			impMap[s.PkgPath] = struct{}{}
+			g.goClsImports = append(g.goClsImports, s.PkgPath)
+		}
 	}
 	g.classes = classes
 	g.imported = make(map[string]struct{})
@@ -194,6 +212,9 @@
 		pkgName := strings.Replace(cls.Name, ".", "/", -1)
 		g.Printf("import %q\n", "Java/"+pkgName)
 	}
+	for _, imp := range g.goClsImports {
+		g.Printf("import %q\n", imp)
+	}
 	if len(g.classes) > 0 {
 		g.Printf("import \"unsafe\"\n\n")
 		g.Printf("import \"reflect\"\n\n")
@@ -235,7 +256,7 @@
 				g.Printf("extern ")
 				g.genCMethodDecl("cproxy", cls.JNIName, f)
 				g.Printf(";\n")
-				if _, ok := g.supers[cls.Name]; ok {
+				if _, ok := g.goClsMap[cls.Name]; ok {
 					g.Printf("extern ")
 					g.genCMethodDecl("csuper", cls.JNIName, f)
 					g.Printf(";\n")
@@ -252,7 +273,7 @@
 	g.Printf(classesCHeader)
 	for _, cls := range g.classes {
 		g.Printf("static jclass class_%s;\n", cls.JNIName)
-		if _, ok := g.supers[cls.Name]; ok {
+		if _, ok := g.goClsMap[cls.Name]; ok {
 			g.Printf("static jclass sclass_%s;\n", cls.JNIName)
 		}
 		for _, fs := range cls.Funcs {
@@ -267,7 +288,7 @@
 			for _, f := range fs.Funcs {
 				if g.isFuncSupported(f) {
 					g.Printf("static jmethodID m_%s_%s;\n", cls.JNIName, f.JNIName)
-					if _, ok := g.supers[cls.Name]; ok {
+					if _, ok := g.goClsMap[cls.Name]; ok {
 						g.Printf("static jmethodID sm_%s_%s;\n", cls.JNIName, f.JNIName)
 					}
 				}
@@ -283,7 +304,7 @@
 	for _, cls := range g.classes {
 		g.Printf("clazz = (*env)->FindClass(env, %q);\n", strings.Replace(cls.FindName, ".", "/", -1))
 		g.Printf("class_%s = (*env)->NewGlobalRef(env, clazz);\n", cls.JNIName)
-		if _, ok := g.supers[cls.Name]; ok {
+		if _, ok := g.goClsMap[cls.Name]; ok {
 			g.Printf("sclass_%s = (*env)->GetSuperclass(env, clazz);\n", cls.JNIName)
 			g.Printf("sclass_%s = (*env)->NewGlobalRef(env, sclass_%s);\n", cls.JNIName, cls.JNIName)
 		}
@@ -304,7 +325,7 @@
 			for _, f := range fs.Funcs {
 				if g.isFuncSupported(f) {
 					g.Printf("m_%s_%s = go_seq_get_method_id(clazz, %q, %q);\n", cls.JNIName, f.JNIName, f.Name, f.Desc)
-					if _, ok := g.supers[cls.Name]; ok {
+					if _, ok := g.goClsMap[cls.Name]; ok {
 						g.Printf("sm_%s_%s = go_seq_get_method_id(sclass_%s, %q, %q);\n", cls.JNIName, f.JNIName, cls.JNIName, f.Name, f.Desc)
 					}
 				}
@@ -322,7 +343,7 @@
 				}
 				g.genCMethodDecl("cproxy", cls.JNIName, f)
 				g.genCMethodBody(cls, f, false)
-				if _, ok := g.supers[cls.Name]; ok {
+				if _, ok := g.goClsMap[cls.Name]; ok {
 					g.genCMethodDecl("csuper", cls.JNIName, f)
 					g.genCMethodBody(cls, f, true)
 				}
@@ -561,11 +582,17 @@
 		g.Printf("	return p.ToString()\n")
 		g.Printf("}\n")
 	}
-	if _, ok := g.supers[cls.Name]; ok {
+	if goName, ok := g.goClsMap[cls.Name]; ok {
 		g.Printf("func (p *proxy_class_%s) Super() Java.%s {\n", cls.JNIName, goClsName(cls.Name))
 		g.Printf("	return &super_%s{p}\n", cls.JNIName)
 		g.Printf("}\n\n")
 		g.Printf("type super_%s struct {*proxy_class_%[1]s}\n\n", cls.JNIName)
+		g.Printf("func (p *proxy_class_%s) Unwrap() interface{} {\n", cls.JNIName)
+		g.Indent()
+		g.Printf("goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))\n")
+		g.Printf("return _seq.FromRefNum(int32(goRefnum)).Get().(*%s)\n", goName)
+		g.Outdent()
+		g.Printf("}\n\n")
 		for _, fs := range cls.AllMethods {
 			if !g.isFuncSetSupported(fs) {
 				continue
@@ -847,8 +874,13 @@
 		g.genFuncDecl(true, fs)
 		g.Printf("\n")
 	}
-	if _, ok := g.supers[cls.Name]; ok {
+	if goName, ok := g.goClsMap[cls.Name]; ok {
 		g.Printf("Super() %s\n", goClsName(cls.Name))
+		g.Printf("// Unwrap returns the Go object this Java instance\n")
+		g.Printf("// is wrapping.\n")
+		g.Printf("// The return value is a %s, but the delclared type is\n", goName)
+		g.Printf("// interface{} to avoid import cycles.\n")
+		g.Printf("Unwrap() interface{}\n")
 	}
 	if cls.Throwable {
 		g.Printf("Error() string\n")
diff --git a/bind/java/ClassesTest.java b/bind/java/ClassesTest.java
index c8264df..00f02ed 100644
--- a/bind/java/ClassesTest.java
+++ b/bind/java/ClassesTest.java
@@ -18,6 +18,7 @@
 import javapkg.GoRunnable;
 import javapkg.GoSubset;
 import javapkg.GoInputStream;
+import javapkg.GoArrayList;
 
 public class ClassesTest extends InstrumentationTestCase {
 	public void testConst() {
@@ -148,4 +149,9 @@
 		Runnable r4c = Javapkg.castRunnable(new Object());
 		assertTrue("Invalid cast", r4c == null);
 	}
+
+	public void testUnwrap() {
+		GoArrayList l = new GoArrayList();
+		Javapkg.unwrapGoArrayList(l);
+	}
 }
diff --git a/bind/java/seq.h b/bind/java/seq.h
index 1f36a1c..84c1dbb 100644
--- a/bind/java/seq.h
+++ b/bind/java/seq.h
@@ -34,6 +34,9 @@
 
 extern void go_seq_dec_ref(int32_t ref);
 extern void go_seq_inc_ref(int32_t ref);
+// go_seq_unwrap takes a reference number to a Java wrapper and returns
+// a reference number to its wrapped Go object.
+extern int32_t go_seq_unwrap(jint refnum);
 extern int32_t go_seq_to_refnum(JNIEnv *env, jobject o);
 extern int32_t go_seq_to_refnum_go(JNIEnv *env, jobject o);
 extern jobject go_seq_from_refnum(JNIEnv *env, int32_t refnum, jclass proxy_class, jmethodID proxy_cons);
diff --git a/bind/java/seq_android.c.support b/bind/java/seq_android.c.support
index 8b4e60f..fbc6d55 100644
--- a/bind/java/seq_android.c.support
+++ b/bind/java/seq_android.c.support
@@ -233,6 +233,14 @@
 	return (int32_t)(*env)->CallStaticIntMethod(env, seq_class, seq_incRef, o);
 }
 
+int32_t go_seq_unwrap(jint refnum) {
+	JNIEnv *env = go_seq_push_local_frame(0);
+	jobject jobj = go_seq_from_refnum(env, refnum, NULL, NULL);
+	int32_t goref = go_seq_to_refnum_go(env, jobj);
+	go_seq_pop_local_frame(env);
+	return goref;
+}
+
 jobject go_seq_from_refnum(JNIEnv *env, int32_t refnum, jclass proxy_class, jmethodID proxy_cons) {
 	if (refnum == NULL_REFNUM) {
 		return NULL;
diff --git a/bind/testdata/classes.go.golden b/bind/testdata/classes.go.golden
index eb17775..9227ab6 100644
--- a/bind/testdata/classes.go.golden
+++ b/bind/testdata/classes.go.golden
@@ -475,22 +475,42 @@
 type Java_Future interface {
 	Get(a0 ...interface{}) (Java_lang_Object, error)
 	Super() Java_Future
+	// Unwrap returns the Go object this Java instance
+	// is wrapping.
+	// The return value is a java.Future, but the delclared type is
+	// interface{} to avoid import cycles.
+	Unwrap() interface{}
 }
 
 type Java_InputStream interface {
 	Read(a0 ...interface{}) (int32, error)
 	ToString() string
 	Super() Java_InputStream
+	// Unwrap returns the Go object this Java instance
+	// is wrapping.
+	// The return value is a java.InputStream, but the delclared type is
+	// interface{} to avoid import cycles.
+	Unwrap() interface{}
 }
 
 type Java_Object interface {
 	ToString() string
 	Super() Java_Object
+	// Unwrap returns the Go object this Java instance
+	// is wrapping.
+	// The return value is a java.Object, but the delclared type is
+	// interface{} to avoid import cycles.
+	Unwrap() interface{}
 }
 
 type Java_Runnable interface {
 	Run()
 	Super() Java_Runnable
+	// Unwrap returns the Go object this Java instance
+	// is wrapping.
+	// The return value is a java.Runnable, but the delclared type is
+	// interface{} to avoid import cycles.
+	Unwrap() interface{}
 }
 
 type Java_util_Iterator interface {
@@ -559,6 +579,7 @@
 import "Java/java/util/PrimitiveIterator/OfDouble"
 import "Java/java/util/Spliterator/OfDouble"
 import "Java/java/io/Console"
+import "testdata"
 import "unsafe"
 
 import "reflect"
@@ -1213,6 +1234,11 @@
 
 type super_java_Future struct {*proxy_class_java_Future}
 
+func (p *proxy_class_java_Future) Unwrap() interface{} {
+	goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
+	return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Future)
+}
+
 func (p *super_java_Future) Get(a0 ...interface{}) (Java.Java_lang_Object, error) {
 	switch 0 + len(a0) {
 	case 0:
@@ -1374,6 +1400,11 @@
 
 type super_java_InputStream struct {*proxy_class_java_InputStream}
 
+func (p *proxy_class_java_InputStream) Unwrap() interface{} {
+	goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
+	return _seq.FromRefNum(int32(goRefnum)).Get().(*java.InputStream)
+}
+
 func (p *super_java_InputStream) Read(a0 ...interface{}) (int32, error) {
 	switch 0 + len(a0) {
 	case 0:
@@ -1494,6 +1525,11 @@
 
 type super_java_Object struct {*proxy_class_java_Object}
 
+func (p *proxy_class_java_Object) Unwrap() interface{} {
+	goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
+	return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Object)
+}
+
 func (p *super_java_Object) ToString() string {
 	res := C.csuper_java_Object_toString(C.jint(p.Bind_proxy_refnum__()))
 	_res := decodeString(res.res)
@@ -1555,6 +1591,11 @@
 
 type super_java_Runnable struct {*proxy_class_java_Runnable}
 
+func (p *proxy_class_java_Runnable) Unwrap() interface{} {
+	goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
+	return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Runnable)
+}
+
 func (p *super_java_Runnable) Run() {
 	res := C.csuper_java_Runnable_run(C.jint(p.Bind_proxy_refnum__()))
 	var _exc error
diff --git a/bind/testpkg/javapkg/classes.go b/bind/testpkg/javapkg/classes.go
index f6cd17b..2cdcdd7 100644
--- a/bind/testpkg/javapkg/classes.go
+++ b/bind/testpkg/javapkg/classes.go
@@ -138,6 +138,10 @@
 	return new(GoArrayList)
 }
 
+func UnwrapGoArrayList(l gopkg.GoArrayList) {
+	_ = l.Unwrap().(*GoArrayList)
+}
+
 func CallSubset(s Character.Subset) {
 	s.ToString()
 }
diff --git a/cmd/gobind/gen.go b/cmd/gobind/gen.go
index b8215d0..d0399d3 100644
--- a/cmd/gobind/gen.go
+++ b/cmd/gobind/gen.go
@@ -20,6 +20,7 @@
 	"unicode/utf8"
 
 	"golang.org/x/mobile/bind"
+	"golang.org/x/mobile/internal/importers"
 	"golang.org/x/mobile/internal/importers/java"
 )
 
@@ -140,15 +141,16 @@
 	}
 }
 
-func genJavaPackages(ctx *build.Context, dir string, classes []*java.Class, genNames []string) error {
+func genJavaPackages(ctx *build.Context, dir string, classes []*java.Class, embedders []importers.Struct) error {
 	var buf bytes.Buffer
 	cg := &bind.ClassGen{
+		JavaPkg: *javaPkg,
 		Printer: &bind.Printer{
 			IndentEach: []byte("\t"),
 			Buf:        &buf,
 		},
 	}
-	cg.Init(classes, genNames)
+	cg.Init(classes, embedders)
 	pkgBase := filepath.Join(dir, "src", "Java")
 	if err := os.MkdirAll(pkgBase, 0700); err != nil {
 		return err
diff --git a/cmd/gobind/main.go b/cmd/gobind/main.go
index ef42cd1..62f19cd 100644
--- a/cmd/gobind/main.go
+++ b/cmd/gobind/main.go
@@ -75,15 +75,7 @@
 				log.Fatal(err)
 			}
 			defer os.RemoveAll(tmpGopath)
-			var genNames []string
-			for _, emb := range refs.Embedders {
-				n := emb.Pkg + "." + emb.Name
-				if *javaPkg != "" {
-					n = *javaPkg + "." + n
-				}
-				genNames = append(genNames, n)
-			}
-			if err := genJavaPackages(ctx, tmpGopath, classes, genNames); err != nil {
+			if err := genJavaPackages(ctx, tmpGopath, classes, refs.Embedders); err != nil {
 				log.Fatal(err)
 			}
 			gopath := ctx.GOPATH
diff --git a/cmd/gomobile/bind.go b/cmd/gomobile/bind.go
index 8888e3b..40259f8 100644
--- a/cmd/gomobile/bind.go
+++ b/cmd/gomobile/bind.go
@@ -415,20 +415,13 @@
 	}
 	var buf bytes.Buffer
 	g := &bind.ClassGen{
+		JavaPkg: bindJavaPkg,
 		Printer: &bind.Printer{
 			IndentEach: []byte("\t"),
 			Buf:        &buf,
 		},
 	}
-	var genNames []string
-	for _, emb := range refs.Embedders {
-		n := emb.Pkg + "." + emb.Name
-		if bindJavaPkg != "" {
-			n = bindJavaPkg + "." + n
-		}
-		genNames = append(genNames, n)
-	}
-	g.Init(classes, genNames)
+	g.Init(classes, refs.Embedders)
 	for i, jpkg := range g.Packages() {
 		pkgDir := filepath.Join(jpkgSrc, "src", "Java", jpkg)
 		if err := os.MkdirAll(pkgDir, 0700); err != nil {
diff --git a/internal/importers/ast.go b/internal/importers/ast.go
index 9c8ac2a..91a4f31 100644
--- a/internal/importers/ast.go
+++ b/internal/importers/ast.go
@@ -67,15 +67,16 @@
 // Struct is a representation of a struct type with embedded
 // types.
 type Struct struct {
-	Name string
-	Pkg  string
-	Refs []PkgRef
+	Name    string
+	Pkg     string
+	PkgPath string
+	Refs    []PkgRef
 }
 
 // PkgRef is a reference to an identifier in a package.
 type PkgRef struct {
-	Pkg  string
 	Name string
+	Pkg  string
 }
 
 type refsSaver struct {
@@ -94,7 +95,7 @@
 	// Ignore errors (from unknown packages)
 	pkg, _ := ast.NewPackage(fset, files, visitor.importer(), nil)
 	ast.Walk(visitor, pkg)
-	visitor.findEmbeddingStructs(pkg)
+	visitor.findEmbeddingStructs("", pkg)
 	return visitor.References, nil
 }
 
@@ -117,7 +118,7 @@
 		// Ignore errors (from unknown packages)
 		astpkg, _ := ast.NewPackage(fset, files, imp, nil)
 		ast.Walk(visitor, astpkg)
-		visitor.findEmbeddingStructs(astpkg)
+		visitor.findEmbeddingStructs(pkg.ImportPath, astpkg)
 	}
 	return visitor.References, nil
 }
@@ -131,7 +132,7 @@
 // type T struct {
 //     Package.Class
 // }
-func (v *refsSaver) findEmbeddingStructs(pkg *ast.Package) {
+func (v *refsSaver) findEmbeddingStructs(pkgpath string, pkg *ast.Package) {
 	var names []string
 	for _, obj := range pkg.Scope.Objects {
 		if obj.Kind != ast.Typ || !ast.IsExported(obj.Name) {
@@ -162,8 +163,9 @@
 		}
 		if len(refs) > 0 {
 			v.Embedders = append(v.Embedders, Struct{
-				Name: obj.Name,
-				Pkg:  pkg.Name,
+				Name:    obj.Name,
+				Pkg:     pkg.Name,
+				PkgPath: pkgpath,
 
 				Refs: refs,
 			})
diff --git a/internal/importers/java/java.go b/internal/importers/java/java.go
index 8cbc6ae..b9bf3bd 100644
--- a/internal/importers/java/java.go
+++ b/internal/importers/java/java.go
@@ -237,10 +237,11 @@
 		}
 		clsSet[n] = struct{}{}
 		cls := &Class{
-			Name:     n,
-			FindName: n,
-			JNIName:  JNIMangle(n),
-			PkgName:  emb.Name,
+			Name:        n,
+			FindName:    n,
+			JNIName:     JNIMangle(n),
+			PkgName:     emb.Name,
+			HasNoArgCon: true,
 		}
 		for _, ref := range emb.Refs {
 			jpkg := strings.Replace(ref.Pkg, "/", ".", -1)
diff --git a/internal/importers/java/java_test.go b/internal/importers/java/java_test.go
index ab0f50f..ff59c8b 100644
--- a/internal/importers/java/java_test.go
+++ b/internal/importers/java/java_test.go
@@ -21,7 +21,7 @@
 		methods []*FuncSet
 	}{
 		{
-			ref:  importers.PkgRef{"java/lang/Object", "equals"},
+			ref:  importers.PkgRef{Pkg: "java/lang/Object", Name: "equals"},
 			name: "java.lang.Object",
 			methods: []*FuncSet{
 				&FuncSet{
@@ -35,7 +35,7 @@
 			},
 		},
 		{
-			ref:  importers.PkgRef{"java/lang/Runnable", "run"},
+			ref:  importers.PkgRef{Pkg: "java/lang/Runnable", Name: "run"},
 			name: "java.lang.Runnable",
 			methods: []*FuncSet{
 				&FuncSet{