go/ssa: Update callee for wrapper function instantiation.

Updates golang/go#48525

Change-Id: Iee30bee08f124118d22524e276762389c8358244
Reviewed-on: https://go-review.googlesource.com/c/tools/+/400374
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index 84f0692..3fd9a8a 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -565,6 +565,129 @@
 	}
 }
 
+func TestGenericWrappers(t *testing.T) {
+	if !typeparams.Enabled {
+		t.Skip("TestGenericWrappers only works with type parameters enabled.")
+	}
+	const input = `
+package p
+
+type S[T any] struct {
+	t *T
+}
+
+func (x S[T]) M() T {
+	return *(x.t)
+}
+
+var thunk = S[int].M
+
+var g S[int]
+var bound = g.M
+
+type R[T any] struct{ S[T] }
+
+var indirect = R[int].M
+`
+	// The relevant SSA members for this package should look something like this:
+	// var   bound      func() int
+	// var   thunk      func(S[int]) int
+	// var   wrapper    func(R[int]) int
+
+	// Parse
+	var conf loader.Config
+	f, err := conf.ParseFile("<input>", input)
+	if err != nil {
+		t.Fatalf("parse: %v", err)
+	}
+	conf.CreateFromFiles("p", f)
+
+	// Load
+	lprog, err := conf.Load()
+	if err != nil {
+		t.Fatalf("Load: %v", err)
+	}
+
+	// Create and build SSA
+	prog := ssautil.CreateProgram(lprog, 0)
+	p := prog.Package(lprog.Package("p").Pkg)
+	p.Build()
+
+	for _, entry := range []struct {
+		name    string // name of the package variable
+		typ     string // type of the package variable
+		wrapper string // wrapper function to which the package variable is set
+		callee  string // callee within the wrapper function
+	}{
+		{
+			"bound",
+			"*func() int",
+			"(p.S[int]).M$bound",
+			"(p.S[int]).M[[int]]",
+		},
+		{
+			"thunk",
+			"*func(p.S[int]) int",
+			"(p.S[int]).M$thunk",
+			"(p.S[int]).M[[int]]",
+		},
+		{
+			"indirect",
+			"*func(p.R[int]) int",
+			"(p.R[int]).M$thunk",
+			"(p.S[int]).M[[int]]",
+		},
+	} {
+		entry := entry
+		t.Run(entry.name, func(t *testing.T) {
+			v := p.Var(entry.name)
+			if v == nil {
+				t.Fatalf("Did not find variable for %q in %s", entry.name, p.String())
+			}
+			if v.Type().String() != entry.typ {
+				t.Errorf("Expected type for variable %s: %q. got %q", v, entry.typ, v.Type())
+			}
+
+			// Find the wrapper for v. This is stored exactly once in init.
+			var wrapper *ssa.Function
+			for _, bb := range p.Func("init").Blocks {
+				for _, i := range bb.Instrs {
+					if store, ok := i.(*ssa.Store); ok && v == store.Addr {
+						switch val := store.Val.(type) {
+						case *ssa.Function:
+							wrapper = val
+						case *ssa.MakeClosure:
+							wrapper = val.Fn.(*ssa.Function)
+						}
+					}
+				}
+			}
+			if wrapper == nil {
+				t.Fatalf("failed to find wrapper function for %s", entry.name)
+			}
+			if wrapper.String() != entry.wrapper {
+				t.Errorf("Expected wrapper function %q. got %q", wrapper, entry.wrapper)
+			}
+
+			// Find the callee within the wrapper. There should be exactly one call.
+			var callee *ssa.Function
+			for _, bb := range wrapper.Blocks {
+				for _, i := range bb.Instrs {
+					if call, ok := i.(*ssa.Call); ok {
+						callee = call.Call.StaticCallee()
+					}
+				}
+			}
+			if callee == nil {
+				t.Fatalf("failed to find callee within wrapper %s", wrapper)
+			}
+			if callee.String() != entry.callee {
+				t.Errorf("Expected callee in wrapper %q is %q. got %q", v, entry.callee, callee)
+			}
+		})
+	}
+}
+
 // TestTypeparamTest builds SSA over compilable examples in $GOROOT/test/typeparam/*.go.
 
 func TestTypeparamTest(t *testing.T) {
diff --git a/go/ssa/wrappers.go b/go/ssa/wrappers.go
index 799ba14..deaa87f 100644
--- a/go/ssa/wrappers.go
+++ b/go/ssa/wrappers.go
@@ -126,7 +126,7 @@
 		}
 		callee := prog.originFunc(obj)
 		if len(callee._TypeParams) > 0 {
-			prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr)
+			callee = prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr)
 		}
 		c.Call.Value = callee
 		c.Call.Args = append(c.Call.Args, v)