blob: 32c3a9a08cf3f6ba6f69f24516edbbe1fb36028c [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssa_test
import (
"fmt"
"go/types"
"reflect"
"sort"
"strings"
"testing"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
// TestNeedsInstance ensures that new method instances can be created via MethodValue.
func TestNeedsInstance(t *testing.T) {
const input = `
package p
import "unsafe"
type Pointer[T any] struct {
v unsafe.Pointer
}
func (x *Pointer[T]) Load() *T {
return (*T)(LoadPointer(&x.v))
}
func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer)
`
// The SSA members for this package should look something like this:
// func LoadPointer func(addr *unsafe.Pointer) (val unsafe.Pointer)
// type Pointer struct{v unsafe.Pointer}
// method (*Pointer[T any]) Load() *T
// func init func()
// var init$guard bool
for _, mode := range []ssa.BuilderMode{
ssa.SanityCheckFunctions,
ssa.SanityCheckFunctions | ssa.InstantiateGenerics,
} {
p, _ := buildPackage(t, input, mode)
prog := p.Prog
ptr := p.Type("Pointer").Type().(*types.Named)
if ptr.NumMethods() != 1 {
t.Fatalf("Expected Pointer to have 1 method. got %d", ptr.NumMethods())
}
obj := ptr.Method(0)
if obj.Name() != "Load" {
t.Errorf("Expected Pointer to have method named 'Load'. got %q", obj.Name())
}
meth := prog.FuncValue(obj)
// instantiateLoadMethod returns the first method (Load) of the instantiation *Pointer[T].
instantiateLoadMethod := func(T types.Type) *ssa.Function {
ptrT, err := types.Instantiate(nil, ptr, []types.Type{T}, false)
if err != nil {
t.Fatalf("Failed to Instantiate %q by %q", ptr, T)
}
methods := types.NewMethodSet(types.NewPointer(ptrT))
if methods.Len() != 1 {
t.Fatalf("Expected 1 method for %q. got %d", ptrT, methods.Len())
}
return prog.MethodValue(methods.At(0))
}
intSliceTyp := types.NewSlice(types.Typ[types.Int])
instance := instantiateLoadMethod(intSliceTyp) // (*Pointer[[]int]).Load
if instance.Origin() != meth {
t.Errorf("Expected Origin of %s to be %s. got %s", instance, meth, instance.Origin())
}
if len(instance.TypeArgs()) != 1 || !types.Identical(instance.TypeArgs()[0], intSliceTyp) {
t.Errorf("Expected TypeArgs of %s to be %v. got %v", instance, []types.Type{intSliceTyp}, instance.TypeArgs())
}
// A second request with an identical type returns the same Function.
second := instantiateLoadMethod(types.NewSlice(types.Typ[types.Int]))
if second != instance {
t.Error("Expected second identical instantiation to be the same function")
}
// (*Pointer[[]uint]).Load
inst2 := instantiateLoadMethod(types.NewSlice(types.Typ[types.Uint]))
if instance.Name() >= inst2.Name() {
t.Errorf("Expected name of instance %s to be before instance %v", instance, inst2)
}
}
}
// TestCallsToInstances checks that calles of calls to generic functions,
// without monomorphization, are wrappers around the origin generic function.
func TestCallsToInstances(t *testing.T) {
const input = `
package p
type I interface {
Foo()
}
type A int
func (a A) Foo() {}
type J[T any] interface{ Bar() T }
type K[T any] struct{ J[T] }
func Id[T any] (t T) T {
return t
}
func Lambda[T I]() func() func(T) {
return func() func(T) {
return T.Foo
}
}
func NoOp[T any]() {}
func Bar[T interface { Foo(); ~int | ~string }, U any] (t T, u U) {
Id[U](u)
Id[T](t)
}
func Make[T any]() interface{} {
NoOp[K[T]]()
return nil
}
func entry(i int, a A) int {
Lambda[A]()()(a)
x := Make[int]()
if j, ok := x.(interface{ Bar() int }); ok {
print(j)
}
Bar[A, int](a, i)
return Id[int](i)
}
`
p, _ := buildPackage(t, input, ssa.SanityCheckFunctions)
all := ssautil.AllFunctions(p.Prog)
for _, ti := range []struct {
orig string
instance string
tparams string
targs string
chTypeInstrs int // number of ChangeType instructions in f's body
}{
{"Id", "Id[int]", "[T]", "[int]", 2},
{"Lambda", "Lambda[p.A]", "[T]", "[p.A]", 1},
{"Make", "Make[int]", "[T]", "[int]", 0},
{"NoOp", "NoOp[p.K[T]]", "[T]", "[p.K[T]]", 0},
} {
test := ti
t.Run(test.instance, func(t *testing.T) {
f := p.Members[test.orig].(*ssa.Function)
if f == nil {
t.Fatalf("origin function not found")
}
var i *ssa.Function
for _, fn := range instancesOf(all, f) {
if fn.Name() == test.instance {
i = fn
break
}
}
if i == nil {
t.Fatalf("instance not found")
}
// for logging on failures
var body strings.Builder
i.WriteTo(&body)
t.Log(body.String())
if len(i.Blocks) != 1 {
t.Fatalf("body has more than 1 block")
}
if instrs := changeTypeInstrs(i.Blocks[0]); instrs != test.chTypeInstrs {
t.Errorf("want %v instructions; got %v", test.chTypeInstrs, instrs)
}
if test.tparams != tparams(i) {
t.Errorf("want %v type params; got %v", test.tparams, tparams(i))
}
if test.targs != targs(i) {
t.Errorf("want %v type arguments; got %v", test.targs, targs(i))
}
})
}
}
func tparams(f *ssa.Function) string {
tplist := f.TypeParams()
var tps []string
for i := 0; i < tplist.Len(); i++ {
tps = append(tps, tplist.At(i).String())
}
return fmt.Sprint(tps)
}
func targs(f *ssa.Function) string {
var tas []string
for _, ta := range f.TypeArgs() {
tas = append(tas, ta.String())
}
return fmt.Sprint(tas)
}
func changeTypeInstrs(b *ssa.BasicBlock) int {
cnt := 0
for _, i := range b.Instrs {
if _, ok := i.(*ssa.ChangeType); ok {
cnt++
}
}
return cnt
}
func TestInstanceUniqueness(t *testing.T) {
const input = `
package p
func H[T any](t T) {
print(t)
}
func F[T any](t T) {
H[T](t)
H[T](t)
H[T](t)
}
func G[T any](t T) {
H[T](t)
H[T](t)
}
func Foo[T any, S any](t T, s S) {
Foo[S, T](s, t)
Foo[T, S](t, s)
}
`
p, _ := buildPackage(t, input, ssa.SanityCheckFunctions)
all := ssautil.AllFunctions(p.Prog)
for _, test := range []struct {
orig string
instances string
}{
{"H", "[p.H[T] p.H[T]]"},
{"Foo", "[p.Foo[S T] p.Foo[T S]]"},
} {
t.Run(test.orig, func(t *testing.T) {
f := p.Members[test.orig].(*ssa.Function)
if f == nil {
t.Fatalf("origin function not found")
}
instances := instancesOf(all, f)
sort.Slice(instances, func(i, j int) bool { return instances[i].Name() < instances[j].Name() })
if got := fmt.Sprintf("%v", instances); !reflect.DeepEqual(got, test.instances) {
t.Errorf("got %v instances, want %v", got, test.instances)
}
})
}
}
// instancesOf returns a new unordered slice of all instances of the
// specified function g in fns.
func instancesOf(fns map[*ssa.Function]bool, g *ssa.Function) []*ssa.Function {
var instances []*ssa.Function
for fn := range fns {
if fn != g && fn.Origin() == g {
instances = append(instances, fn)
}
}
return instances
}