go/ssa: substitute type parameterized aliases
Adds support to substitute type parameterized aliases in
generic functions.
Change-Id: I4fb2e5f5fd9b626781efdc4db808c52cb22ba241
Reviewed-on: https://go-review.googlesource.com/c/tools/+/602195
Reviewed-by: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go
index 33531da..55dc79f 100644
--- a/go/ssa/builder_generic_test.go
+++ b/go/ssa/builder_generic_test.go
@@ -550,7 +550,13 @@
}
// Collect calls to the builtin print function.
- probes := callsTo(p, "print")
+ fns := make(map[*ssa.Function]bool)
+ for _, mem := range p.Members {
+ if fn, ok := mem.(*ssa.Function); ok {
+ fns[fn] = true
+ }
+ }
+ probes := callsTo(fns, "print")
expectations := matchNotes(prog.Fset, notes, probes)
for call := range probes {
@@ -576,17 +582,15 @@
// callsTo finds all calls to an SSA value named fname,
// and returns a map from each call site to its enclosing function.
-func callsTo(p *ssa.Package, fname string) map[*ssa.CallCommon]*ssa.Function {
+func callsTo(fns map[*ssa.Function]bool, fname string) map[*ssa.CallCommon]*ssa.Function {
callsites := make(map[*ssa.CallCommon]*ssa.Function)
- for _, mem := range p.Members {
- if fn, ok := mem.(*ssa.Function); ok {
- for _, bb := range fn.Blocks {
- for _, i := range bb.Instrs {
- if i, ok := i.(ssa.CallInstruction); ok {
- call := i.Common()
- if call.Value.Name() == fname {
- callsites[call] = fn
- }
+ for fn := range fns {
+ for _, bb := range fn.Blocks {
+ for _, i := range bb.Instrs {
+ if i, ok := i.(ssa.CallInstruction); ok {
+ call := i.Common()
+ if call.Value.Name() == fname {
+ callsites[call] = fn
}
}
}
diff --git a/go/ssa/builder_go122_test.go b/go/ssa/builder_go122_test.go
index d984312..bde5bae 100644
--- a/go/ssa/builder_go122_test.go
+++ b/go/ssa/builder_go122_test.go
@@ -168,7 +168,13 @@
}
// Collect calls to the built-in print function.
- probes := callsTo(p, "print")
+ fns := make(map[*ssa.Function]bool)
+ for _, mem := range p.Members {
+ if fn, ok := mem.(*ssa.Function); ok {
+ fns[fn] = true
+ }
+ }
+ probes := callsTo(fns, "print")
expectations := matchNotes(fset, notes, probes)
for call := range probes {
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index ed1d84f..f6fae50 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -14,6 +14,7 @@
"go/token"
"go/types"
"os"
+ "os/exec"
"path/filepath"
"reflect"
"sort"
@@ -1260,3 +1261,143 @@
g.Wait() // ignore error
}
+
+func TestGenericAliases(t *testing.T) {
+ testenv.NeedsGo1Point(t, 23)
+
+ if os.Getenv("GENERICALIASTEST_CHILD") == "1" {
+ testGenericAliases(t)
+ return
+ }
+
+ testenv.NeedsExec(t)
+ testenv.NeedsTool(t, "go")
+
+ cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases")
+ cmd.Env = append(os.Environ(),
+ "GENERICALIASTEST_CHILD=1",
+ "GODEBUG=gotypesalias=1",
+ "GOEXPERIMENT=aliastypeparams",
+ )
+ out, err := cmd.CombinedOutput()
+ if len(out) > 0 {
+ t.Logf("out=<<%s>>", out)
+ }
+ var exitcode int
+ if err, ok := err.(*exec.ExitError); ok {
+ exitcode = err.ExitCode()
+ }
+ const want = 0
+ if exitcode != want {
+ t.Errorf("exited %d, want %d", exitcode, want)
+ }
+}
+
+func testGenericAliases(t *testing.T) {
+ t.Setenv("GOEXPERIMENT", "aliastypeparams=1")
+
+ const source = `
+package P
+
+type A = uint8
+type B[T any] = [4]T
+
+var F = f[string]
+
+func f[S any]() {
+ // Two copies of f are made: p.f[S] and p.f[string]
+
+ var v A // application of A that is declared outside of f without no type arguments
+ print("p.f", "String", "p.A", v)
+ print("p.f", "==", v, uint8(0))
+ print("p.f[string]", "String", "p.A", v)
+ print("p.f[string]", "==", v, uint8(0))
+
+
+ var u B[S] // application of B that is declared outside declared outside of f with type arguments
+ print("p.f", "String", "p.B[S]", u)
+ print("p.f", "==", u, [4]S{})
+ print("p.f[string]", "String", "p.B[string]", u)
+ print("p.f[string]", "==", u, [4]string{})
+
+ type C[T any] = struct{ s S; ap *B[T]} // declaration within f with type params
+ var w C[int] // application of C with type arguments
+ print("p.f", "String", "p.C[int]", w)
+ print("p.f", "==", w, struct{ s S; ap *[4]int}{})
+ print("p.f[string]", "String", "p.C[int]", w)
+ print("p.f[string]", "==", w, struct{ s string; ap *[4]int}{})
+}
+`
+
+ conf := loader.Config{Fset: token.NewFileSet()}
+ f, err := parser.ParseFile(conf.Fset, "p.go", source, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conf.CreateFromFiles("p", f)
+ iprog, err := conf.Load()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create and build SSA program.
+ prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics)
+ prog.Build()
+
+ probes := callsTo(ssautil.AllFunctions(prog), "print")
+ if got, want := len(probes), 3*4*2; got != want {
+ t.Errorf("Found %v probes, expected %v", got, want)
+ }
+
+ const debug = false // enable to debug skips
+ skipped := 0
+ for probe, fn := range probes {
+ // Each probe is of the form:
+ // print("within", "test", head, tail)
+ // The probe only matches within a function whose fn.String() is within.
+ // This allows for different instantiations of fn to match different probes.
+ // On a match, it applies the test named "test" to head::tail.
+ if len(probe.Args) < 3 {
+ t.Fatalf("probe %v did not have enough arguments", probe)
+ }
+ within, test, head, tail := constString(probe.Args[0]), probe.Args[1], probe.Args[2], probe.Args[3:]
+ if within != fn.String() {
+ skipped++
+ if debug {
+ t.Logf("Skipping %q within %q", within, fn.String())
+ }
+ continue // does not match function
+ }
+
+ switch test := constString(test); test {
+ case "==": // All of the values are types.Identical.
+ for _, v := range tail {
+ if !types.Identical(head.Type(), v.Type()) {
+ t.Errorf("Expected %v and %v to have identical types", head, v)
+ }
+ }
+ case "String": // head is a string constant that all values in tail must match Type().String()
+ want := constString(head)
+ for _, v := range tail {
+ if got := v.Type().String(); got != want {
+ t.Errorf("%s: %v had the Type().String()=%q. expected %q", within, v, got, want)
+ }
+ }
+ default:
+ t.Errorf("%q is not a test subcommand", test)
+ }
+ }
+ if want := 3 * 4; skipped != want {
+ t.Errorf("Skipped %d probes, expected to skip %d", skipped, want)
+ }
+}
+
+// constString returns the value of a string constant
+// or "<not a constant string>" if the value is not a string constant.
+func constString(v ssa.Value) string {
+ if c, ok := v.(*ssa.Const); ok {
+ str := c.Value.String()
+ return strings.Trim(str, `"`)
+ }
+ return "<not a constant string>"
+}
diff --git a/go/ssa/subst.go b/go/ssa/subst.go
index 75d887d..4dcb871 100644
--- a/go/ssa/subst.go
+++ b/go/ssa/subst.go
@@ -318,15 +318,80 @@
}
func (subst *subster) alias(t *aliases.Alias) types.Type {
- // TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
- u := aliases.Unalias(t)
- if s := subst.typ(u); s != u {
- // If there is any change, do not create a new alias.
- return s
+ // See subster.named. This follows the same strategy.
+ tparams := aliases.TypeParams(t)
+ targs := aliases.TypeArgs(t)
+ tname := t.Obj()
+ torigin := aliases.Origin(t)
+
+ if !declaredWithin(tname, subst.origin) {
+ // t is declared outside of the function origin. So t is a package level type alias.
+ if targs.Len() == 0 {
+ // No type arguments so no instantiation needed.
+ return t
+ }
+
+ // Instantiate with the substituted type arguments.
+ newTArgs := subst.typelist(targs)
+ return subst.instantiate(torigin, newTArgs)
}
- // If there is no change, t did not reach any type parameter.
- // Keep the Alias.
- return t
+
+ if targs.Len() == 0 {
+ // t is declared within the function origin and has no type arguments.
+ //
+ // Example: This corresponds to A or B in F, but not A[int]:
+ //
+ // func F[T any]() {
+ // type A[S any] = struct{t T, s S}
+ // type B = T
+ // var x A[int]
+ // ...
+ // }
+ //
+ // This is somewhat different than *Named as *Alias cannot be created recursively.
+
+ // Copy and substitute type params.
+ var newTParams []*types.TypeParam
+ for i := 0; i < tparams.Len(); i++ {
+ cur := tparams.At(i)
+ cobj := cur.Obj()
+ cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
+ ntp := types.NewTypeParam(cname, nil)
+ subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
+ newTParams = append(newTParams, ntp)
+ }
+
+ // Substitute rhs.
+ rhs := subst.typ(aliases.Rhs(t))
+
+ // Create the fresh alias.
+ obj := aliases.NewAlias(true, tname.Pos(), tname.Pkg(), tname.Name(), rhs)
+ fresh := obj.Type()
+ if fresh, ok := fresh.(*aliases.Alias); ok {
+ // TODO: assume ok when aliases are always materialized (go1.27).
+ aliases.SetTypeParams(fresh, newTParams)
+ }
+
+ // Substitute into all of the constraints after they are created.
+ for i, ntp := range newTParams {
+ bound := tparams.At(i).Constraint()
+ ntp.SetConstraint(subst.typ(bound))
+ }
+ return fresh
+ }
+
+ // t is declared within the function origin and has type arguments.
+ //
+ // Example: This corresponds to A[int] in F. Cases A and B are handled above.
+ // func F[T any]() {
+ // type A[S any] = struct{t T, s S}
+ // type B = T
+ // var x A[int]
+ // ...
+ // }
+ subOrigin := subst.typ(torigin)
+ subTArgs := subst.typelist(targs)
+ return subst.instantiate(subOrigin, subTArgs)
}
func (subst *subster) named(t *types.Named) types.Type {
@@ -456,7 +521,7 @@
func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
i, err := types.Instantiate(subst.ctxt, orig, targs, false)
- assert(err == nil, "failed to Instantiate Named type")
+ assert(err == nil, "failed to Instantiate named (Named or Alias) type")
if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
return c.(types.Type)
}
diff --git a/internal/aliases/aliases_go121.go b/internal/aliases/aliases_go121.go
index 63391e5..6652f7d 100644
--- a/internal/aliases/aliases_go121.go
+++ b/internal/aliases/aliases_go121.go
@@ -15,11 +15,14 @@
// It will never be created by go/types.
type Alias struct{}
-func (*Alias) String() string { panic("unreachable") }
-func (*Alias) Underlying() types.Type { panic("unreachable") }
-func (*Alias) Obj() *types.TypeName { panic("unreachable") }
-func Rhs(alias *Alias) types.Type { panic("unreachable") }
-func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
+func (*Alias) String() string { panic("unreachable") }
+func (*Alias) Underlying() types.Type { panic("unreachable") }
+func (*Alias) Obj() *types.TypeName { panic("unreachable") }
+func Rhs(alias *Alias) types.Type { panic("unreachable") }
+func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
+func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { panic("unreachable") }
+func TypeArgs(alias *Alias) *types.TypeList { panic("unreachable") }
+func Origin(alias *Alias) *Alias { panic("unreachable") }
// Unalias returns the type t for go <=1.21.
func Unalias(t types.Type) types.Type { return t }
diff --git a/internal/aliases/aliases_go122.go b/internal/aliases/aliases_go122.go
index 96fcd16..3ef1afe 100644
--- a/internal/aliases/aliases_go122.go
+++ b/internal/aliases/aliases_go122.go
@@ -36,6 +36,34 @@
return nil
}
+// SetTypeParams sets the type parameters of the alias type.
+func SetTypeParams(alias *Alias, tparams []*types.TypeParam) {
+ if alias, ok := any(alias).(interface {
+ SetTypeParams(tparams []*types.TypeParam)
+ }); ok {
+ alias.SetTypeParams(tparams) // go1.23+
+ } else if len(tparams) > 0 {
+ panic("cannot set type parameters of an Alias type in go1.22")
+ }
+}
+
+// TypeArgs returns the type arguments used to instantiate the Alias type.
+func TypeArgs(alias *Alias) *types.TypeList {
+ if alias, ok := any(alias).(interface{ TypeArgs() *types.TypeList }); ok {
+ return alias.TypeArgs() // go1.23+
+ }
+ return nil // empty (go1.22)
+}
+
+// Origin returns the generic Alias type of which alias is an instance.
+// If alias is not an instance of a generic alias, Origin returns alias.
+func Origin(alias *Alias) *Alias {
+ if alias, ok := any(alias).(interface{ Origin() *types.Alias }); ok {
+ return alias.Origin() // go1.23+
+ }
+ return alias // not an instance of a generic alias (go1.22)
+}
+
// Unalias is a wrapper of types.Unalias.
func Unalias(t types.Type) types.Type { return types.Unalias(t) }