blob: 6e9a088f77906dba4cdeebc512f9b0148b051264 [file] [log] [blame]
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bind
import (
"fmt"
"go/token"
"log"
"strings"
"golang.org/x/tools/go/types"
)
type goGen struct {
*printer
fset *token.FileSet
pkg *types.Package
err ErrorList
}
func (g *goGen) errorf(format string, args ...interface{}) {
g.err = append(g.err, fmt.Errorf(format, args...))
}
const goPreamble = `// Package go_%s is an autogenerated binder stub for package %s.
// gobind -lang=go %s
//
// File is generated by gobind. Do not edit.
package go_%s
import (
"golang.org/x/mobile/bind/seq"
%q
)
`
func (g *goGen) genPreamble() {
n := g.pkg.Name()
g.Printf(goPreamble, n, n, g.pkg.Path(), n, g.pkg.Path())
}
func (g *goGen) genFuncBody(o *types.Func, selectorLHS string) {
sig := o.Type().(*types.Signature)
params := sig.Params()
for i := 0; i < params.Len(); i++ {
p := params.At(i)
g.genRead("param_"+p.Name(), "in", p.Type())
}
res := sig.Results()
if res.Len() > 2 || res.Len() == 2 && !isErrorType(res.At(1).Type()) {
g.errorf("functions and methods must return either zero or one values, and optionally an error")
return
}
returnsValue := false
returnsError := false
if res.Len() == 1 {
if isErrorType(res.At(0).Type()) {
returnsError = true
g.Printf("err := ")
} else {
returnsValue = true
g.Printf("res := ")
}
} else if res.Len() == 2 {
returnsValue = true
returnsError = true
g.Printf("res, err := ")
}
g.Printf("%s.%s(", selectorLHS, o.Name())
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.Printf(", ")
}
g.Printf("param_%s", params.At(i).Name())
}
g.Printf(")\n")
if returnsValue {
g.genWrite("res", "out", res.At(0).Type())
}
if returnsError {
g.genWrite("err", "out", res.At(res.Len()-1).Type())
}
}
func (g *goGen) genWrite(valName, seqName string, T types.Type) {
if isErrorType(T) {
g.Printf("if %s == nil {\n", valName)
g.Printf(" %s.WriteUTF16(\"\");\n", seqName)
g.Printf("} else {\n")
g.Printf(" %s.WriteUTF16(%s.Error());\n", seqName, valName)
g.Printf("}\n")
return
}
switch T := T.(type) {
case *types.Pointer:
// TODO(crawshaw): test *int
// TODO(crawshaw): test **Generator
switch T := T.Elem().(type) {
case *types.Named:
obj := T.Obj()
if obj.Pkg() != g.pkg {
g.errorf("type %s not defined in package %s", T, g.pkg)
return
}
g.Printf("%s.WriteGoRef(%s)\n", seqName, valName)
default:
g.errorf("unsupported type %s", T)
}
case *types.Named:
switch u := T.Underlying().(type) {
case *types.Interface, *types.Pointer:
g.Printf("%s.WriteGoRef(%s)\n", seqName, valName)
default:
g.errorf("unsupported, direct named type %s: %s", T, u)
}
default:
g.Printf("%s.Write%s(%s);\n", seqName, seqType(T), valName)
}
}
func (g *goGen) genFunc(o *types.Func) {
g.Printf("func proxy_%s(out, in *seq.Buffer) {\n", o.Name())
g.Indent()
g.genFuncBody(o, g.pkg.Name())
g.Outdent()
g.Printf("}\n\n")
}
func exportedMethodSet(T types.Type) []*types.Func {
var methods []*types.Func
methodset := types.NewMethodSet(T)
for i := 0; i < methodset.Len(); i++ {
obj := methodset.At(i).Obj()
if !obj.Exported() {
continue
}
switch obj := obj.(type) {
case *types.Func:
methods = append(methods, obj)
default:
log.Panicf("unexpected methodset obj: %s", obj)
}
}
return methods
}
func exportedFields(T *types.Struct) []*types.Var {
var fields []*types.Var
for i := 0; i < T.NumFields(); i++ {
f := T.Field(i)
if !f.Exported() {
continue
}
fields = append(fields, f)
}
return fields
}
func (g *goGen) genStruct(obj *types.TypeName, T *types.Struct) {
fields := exportedFields(T)
methods := exportedMethodSet(types.NewPointer(obj.Type()))
g.Printf("const (\n")
g.Indent()
g.Printf("proxy%sDescriptor = \"go.%s.%s\"\n", obj.Name(), g.pkg.Name(), obj.Name())
for i, f := range fields {
g.Printf("proxy%s%sGetCode = 0x%x0f\n", obj.Name(), f.Name(), i)
g.Printf("proxy%s%sSetCode = 0x%x1f\n", obj.Name(), f.Name(), i)
}
for i, m := range methods {
g.Printf("proxy%s%sCode = 0x%x0c\n", obj.Name(), m.Name(), i)
}
g.Outdent()
g.Printf(")\n\n")
g.Printf("type proxy%s seq.Ref\n\n", obj.Name())
for _, f := range fields {
g.Printf("func proxy%s%sSet(out, in *seq.Buffer) {\n", obj.Name(), f.Name())
g.Indent()
g.Printf("ref := in.ReadRef()\n")
g.Printf("v := in.Read%s()\n", seqType(f.Type()))
// TODO(crawshaw): other kinds of non-ptr types.
g.Printf("ref.Get().(*%s.%s).%s = v\n", g.pkg.Name(), obj.Name(), f.Name())
g.Outdent()
g.Printf("}\n\n")
g.Printf("func proxy%s%sGet(out, in *seq.Buffer) {\n", obj.Name(), f.Name())
g.Indent()
g.Printf("ref := in.ReadRef()\n")
g.Printf("v := ref.Get().(*%s.%s).%s\n", g.pkg.Name(), obj.Name(), f.Name())
g.Printf("out.Write%s(v)\n", seqType(f.Type()))
g.Outdent()
g.Printf("}\n\n")
}
for _, m := range methods {
g.Printf("func proxy%s%s(out, in *seq.Buffer) {\n", obj.Name(), m.Name())
g.Indent()
g.Printf("ref := in.ReadRef()\n")
g.Printf("v := ref.Get().(*%s.%s)\n", g.pkg.Name(), obj.Name())
g.genFuncBody(m, "v")
g.Outdent()
g.Printf("}\n\n")
}
g.Printf("func init() {\n")
g.Indent()
for _, f := range fields {
n := f.Name()
g.Printf("seq.Register(proxy%sDescriptor, proxy%s%sSetCode, proxy%s%sSet)\n", obj.Name(), obj.Name(), n, obj.Name(), n)
g.Printf("seq.Register(proxy%sDescriptor, proxy%s%sGetCode, proxy%s%sGet)\n", obj.Name(), obj.Name(), n, obj.Name(), n)
}
for _, m := range methods {
n := m.Name()
g.Printf("seq.Register(proxy%sDescriptor, proxy%s%sCode, proxy%s%s)\n", obj.Name(), obj.Name(), n, obj.Name(), n)
}
g.Outdent()
g.Printf("}\n\n")
}
func (g *goGen) genInterface(obj *types.TypeName) {
iface := obj.Type().(*types.Named).Underlying().(*types.Interface)
// Descriptor and code for interface methods.
g.Printf("const (\n")
g.Indent()
g.Printf("proxy%sDescriptor = \"go.%s.%s\"\n", obj.Name(), g.pkg.Name(), obj.Name())
for i := 0; i < iface.NumMethods(); i++ {
g.Printf("proxy%s%sCode = 0x%x0a\n", obj.Name(), iface.Method(i).Name(), i+1)
}
g.Outdent()
g.Printf(")\n\n")
// Define the entry points.
for i := 0; i < iface.NumMethods(); i++ {
m := iface.Method(i)
g.Printf("func proxy%s%s(out, in *seq.Buffer) {\n", obj.Name(), m.Name())
g.Indent()
g.Printf("ref := in.ReadRef()\n")
g.Printf("v := ref.Get().(%s.%s)\n", g.pkg.Name(), obj.Name())
g.genFuncBody(m, "v")
g.Outdent()
g.Printf("}\n\n")
}
// Register the method entry points.
g.Printf("func init() {\n")
g.Indent()
for i := 0; i < iface.NumMethods(); i++ {
g.Printf("seq.Register(proxy%sDescriptor, proxy%s%sCode, proxy%s%s)\n",
obj.Name(), obj.Name(), iface.Method(i).Name(), obj.Name(), iface.Method(i).Name())
}
g.Outdent()
g.Printf("}\n\n")
// Define a proxy interface.
g.Printf("type proxy%s seq.Ref\n\n", obj.Name())
for i := 0; i < iface.NumMethods(); i++ {
m := iface.Method(i)
sig := m.Type().(*types.Signature)
params := sig.Params()
res := sig.Results()
if res.Len() > 2 ||
(res.Len() == 2 && !isErrorType(res.At(1).Type())) {
g.errorf("functions and methods must return either zero or one value, and optionally an error: %s.%s", obj.Name(), m.Name())
continue
}
g.Printf("func (p *proxy%s) %s(", obj.Name(), m.Name())
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.Printf(", ")
}
g.Printf("%s %s", paramName(params, i), g.typeString(params.At(i).Type()))
}
g.Printf(") ")
if res.Len() == 1 {
g.Printf(g.typeString(res.At(0).Type()))
} else if res.Len() == 2 {
g.Printf("(%s, error)", g.typeString(res.At(0).Type()))
}
g.Printf(" {\n")
g.Indent()
g.Printf("in := new(seq.Buffer)\n")
for i := 0; i < params.Len(); i++ {
g.genWrite(paramName(params, i), "in", params.At(i).Type())
}
if res.Len() == 0 {
g.Printf("seq.Transact((*seq.Ref)(p), proxy%s%sCode, in)\n", obj.Name(), m.Name())
} else {
g.Printf("out := seq.Transact((*seq.Ref)(p), proxy%s%sCode, in)\n", obj.Name(), m.Name())
var rvs []string
for i := 0; i < res.Len(); i++ {
rv := fmt.Sprintf("res_%d", i)
g.genRead(rv, "out", res.At(i).Type())
rvs = append(rvs, rv)
}
g.Printf("return %s\n", strings.Join(rvs, ","))
}
g.Outdent()
g.Printf("}\n\n")
}
}
func (g *goGen) genRead(valName, seqName string, typ types.Type) {
if isErrorType(typ) {
g.Printf("%s := %s.ReadError()\n", valName, seqName)
return
}
switch t := typ.(type) {
case *types.Pointer:
switch u := t.Elem().(type) {
case *types.Named:
o := u.Obj()
if o.Pkg() != g.pkg {
g.errorf("type %s not defined in package %s", u, g.pkg)
return
}
g.Printf("// Must be a Go object\n")
g.Printf("%s_ref := %s.ReadRef()\n", valName, seqName)
g.Printf("%s := %s_ref.Get().(*%s.%s)\n", valName, valName, g.pkg.Name(), o.Name())
default:
g.errorf("unsupported type %s", t)
}
case *types.Named:
switch t.Underlying().(type) {
case *types.Interface, *types.Pointer:
o := t.Obj()
if o.Pkg() != g.pkg {
g.errorf("type %s not defined in package %s", t, g.pkg)
return
}
g.Printf("var %s %s\n", valName, g.typeString(t))
g.Printf("%s_ref := %s.ReadRef()\n", valName, seqName)
g.Printf("if %s_ref.Num < 0 { // go object \n", valName)
g.Printf(" %s = %s_ref.Get().(%s.%s)\n", valName, valName, g.pkg.Name(), o.Name())
g.Printf("} else { // foreign object \n")
g.Printf(" %s = (*proxy%s)(%s_ref)\n", valName, o.Name(), valName)
g.Printf("}\n")
}
default:
g.Printf("%s := %s.Read%s()\n", valName, seqName, seqType(t))
}
}
func (g *goGen) typeString(typ types.Type) string {
pkg := g.pkg
switch t := typ.(type) {
case *types.Named:
obj := t.Obj()
if obj.Pkg() == nil { // e.g. error type is *types.Named.
return types.TypeString(pkg, typ)
}
if obj.Pkg() != g.pkg {
g.errorf("type %s not defined in package %s", t, g.pkg)
}
switch t.Underlying().(type) {
case *types.Interface, *types.Struct:
return fmt.Sprintf("%s.%s", pkg.Name(), types.TypeString(pkg, typ))
default:
g.errorf("unsupported named type %s / %T", t, t)
}
case *types.Pointer:
switch t := t.Elem().(type) {
case *types.Named:
return fmt.Sprintf("*%s", g.typeString(t))
default:
g.errorf("not yet supported, pointer type %s / %T", t, t)
}
default:
return types.TypeString(pkg, typ)
}
return ""
}
func (g *goGen) gen() error {
g.genPreamble()
var funcs []string
scope := g.pkg.Scope()
names := scope.Names()
for _, name := range names {
obj := scope.Lookup(name)
if !obj.Exported() {
continue
}
switch obj := obj.(type) {
// TODO(crawshaw): case *types.Const:
// TODO(crawshaw): case *types.Var:
case *types.Func:
g.genFunc(obj)
funcs = append(funcs, obj.Name())
case *types.TypeName:
named := obj.Type().(*types.Named)
switch T := named.Underlying().(type) {
case *types.Struct:
g.genStruct(obj, T)
case *types.Interface:
g.genInterface(obj)
}
default:
g.errorf("not yet supported, name for %v / %T", obj, obj)
continue
}
}
g.Printf("func init() {\n")
g.Indent()
for i, name := range funcs {
g.Printf("seq.Register(%q, %d, proxy_%s)\n", g.pkg.Name(), i+1, name)
}
g.Outdent()
g.Printf("}\n")
if len(g.err) > 0 {
return g.err
}
return nil
}