go/ssa: fixing Named type substitution

Reorganizes how type substitution is done for Named types.
This fixes several issues with Named types support for
parameterized types defined within parameterized functions.

Additionally supports recursive substitution of a type parameter
that is not a type parameter being substituted.

Fixes golang/go#66783

Change-Id: I31478e622e854d58620687c1964cf8b254bf419f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/581835
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index 062a221..fdd98f4 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -1186,3 +1186,28 @@
 		pkg.Build()
 	}
 }
+
+func TestFixedBugs(t *testing.T) {
+	for _, name := range []string{
+		"issue66783a",
+		"issue66783b",
+	} {
+
+		t.Run(name, func(t *testing.T) {
+			base := name + ".go"
+			path := filepath.Join(analysistest.TestData(), "fixedbugs", base)
+			fset := token.NewFileSet()
+			f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+			if err != nil {
+				t.Fatal(err)
+			}
+			files := []*ast.File{f}
+			pkg := types.NewPackage(name, name)
+			mode := ssa.SanityCheckFunctions | ssa.InstantiateGenerics
+			// mode |= ssa.PrintFunctions // debug mode
+			if _, _, err := ssautil.BuildPackage(&types.Config{}, fset, pkg, files, mode); err != nil {
+				t.Error(err)
+			}
+		})
+	}
+}
diff --git a/go/ssa/instantiate.go b/go/ssa/instantiate.go
index e5e7162..2cd7405 100644
--- a/go/ssa/instantiate.go
+++ b/go/ssa/instantiate.go
@@ -78,8 +78,7 @@
 	if prog.mode&InstantiateGenerics != 0 && !prog.isParameterized(targs...) {
 		synthetic = fmt.Sprintf("instance of %s", fn.Name())
 		if fn.syntax != nil {
-			scope := obj.Origin().Scope()
-			subst = makeSubster(prog.ctxt, scope, fn.typeparams, targs, false)
+			subst = makeSubster(prog.ctxt, obj, fn.typeparams, targs, false)
 			build = (*builder).buildFromSyntax
 		} else {
 			build = (*builder).buildParamsOnly
diff --git a/go/ssa/interp/interp_test.go b/go/ssa/interp/interp_test.go
index 2cd7ee9..2ad6a9a 100644
--- a/go/ssa/interp/interp_test.go
+++ b/go/ssa/interp/interp_test.go
@@ -134,6 +134,7 @@
 	"fixedbugs/issue55115.go",
 	"fixedbugs/issue52835.go",
 	"fixedbugs/issue55086.go",
+	"fixedbugs/issue66783.go",
 	"typeassert.go",
 	"zeros.go",
 }
diff --git a/go/ssa/interp/testdata/fixedbugs/issue66783.go b/go/ssa/interp/testdata/fixedbugs/issue66783.go
new file mode 100644
index 0000000..e49e86d
--- /dev/null
+++ b/go/ssa/interp/testdata/fixedbugs/issue66783.go
@@ -0,0 +1,54 @@
+package main
+
+import "fmt"
+
+func Fn[N any]() (any, any, any) {
+	// Very recursive type to exercise substitution.
+	type t[x any, ignored *N] struct {
+		f  x
+		g  N
+		nx *t[x, *N]
+		nn *t[N, *N]
+	}
+	n := t[N, *N]{}
+	s := t[string, *N]{}
+	i := t[int, *N]{}
+	return n, s, i
+}
+
+func main() {
+
+	sn, ss, si := Fn[string]()
+	in, is, ii := Fn[int]()
+
+	for i, t := range []struct {
+		x, y any
+		want bool
+	}{
+		{sn, ss, true},  // main.t[string;string,*string] == main.t[string;string,*string]
+		{sn, si, false}, // main.t[string;string,*string] != main.t[string;int,*string]
+		{sn, in, false}, // main.t[string;string,*string] != main.t[int;int,*int]
+		{sn, is, false}, // main.t[string;string,*string] != main.t[int;string,*int]
+		{sn, ii, false}, // main.t[string;string,*string] != main.t[int;int,*int]
+
+		{ss, si, false}, // main.t[string;string,*string] != main.t[string;int,*string]
+		{ss, in, false}, // main.t[string;string,*string] != main.t[int;int,*int]
+		{ss, is, false}, // main.t[string;string,*string] != main.t[int;string,*int]
+		{ss, ii, false}, // main.t[string;string,*string] != main.t[int;int,*int]
+
+		{si, in, false}, // main.t[string;int,*string] != main.t[int;int,*int]
+		{si, is, false}, // main.t[string;int,*string] != main.t[int;string,*int]
+		{si, ii, false}, // main.t[string;int,*string] != main.t[int;int,*int]
+
+		{in, is, false}, // main.t[int;int,*int] != main.t[int;string,*int]
+		{in, ii, true},  // main.t[int;int,*int] == main.t[int;int,*int]
+
+		{is, ii, false}, // main.t[int;string,*int] != main.t[int;int,*int]
+	} {
+		x, y, want := t.x, t.y, t.want
+		if got := x == y; got != want {
+			msg := fmt.Sprintf("(case %d) %T == %T. got %v. wanted %v", i, x, y, got, want)
+			panic(msg)
+		}
+	}
+}
diff --git a/go/ssa/subst.go b/go/ssa/subst.go
index 6490db8..75d887d 100644
--- a/go/ssa/subst.go
+++ b/go/ssa/subst.go
@@ -7,66 +7,73 @@
 import (
 	"go/types"
 
+	"golang.org/x/tools/go/types/typeutil"
 	"golang.org/x/tools/internal/aliases"
 )
 
-// Type substituter for a fixed set of replacement types.
+// subster defines a type substitution operation of a set of type parameters
+// to type parameter free replacement types. Substitution is done within
+// the context of a package-level function instantiation. *Named types
+// declared in the function are unique to the instantiation.
 //
-// A nil *subster is an valid, empty substitution map. It always acts as
+// For example, given a parameterized function F
+//
+//	  func F[S, T any]() any {
+//	    type X struct{ s S; next *X }
+//		var p *X
+//	    return p
+//	  }
+//
+// calling the instantiation F[string, int]() returns an interface
+// value (*X[string,int], nil) where the underlying value of
+// X[string,int] is a struct{s string; next *X[string,int]}.
+//
+// A nil *subster is a valid, empty substitution map. It always acts as
 // the identity function. This allows for treating parameterized and
 // non-parameterized functions identically while compiling to ssa.
 //
 // Not concurrency-safe.
+//
+// Note: Some may find it helpful to think through some of the most
+// complex substitution cases using lambda calculus inspired notation.
+// subst.typ() solves evaluating a type expression E
+// within the body of a function Fn[m] with the type parameters m
+// once we have applied the type arguments N.
+// We can succinctly write this as a function application:
+//
+//	((λm. E) N)
+//
+// go/types does not provide this interface directly.
+// So what subster provides is a type substitution operation
+//
+//	E[m:=N]
 type subster struct {
 	replacements map[*types.TypeParam]types.Type // values should contain no type params
 	cache        map[types.Type]types.Type       // cache of subst results
-	ctxt         *types.Context                  // cache for instantiation
-	scope        *types.Scope                    // *types.Named declared within this scope can be substituted (optional)
-	debug        bool                            // perform extra debugging checks
+	origin       *types.Func                     // types.Objects declared within this origin function are unique within this context
+	ctxt         *types.Context                  // speeds up repeated instantiations
+	uniqueness   typeutil.Map                    // determines the uniqueness of the instantiations within the function
 	// TODO(taking): consider adding Pos
-	// TODO(zpavlinovic): replacements can contain type params
-	// when generating instances inside of a generic function body.
 }
 
 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
 // targs should not contain any types in tparams.
-// scope is the (optional) lexical block of the generic function for which we are substituting.
-func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
+// fn is the generic function for which we are substituting.
+func makeSubster(ctxt *types.Context, fn *types.Func, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
 	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
 
 	subst := &subster{
 		replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
 		cache:        make(map[types.Type]types.Type),
+		origin:       fn.Origin(),
 		ctxt:         ctxt,
-		scope:        scope,
-		debug:        debug,
 	}
 	for i := 0; i < tparams.Len(); i++ {
 		subst.replacements[tparams.At(i)] = targs[i]
 	}
-	if subst.debug {
-		subst.wellFormed()
-	}
 	return subst
 }
 
-// wellFormed asserts that subst was properly initialized.
-func (subst *subster) wellFormed() {
-	if subst == nil {
-		return
-	}
-	// Check that all of the type params do not appear in the arguments.
-	s := make(map[types.Type]bool, len(subst.replacements))
-	for tparam := range subst.replacements {
-		s[tparam] = true
-	}
-	for _, r := range subst.replacements {
-		if reaches(r, s) {
-			panic(subst)
-		}
-	}
-}
-
 // typ returns the type of t with the type parameter tparams[i] substituted
 // for the type targs[i] where subst was created using tparams and targs.
 func (subst *subster) typ(t types.Type) (res types.Type) {
@@ -82,9 +89,10 @@
 
 	switch t := t.(type) {
 	case *types.TypeParam:
-		r := subst.replacements[t]
-		assert(r != nil, "type param without replacement encountered")
-		return r
+		if r := subst.replacements[t]; r != nil {
+			return r
+		}
+		return t
 
 	case *types.Basic:
 		return t
@@ -194,7 +202,7 @@
 	return t
 }
 
-// varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
+// varlist returns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
 func (subst *subster) varlist(in varlist) []*types.Var {
 	var out []*types.Var // nil => no updates
 	for i, n := 0, in.Len(); i < n; i++ {
@@ -322,71 +330,146 @@
 }
 
 func (subst *subster) named(t *types.Named) types.Type {
-	// A named type may be:
-	// (1) ordinary named type (non-local scope, no type parameters, no type arguments),
-	// (2) locally scoped type,
-	// (3) generic (type parameters but no type arguments), or
-	// (4) instantiated (type parameters and type arguments).
-	tparams := t.TypeParams()
-	if tparams.Len() == 0 {
-		if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
-			// Outside the current function scope?
-			return t // case (1) ordinary
+	// A Named type is a user defined type.
+	// Ignoring generics, Named types are canonical: they are identical if
+	// and only if they have the same defining symbol.
+	// Generics complicate things, both if the type definition itself is
+	// parameterized, and if the type is defined within the scope of a
+	// parameterized function. In this case, two named types are identical if
+	// and only if their identifying symbols are identical, and all type
+	// arguments bindings in scope of the named type definition (including the
+	// type parameters of the definition itself) are equivalent.
+	//
+	// Notably:
+	// 1. For type definition type T[P1 any] struct{}, T[A] and T[B] are identical
+	//    only if A and B are identical.
+	// 2. Inside the generic func Fn[m any]() any { type T struct{}; return T{} },
+	//    the result of Fn[A] and Fn[B] have identical type if and only if A and
+	//    B are identical.
+	// 3. Both 1 and 2 could apply, such as in
+	//    func F[m any]() any { type T[x any] struct{}; return T{} }
+	//
+	// A subster replaces type parameters within a function scope, and therefore must
+	// also replace free type parameters in the definitions of local types.
+	//
+	// Note: There are some detailed notes sprinkled throughout that borrow from
+	// lambda calculus notation. These contain some over simplifying math.
+	//
+	// LC: One way to think about subster is that it is  a way of evaluating
+	//   ((λm. E) N) as E[m:=N].
+	// Each Named type t has an object *TypeName within a scope S that binds an
+	// underlying type expression U. U can refer to symbols within S (+ S's ancestors).
+	// Let x = t.TypeParams() and A = t.TypeArgs().
+	// Each Named type t is then either:
+	//   U              where len(x) == 0 && len(A) == 0
+	//   λx. U          where len(x) != 0 && len(A) == 0
+	//   ((λx. U) A)    where len(x) == len(A)
+	// In each case, we will evaluate t[m:=N].
+	tparams := t.TypeParams() // x
+	targs := t.TypeArgs()     // A
+
+	if !declaredWithin(t.Obj(), subst.origin) {
+		// t is declared outside of Fn[m].
+		//
+		// In this case, we can skip substituting t.Underlying().
+		// The underlying type cannot refer to the type parameters.
+		//
+		// LC: Let free(E) be the set of free type parameters in an expression E.
+		// Then whenever m ∉ free(E), then E = E[m:=N].
+		// t ∉ Scope(fn) so therefore m ∉ free(U) and m ∩ x = ∅.
+		if targs.Len() == 0 {
+			// t has no type arguments. So it does not need to be instantiated.
+			//
+			// This is the normal case in real Go code, where t is not parameterized,
+			// declared at some package scope, and m is a TypeParam from a parameterized
+			// function F[m] or method.
+			//
+			// LC: m ∉ free(A) lets us conclude m ∉ free(t). So t=t[m:=N].
+			return t
 		}
 
-		// case (2) locally scoped type.
-		// Create a new named type to represent this instantiation.
-		// We assume that local types of distinct instantiations of a
-		// generic function are distinct, even if they don't refer to
-		// type parameters, but the spec is unclear; see golang/go#58573.
+		// t is declared outside of Fn[m] and has type arguments.
+		// The type arguments may contain type parameters m so
+		// substitute the type arguments, and instantiate the substituted
+		// type arguments.
+		//
+		// LC: Evaluate this as ((λx. U) A') where A' = A[m := N].
+		newTArgs := subst.typelist(targs)
+		return subst.instantiate(t.Origin(), newTArgs)
+	}
+
+	// t is declared within Fn[m].
+
+	if targs.Len() == 0 { // no type arguments?
+		assert(t == t.Origin(), "local parameterized type abstraction must be an origin type")
+
+		// t has no type arguments.
+		// The underlying type of t may contain the function's type parameters,
+		// replace these, and create a new type.
 		//
 		// Subtle: We short circuit substitution and use a newly created type in
-		// subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive
-		// types during traversal. This both breaks infinite cycles and allows for
-		// constructing types with the replacement applied in subst.typ(under).
+		// subst, i.e. cache[t]=fresh, to preemptively replace t with fresh
+		// in recursive types during traversal. This both breaks infinite cycles
+		// and allows for constructing types with the replacement applied in
+		// subst.typ(U).
 		//
-		// Example:
-		// func foo[T any]() {
-		//   type linkedlist struct {
-		//     next *linkedlist
-		//     val T
-		//   }
-		// }
+		// A new copy of the Named and Typename (and constraints) per function
+		// instantiation matches the semantics of Go, which treats all function
+		// instantiations F[N] as having distinct local types.
 		//
-		// When the field `next *linkedlist` is visited during subst.typ(under),
-		// we want the substituted type for the field `next` to be `*n`.
-		n := types.NewNamed(t.Obj(), nil, nil)
-		subst.cache[t] = n
-		subst.cache[n] = n
-		n.SetUnderlying(subst.typ(t.Underlying()))
-		return n
+		// LC: x.Len()=0 can be thought of as a special case of λx. U.
+		// LC: Evaluate (λx. U)[m:=N] as (λx'. U') where U'=U[x:=x',m:=N].
+		tname := t.Obj()
+		obj := types.NewTypeName(tname.Pos(), tname.Pkg(), tname.Name(), nil)
+		fresh := types.NewNamed(obj, nil, nil)
+		var newTParams []*types.TypeParam
+		for i := 0; i < tparams.Len(); i++ {
+			cur := tparams.At(i)
+			cobj := cur.Obj()
+			cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
+			ntp := types.NewTypeParam(cname, nil)
+			subst.cache[cur] = ntp
+			newTParams = append(newTParams, ntp)
+		}
+		fresh.SetTypeParams(newTParams)
+		subst.cache[t] = fresh
+		subst.cache[fresh] = fresh
+		fresh.SetUnderlying(subst.typ(t.Underlying()))
+		// Substitute into all of the constraints after they are created.
+		for i, ntp := range newTParams {
+			bound := tparams.At(i).Constraint()
+			ntp.SetConstraint(subst.typ(bound))
+		}
+		return fresh
 	}
-	targs := t.TypeArgs()
 
-	// insts are arguments to instantiate using.
-	insts := make([]types.Type, tparams.Len())
+	// t is defined within Fn[m] and t has type arguments (an instantiation).
+	// We reduce this to the two cases above:
+	// (1) substitute the function's type parameters into t.Origin().
+	// (2) substitute t's type arguments A and instantiate the updated t.Origin() with these.
+	//
+	// LC: Evaluate ((λx. U) A)[m:=N] as (t' A') where t' = (λx. U)[m:=N] and A'=A [m:=N]
+	subOrigin := subst.typ(t.Origin())
+	subTArgs := subst.typelist(targs)
+	return subst.instantiate(subOrigin, subTArgs)
+}
 
-	// case (3) generic ==> targs.Len() == 0
-	// Instantiating a generic with no type arguments should be unreachable.
-	// Please report a bug if you encounter this.
-	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
-
-	// case (4) instantiated.
-	// Substitute into the type arguments and instantiate the replacements/
-	// Example:
-	//    type N[A any] func() A
-	//    func Foo[T](g N[T]) {}
-	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
-	//  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
-	//  to get {N with TypeArgs == {string} and typeparams == {A} }.
-	assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
-	for i, n := 0, targs.Len(); i < n; i++ {
-		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
-		insts[i] = inst
-	}
-	r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false)
+func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
+	i, err := types.Instantiate(subst.ctxt, orig, targs, false)
 	assert(err == nil, "failed to Instantiate Named type")
-	return r
+	if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
+		return c.(types.Type)
+	}
+	subst.uniqueness.Set(i, i)
+	return i
+}
+
+func (subst *subster) typelist(l *types.TypeList) []types.Type {
+	res := make([]types.Type, l.Len())
+	for i := 0; i < l.Len(); i++ {
+		res[i] = subst.typ(l.At(i))
+	}
+	return res
 }
 
 func (subst *subster) signature(t *types.Signature) types.Type {
diff --git a/go/ssa/subst_test.go b/go/ssa/subst_test.go
index 6652b1a..3c126fa 100644
--- a/go/ssa/subst_test.go
+++ b/go/ssa/subst_test.go
@@ -16,6 +16,10 @@
 	const source = `
 package P
 
+func within(){
+	// Pretend that the instantiation happens within this function.
+}
+
 type t0 int
 func (t0) f()
 type t1 interface{ f() }
@@ -55,6 +59,11 @@
 		t.Fatal(err)
 	}
 
+	within, _ := pkg.Scope().Lookup("within").(*types.Func)
+	if within == nil {
+		t.Fatal("Failed to find the function within()")
+	}
+
 	for _, test := range []struct {
 		expr string   // type expression of Named parameterized type
 		args []string // type expressions of args for named
@@ -94,7 +103,7 @@
 
 		T := tv.Type.(*types.Named)
 
-		subst := makeSubster(types.NewContext(), nil, T.TypeParams(), targs, true)
+		subst := makeSubster(types.NewContext(), within, T.TypeParams(), targs, true)
 		sub := subst.typ(T.Underlying())
 		if got := sub.String(); got != test.want {
 			t.Errorf("subst{%v->%v}.typ(%s) = %v, want %v", test.expr, test.args, T.Underlying(), got, test.want)
diff --git a/go/ssa/testdata/fixedbugs/issue66783a.go b/go/ssa/testdata/fixedbugs/issue66783a.go
new file mode 100644
index 0000000..d4cf0f5
--- /dev/null
+++ b/go/ssa/testdata/fixedbugs/issue66783a.go
@@ -0,0 +1,24 @@
+//go:build ignore
+// +build ignore
+
+package issue66783a
+
+type S[T any] struct {
+	a T
+}
+
+func (s S[T]) M() {
+	type A S[T]
+	type B[U any] A
+	_ = B[rune](s)
+}
+
+// M[int]
+
+// panic: in (issue66783a.S[int]).M[int]:
+// cannot convert term *t0 (issue66783a.S[int] [within struct{a int}])
+// to type issue66783a.B[rune] [within struct{a T}] [recovered]
+
+func M() {
+	S[int]{}.M()
+}
diff --git a/go/ssa/testdata/fixedbugs/issue66783b.go b/go/ssa/testdata/fixedbugs/issue66783b.go
new file mode 100644
index 0000000..50a2d30
--- /dev/null
+++ b/go/ssa/testdata/fixedbugs/issue66783b.go
@@ -0,0 +1,22 @@
+//go:build ignore
+// +build ignore
+
+package issue66783b
+
+type I1[T any] interface {
+	M(T)
+}
+
+type I2[T any] I1[T]
+
+func foo[T any](i I2[T]) {
+	_ = i.M
+}
+
+type S[T any] struct{}
+
+func (s S[T]) M(t T) {}
+
+func M2() {
+	foo[int](I2[int](S[int]{}))
+}
diff --git a/go/ssa/util.go b/go/ssa/util.go
index ed3e993..bd0e62e 100644
--- a/go/ssa/util.go
+++ b/go/ssa/util.go
@@ -149,6 +149,26 @@
 	return ok && b.Info()&types.IsUntyped != 0
 }
 
+// declaredWithin reports whether an object is declared within a function.
+//
+// obj must not be a method or a field.
+func declaredWithin(obj types.Object, fn *types.Func) bool {
+	if obj.Pos() != token.NoPos {
+		return fn.Scope().Contains(obj.Pos()) // trust the positions if they exist.
+	}
+	if fn.Pkg() != obj.Pkg() {
+		return false // fast path for different packages
+	}
+
+	// Traverse Parent() scopes for fn.Scope().
+	for p := obj.Parent(); p != nil; p = p.Parent() {
+		if p == fn.Scope() {
+			return true
+		}
+	}
+	return false
+}
+
 // logStack prints the formatted "start" message to stderr and
 // returns a closure that prints the corresponding "end" message.
 // Call using 'defer logStack(...)()' to show builder stack on panic.