go/ssa: substitute type parameterized aliases

Adds support to substitute type parameterized aliases in
generic functions.

Change-Id: I4fb2e5f5fd9b626781efdc4db808c52cb22ba241
Reviewed-on: https://go-review.googlesource.com/c/tools/+/602195
Reviewed-by: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go
index 33531da..55dc79f 100644
--- a/go/ssa/builder_generic_test.go
+++ b/go/ssa/builder_generic_test.go
@@ -550,7 +550,13 @@
 			}
 
 			// Collect calls to the builtin print function.
-			probes := callsTo(p, "print")
+			fns := make(map[*ssa.Function]bool)
+			for _, mem := range p.Members {
+				if fn, ok := mem.(*ssa.Function); ok {
+					fns[fn] = true
+				}
+			}
+			probes := callsTo(fns, "print")
 			expectations := matchNotes(prog.Fset, notes, probes)
 
 			for call := range probes {
@@ -576,17 +582,15 @@
 
 // callsTo finds all calls to an SSA value named fname,
 // and returns a map from each call site to its enclosing function.
-func callsTo(p *ssa.Package, fname string) map[*ssa.CallCommon]*ssa.Function {
+func callsTo(fns map[*ssa.Function]bool, fname string) map[*ssa.CallCommon]*ssa.Function {
 	callsites := make(map[*ssa.CallCommon]*ssa.Function)
-	for _, mem := range p.Members {
-		if fn, ok := mem.(*ssa.Function); ok {
-			for _, bb := range fn.Blocks {
-				for _, i := range bb.Instrs {
-					if i, ok := i.(ssa.CallInstruction); ok {
-						call := i.Common()
-						if call.Value.Name() == fname {
-							callsites[call] = fn
-						}
+	for fn := range fns {
+		for _, bb := range fn.Blocks {
+			for _, i := range bb.Instrs {
+				if i, ok := i.(ssa.CallInstruction); ok {
+					call := i.Common()
+					if call.Value.Name() == fname {
+						callsites[call] = fn
 					}
 				}
 			}
diff --git a/go/ssa/builder_go122_test.go b/go/ssa/builder_go122_test.go
index d984312..bde5bae 100644
--- a/go/ssa/builder_go122_test.go
+++ b/go/ssa/builder_go122_test.go
@@ -168,7 +168,13 @@
 	}
 
 	// Collect calls to the built-in print function.
-	probes := callsTo(p, "print")
+	fns := make(map[*ssa.Function]bool)
+	for _, mem := range p.Members {
+		if fn, ok := mem.(*ssa.Function); ok {
+			fns[fn] = true
+		}
+	}
+	probes := callsTo(fns, "print")
 	expectations := matchNotes(fset, notes, probes)
 
 	for call := range probes {
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index ed1d84f..f6fae50 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -14,6 +14,7 @@
 	"go/token"
 	"go/types"
 	"os"
+	"os/exec"
 	"path/filepath"
 	"reflect"
 	"sort"
@@ -1260,3 +1261,143 @@
 
 	g.Wait() // ignore error
 }
+
+func TestGenericAliases(t *testing.T) {
+	testenv.NeedsGo1Point(t, 23)
+
+	if os.Getenv("GENERICALIASTEST_CHILD") == "1" {
+		testGenericAliases(t)
+		return
+	}
+
+	testenv.NeedsExec(t)
+	testenv.NeedsTool(t, "go")
+
+	cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases")
+	cmd.Env = append(os.Environ(),
+		"GENERICALIASTEST_CHILD=1",
+		"GODEBUG=gotypesalias=1",
+		"GOEXPERIMENT=aliastypeparams",
+	)
+	out, err := cmd.CombinedOutput()
+	if len(out) > 0 {
+		t.Logf("out=<<%s>>", out)
+	}
+	var exitcode int
+	if err, ok := err.(*exec.ExitError); ok {
+		exitcode = err.ExitCode()
+	}
+	const want = 0
+	if exitcode != want {
+		t.Errorf("exited %d, want %d", exitcode, want)
+	}
+}
+
+func testGenericAliases(t *testing.T) {
+	t.Setenv("GOEXPERIMENT", "aliastypeparams=1")
+
+	const source = `
+package P
+
+type A = uint8
+type B[T any] = [4]T
+
+var F = f[string]
+
+func f[S any]() {
+	// Two copies of f are made: p.f[S] and p.f[string]
+
+	var v A // application of A that is declared outside of f without no type arguments
+	print("p.f", "String", "p.A", v)
+	print("p.f", "==", v, uint8(0))
+	print("p.f[string]", "String", "p.A", v)
+	print("p.f[string]", "==", v, uint8(0))
+
+
+	var u B[S] // application of B that is declared outside declared outside of f with type arguments
+	print("p.f", "String", "p.B[S]", u)
+	print("p.f", "==", u, [4]S{})
+	print("p.f[string]", "String", "p.B[string]", u)
+	print("p.f[string]", "==", u, [4]string{})
+
+	type C[T any] = struct{ s S; ap *B[T]} // declaration within f with type params
+	var w C[int] // application of C with type arguments
+	print("p.f", "String", "p.C[int]", w)
+	print("p.f", "==", w, struct{ s S; ap *[4]int}{})
+	print("p.f[string]", "String", "p.C[int]", w)
+	print("p.f[string]", "==", w, struct{ s string; ap *[4]int}{})
+}
+`
+
+	conf := loader.Config{Fset: token.NewFileSet()}
+	f, err := parser.ParseFile(conf.Fset, "p.go", source, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	conf.CreateFromFiles("p", f)
+	iprog, err := conf.Load()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Create and build SSA program.
+	prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics)
+	prog.Build()
+
+	probes := callsTo(ssautil.AllFunctions(prog), "print")
+	if got, want := len(probes), 3*4*2; got != want {
+		t.Errorf("Found %v probes, expected %v", got, want)
+	}
+
+	const debug = false // enable to debug skips
+	skipped := 0
+	for probe, fn := range probes {
+		// Each probe is of the form:
+		// 		print("within", "test", head, tail)
+		// The probe only matches within a function whose fn.String() is within.
+		// This allows for different instantiations of fn to match different probes.
+		// On a match, it applies the test named "test" to head::tail.
+		if len(probe.Args) < 3 {
+			t.Fatalf("probe %v did not have enough arguments", probe)
+		}
+		within, test, head, tail := constString(probe.Args[0]), probe.Args[1], probe.Args[2], probe.Args[3:]
+		if within != fn.String() {
+			skipped++
+			if debug {
+				t.Logf("Skipping %q within %q", within, fn.String())
+			}
+			continue // does not match function
+		}
+
+		switch test := constString(test); test {
+		case "==": // All of the values are types.Identical.
+			for _, v := range tail {
+				if !types.Identical(head.Type(), v.Type()) {
+					t.Errorf("Expected %v and %v to have identical types", head, v)
+				}
+			}
+		case "String": // head is a string constant that all values in tail must match Type().String()
+			want := constString(head)
+			for _, v := range tail {
+				if got := v.Type().String(); got != want {
+					t.Errorf("%s: %v had the Type().String()=%q. expected %q", within, v, got, want)
+				}
+			}
+		default:
+			t.Errorf("%q is not a test subcommand", test)
+		}
+	}
+	if want := 3 * 4; skipped != want {
+		t.Errorf("Skipped %d probes, expected to skip %d", skipped, want)
+	}
+}
+
+// constString returns the value of a string constant
+// or "<not a constant string>" if the value is not a string constant.
+func constString(v ssa.Value) string {
+	if c, ok := v.(*ssa.Const); ok {
+		str := c.Value.String()
+		return strings.Trim(str, `"`)
+	}
+	return "<not a constant string>"
+}
diff --git a/go/ssa/subst.go b/go/ssa/subst.go
index 75d887d..4dcb871 100644
--- a/go/ssa/subst.go
+++ b/go/ssa/subst.go
@@ -318,15 +318,80 @@
 }
 
 func (subst *subster) alias(t *aliases.Alias) types.Type {
-	// TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
-	u := aliases.Unalias(t)
-	if s := subst.typ(u); s != u {
-		// If there is any change, do not create a new alias.
-		return s
+	// See subster.named. This follows the same strategy.
+	tparams := aliases.TypeParams(t)
+	targs := aliases.TypeArgs(t)
+	tname := t.Obj()
+	torigin := aliases.Origin(t)
+
+	if !declaredWithin(tname, subst.origin) {
+		// t is declared outside of the function origin. So t is a package level type alias.
+		if targs.Len() == 0 {
+			// No type arguments so no instantiation needed.
+			return t
+		}
+
+		// Instantiate with the substituted type arguments.
+		newTArgs := subst.typelist(targs)
+		return subst.instantiate(torigin, newTArgs)
 	}
-	// If there is no change, t did not reach any type parameter.
-	// Keep the Alias.
-	return t
+
+	if targs.Len() == 0 {
+		// t is declared within the function origin and has no type arguments.
+		//
+		// Example: This corresponds to A or B in F, but not A[int]:
+		//
+		//     func F[T any]() {
+		//       type A[S any] = struct{t T, s S}
+		//       type B = T
+		//       var x A[int]
+		//       ...
+		//     }
+		//
+		// This is somewhat different than *Named as *Alias cannot be created recursively.
+
+		// Copy and substitute type params.
+		var newTParams []*types.TypeParam
+		for i := 0; i < tparams.Len(); i++ {
+			cur := tparams.At(i)
+			cobj := cur.Obj()
+			cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
+			ntp := types.NewTypeParam(cname, nil)
+			subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
+			newTParams = append(newTParams, ntp)
+		}
+
+		// Substitute rhs.
+		rhs := subst.typ(aliases.Rhs(t))
+
+		// Create the fresh alias.
+		obj := aliases.NewAlias(true, tname.Pos(), tname.Pkg(), tname.Name(), rhs)
+		fresh := obj.Type()
+		if fresh, ok := fresh.(*aliases.Alias); ok {
+			// TODO: assume ok when aliases are always materialized (go1.27).
+			aliases.SetTypeParams(fresh, newTParams)
+		}
+
+		// Substitute into all of the constraints after they are created.
+		for i, ntp := range newTParams {
+			bound := tparams.At(i).Constraint()
+			ntp.SetConstraint(subst.typ(bound))
+		}
+		return fresh
+	}
+
+	// t is declared within the function origin and has type arguments.
+	//
+	// Example: This corresponds to A[int] in F. Cases A and B are handled above.
+	//     func F[T any]() {
+	//       type A[S any] = struct{t T, s S}
+	//       type B = T
+	//       var x A[int]
+	//       ...
+	//     }
+	subOrigin := subst.typ(torigin)
+	subTArgs := subst.typelist(targs)
+	return subst.instantiate(subOrigin, subTArgs)
 }
 
 func (subst *subster) named(t *types.Named) types.Type {
@@ -456,7 +521,7 @@
 
 func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
 	i, err := types.Instantiate(subst.ctxt, orig, targs, false)
-	assert(err == nil, "failed to Instantiate Named type")
+	assert(err == nil, "failed to Instantiate named (Named or Alias) type")
 	if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
 		return c.(types.Type)
 	}
diff --git a/internal/aliases/aliases_go121.go b/internal/aliases/aliases_go121.go
index 63391e5..6652f7d 100644
--- a/internal/aliases/aliases_go121.go
+++ b/internal/aliases/aliases_go121.go
@@ -15,11 +15,14 @@
 // It will never be created by go/types.
 type Alias struct{}
 
-func (*Alias) String() string                      { panic("unreachable") }
-func (*Alias) Underlying() types.Type              { panic("unreachable") }
-func (*Alias) Obj() *types.TypeName                { panic("unreachable") }
-func Rhs(alias *Alias) types.Type                  { panic("unreachable") }
-func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
+func (*Alias) String() string                                { panic("unreachable") }
+func (*Alias) Underlying() types.Type                        { panic("unreachable") }
+func (*Alias) Obj() *types.TypeName                          { panic("unreachable") }
+func Rhs(alias *Alias) types.Type                            { panic("unreachable") }
+func TypeParams(alias *Alias) *types.TypeParamList           { panic("unreachable") }
+func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { panic("unreachable") }
+func TypeArgs(alias *Alias) *types.TypeList                  { panic("unreachable") }
+func Origin(alias *Alias) *Alias                             { panic("unreachable") }
 
 // Unalias returns the type t for go <=1.21.
 func Unalias(t types.Type) types.Type { return t }
diff --git a/internal/aliases/aliases_go122.go b/internal/aliases/aliases_go122.go
index 96fcd16..3ef1afe 100644
--- a/internal/aliases/aliases_go122.go
+++ b/internal/aliases/aliases_go122.go
@@ -36,6 +36,34 @@
 	return nil
 }
 
+// SetTypeParams sets the type parameters of the alias type.
+func SetTypeParams(alias *Alias, tparams []*types.TypeParam) {
+	if alias, ok := any(alias).(interface {
+		SetTypeParams(tparams []*types.TypeParam)
+	}); ok {
+		alias.SetTypeParams(tparams) // go1.23+
+	} else if len(tparams) > 0 {
+		panic("cannot set type parameters of an Alias type in go1.22")
+	}
+}
+
+// TypeArgs returns the type arguments used to instantiate the Alias type.
+func TypeArgs(alias *Alias) *types.TypeList {
+	if alias, ok := any(alias).(interface{ TypeArgs() *types.TypeList }); ok {
+		return alias.TypeArgs() // go1.23+
+	}
+	return nil // empty (go1.22)
+}
+
+// Origin returns the generic Alias type of which alias is an instance.
+// If alias is not an instance of a generic alias, Origin returns alias.
+func Origin(alias *Alias) *Alias {
+	if alias, ok := any(alias).(interface{ Origin() *types.Alias }); ok {
+		return alias.Origin() // go1.23+
+	}
+	return alias // not an instance of a generic alias (go1.22)
+}
+
 // Unalias is a wrapper of types.Unalias.
 func Unalias(t types.Type) types.Type { return types.Unalias(t) }