go/ssa: Update callee for wrapper function instantiation.
Updates golang/go#48525
Change-Id: Iee30bee08f124118d22524e276762389c8358244
Reviewed-on: https://go-review.googlesource.com/c/tools/+/400374
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index 84f0692..3fd9a8a 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -565,6 +565,129 @@
}
}
+func TestGenericWrappers(t *testing.T) {
+ if !typeparams.Enabled {
+ t.Skip("TestGenericWrappers only works with type parameters enabled.")
+ }
+ const input = `
+package p
+
+type S[T any] struct {
+ t *T
+}
+
+func (x S[T]) M() T {
+ return *(x.t)
+}
+
+var thunk = S[int].M
+
+var g S[int]
+var bound = g.M
+
+type R[T any] struct{ S[T] }
+
+var indirect = R[int].M
+`
+ // The relevant SSA members for this package should look something like this:
+ // var bound func() int
+ // var thunk func(S[int]) int
+ // var wrapper func(R[int]) int
+
+ // Parse
+ var conf loader.Config
+ f, err := conf.ParseFile("<input>", input)
+ if err != nil {
+ t.Fatalf("parse: %v", err)
+ }
+ conf.CreateFromFiles("p", f)
+
+ // Load
+ lprog, err := conf.Load()
+ if err != nil {
+ t.Fatalf("Load: %v", err)
+ }
+
+ // Create and build SSA
+ prog := ssautil.CreateProgram(lprog, 0)
+ p := prog.Package(lprog.Package("p").Pkg)
+ p.Build()
+
+ for _, entry := range []struct {
+ name string // name of the package variable
+ typ string // type of the package variable
+ wrapper string // wrapper function to which the package variable is set
+ callee string // callee within the wrapper function
+ }{
+ {
+ "bound",
+ "*func() int",
+ "(p.S[int]).M$bound",
+ "(p.S[int]).M[[int]]",
+ },
+ {
+ "thunk",
+ "*func(p.S[int]) int",
+ "(p.S[int]).M$thunk",
+ "(p.S[int]).M[[int]]",
+ },
+ {
+ "indirect",
+ "*func(p.R[int]) int",
+ "(p.R[int]).M$thunk",
+ "(p.S[int]).M[[int]]",
+ },
+ } {
+ entry := entry
+ t.Run(entry.name, func(t *testing.T) {
+ v := p.Var(entry.name)
+ if v == nil {
+ t.Fatalf("Did not find variable for %q in %s", entry.name, p.String())
+ }
+ if v.Type().String() != entry.typ {
+ t.Errorf("Expected type for variable %s: %q. got %q", v, entry.typ, v.Type())
+ }
+
+ // Find the wrapper for v. This is stored exactly once in init.
+ var wrapper *ssa.Function
+ for _, bb := range p.Func("init").Blocks {
+ for _, i := range bb.Instrs {
+ if store, ok := i.(*ssa.Store); ok && v == store.Addr {
+ switch val := store.Val.(type) {
+ case *ssa.Function:
+ wrapper = val
+ case *ssa.MakeClosure:
+ wrapper = val.Fn.(*ssa.Function)
+ }
+ }
+ }
+ }
+ if wrapper == nil {
+ t.Fatalf("failed to find wrapper function for %s", entry.name)
+ }
+ if wrapper.String() != entry.wrapper {
+ t.Errorf("Expected wrapper function %q. got %q", wrapper, entry.wrapper)
+ }
+
+ // Find the callee within the wrapper. There should be exactly one call.
+ var callee *ssa.Function
+ for _, bb := range wrapper.Blocks {
+ for _, i := range bb.Instrs {
+ if call, ok := i.(*ssa.Call); ok {
+ callee = call.Call.StaticCallee()
+ }
+ }
+ }
+ if callee == nil {
+ t.Fatalf("failed to find callee within wrapper %s", wrapper)
+ }
+ if callee.String() != entry.callee {
+ t.Errorf("Expected callee in wrapper %q is %q. got %q", v, entry.callee, callee)
+ }
+ })
+ }
+}
+
// TestTypeparamTest builds SSA over compilable examples in $GOROOT/test/typeparam/*.go.
func TestTypeparamTest(t *testing.T) {
diff --git a/go/ssa/wrappers.go b/go/ssa/wrappers.go
index 799ba14..deaa87f 100644
--- a/go/ssa/wrappers.go
+++ b/go/ssa/wrappers.go
@@ -126,7 +126,7 @@
}
callee := prog.originFunc(obj)
if len(callee._TypeParams) > 0 {
- prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr)
+ callee = prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr)
}
c.Call.Value = callee
c.Call.Args = append(c.Call.Args, v)