refactor/satisfy: don't crash on type parameters

This change causes the satisfy constraint pass to correctly handle
type parameters. In nearly all cases this means calling coreType(T)
instead of T.Underlying(). This, and the addition of cases for
C[T] and C[X, Y], should make the code robust to generic syntax.

However, it is still not clear what the semantics of constraints
are for the renaming tool. That work is left to a follow-up.

Also, add a test suite that exercises all the basic operators,
using generics in each case.

Fixes golang/go#52940

Change-Id: Ic1261eb551c99b582c35fadaa148b979532588df
Reviewed-on: https://go-review.googlesource.com/c/tools/+/413690
Reviewed-by: Robert Findley <rfindley@google.com>
diff --git a/refactor/satisfy/find.go b/refactor/satisfy/find.go
index ff4212b..91fb7de 100644
--- a/refactor/satisfy/find.go
+++ b/refactor/satisfy/find.go
@@ -10,10 +10,8 @@
 //
 // THIS PACKAGE IS EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
 //
-// It is provided only for the gorename tool.  Ideally this
-// functionality will become part of the type-checker in due course,
-// since it is computing it anyway, and it is robust for ill-typed
-// inputs, which this package is not.
+// It is provided only for the gopls tool. It requires well-typed inputs.
+//
 package satisfy // import "golang.org/x/tools/refactor/satisfy"
 
 // NOTES:
@@ -25,9 +23,6 @@
 //     ...
 //   }})
 //
-// TODO(adonovan): make this robust against ill-typed input.
-// Or move it into the type-checker.
-//
 // Assignability conversions are possible in the following places:
 // - in assignments y = x, y := x, var y = x.
 // - from call argument types to formal parameter types
@@ -51,11 +46,15 @@
 
 	"golang.org/x/tools/go/ast/astutil"
 	"golang.org/x/tools/go/types/typeutil"
+	"golang.org/x/tools/internal/typeparams"
 )
 
 // A Constraint records the fact that the RHS type does and must
 // satisfy the LHS type, which is an interface.
 // The names are suggestive of an assignment statement LHS = RHS.
+//
+// The constraint is implicitly universally quantified over any type
+// parameters appearing within the two types.
 type Constraint struct {
 	LHS, RHS types.Type
 }
@@ -129,13 +128,13 @@
 
 	case *ast.CallExpr:
 		// x, err := f(args)
-		sig := f.expr(e.Fun).Underlying().(*types.Signature)
+		sig := coreType(f.expr(e.Fun)).(*types.Signature)
 		f.call(sig, e.Args)
 
 	case *ast.IndexExpr:
 		// y, ok := x[i]
 		x := f.expr(e.X)
-		f.assign(f.expr(e.Index), x.Underlying().(*types.Map).Key())
+		f.assign(f.expr(e.Index), coreType(x).(*types.Map).Key())
 
 	case *ast.TypeAssertExpr:
 		// y, ok := x.(T)
@@ -215,7 +214,7 @@
 			f.expr(args[1])
 		} else {
 			// append(x, y, z)
-			tElem := s.Underlying().(*types.Slice).Elem()
+			tElem := coreType(s).(*types.Slice).Elem()
 			for _, arg := range args[1:] {
 				f.assign(tElem, f.expr(arg))
 			}
@@ -224,7 +223,7 @@
 	case "delete":
 		m := f.expr(args[0])
 		k := f.expr(args[1])
-		f.assign(m.Underlying().(*types.Map).Key(), k)
+		f.assign(coreType(m).(*types.Map).Key(), k)
 
 	default:
 		// ordinary call
@@ -358,6 +357,7 @@
 		f.sig = saved
 
 	case *ast.CompositeLit:
+		// No need for coreType here: go1.18 disallows P{...} for type param P.
 		switch T := deref(tv.Type).Underlying().(type) {
 		case *types.Struct:
 			for i, elem := range e.Elts {
@@ -403,12 +403,20 @@
 		}
 
 	case *ast.IndexExpr:
-		x := f.expr(e.X)
-		i := f.expr(e.Index)
-		if ux, ok := x.Underlying().(*types.Map); ok {
-			f.assign(ux.Key(), i)
+		if instance(f.info, e.X) {
+			// f[T] or C[T] -- generic instantiation
+		} else {
+			// x[i] or m[k] -- index or lookup operation
+			x := f.expr(e.X)
+			i := f.expr(e.Index)
+			if ux, ok := coreType(x).(*types.Map); ok {
+				f.assign(ux.Key(), i)
+			}
 		}
 
+	case *typeparams.IndexListExpr:
+		// f[X, Y] -- generic instantiation
+
 	case *ast.SliceExpr:
 		f.expr(e.X)
 		if e.Low != nil {
@@ -439,7 +447,7 @@
 				}
 			}
 			// ordinary call
-			f.call(f.expr(e.Fun).Underlying().(*types.Signature), e.Args)
+			f.call(coreType(f.expr(e.Fun)).(*types.Signature), e.Args)
 		}
 
 	case *ast.StarExpr:
@@ -499,7 +507,7 @@
 	case *ast.SendStmt:
 		ch := f.expr(s.Chan)
 		val := f.expr(s.Value)
-		f.assign(ch.Underlying().(*types.Chan).Elem(), val)
+		f.assign(coreType(ch).(*types.Chan).Elem(), val)
 
 	case *ast.IncDecStmt:
 		f.expr(s.X)
@@ -647,35 +655,35 @@
 			if s.Key != nil {
 				k := f.expr(s.Key)
 				var xelem types.Type
-				// keys of array, *array, slice, string aren't interesting
-				switch ux := x.Underlying().(type) {
+				// Keys of array, *array, slice, string aren't interesting
+				// since the RHS key type is just an int.
+				switch ux := coreType(x).(type) {
 				case *types.Chan:
 					xelem = ux.Elem()
 				case *types.Map:
 					xelem = ux.Key()
 				}
 				if xelem != nil {
-					f.assign(xelem, k)
+					f.assign(k, xelem)
 				}
 			}
 			if s.Value != nil {
 				val := f.expr(s.Value)
 				var xelem types.Type
-				// values of strings aren't interesting
-				switch ux := x.Underlying().(type) {
+				// Values of type strings aren't interesting because
+				// the RHS value type is just a rune.
+				switch ux := coreType(x).(type) {
 				case *types.Array:
 					xelem = ux.Elem()
-				case *types.Chan:
-					xelem = ux.Elem()
 				case *types.Map:
 					xelem = ux.Elem()
 				case *types.Pointer: // *array
-					xelem = deref(ux).(*types.Array).Elem()
+					xelem = coreType(deref(ux)).(*types.Array).Elem()
 				case *types.Slice:
 					xelem = ux.Elem()
 				}
 				if xelem != nil {
-					f.assign(xelem, val)
+					f.assign(val, xelem)
 				}
 			}
 		}
@@ -690,7 +698,7 @@
 
 // deref returns a pointer's element type; otherwise it returns typ.
 func deref(typ types.Type) types.Type {
-	if p, ok := typ.Underlying().(*types.Pointer); ok {
+	if p, ok := coreType(typ).(*types.Pointer); ok {
 		return p.Elem()
 	}
 	return typ
@@ -699,3 +707,19 @@
 func unparen(e ast.Expr) ast.Expr { return astutil.Unparen(e) }
 
 func isInterface(T types.Type) bool { return types.IsInterface(T) }
+
+func coreType(T types.Type) types.Type { return typeparams.CoreType(T) }
+
+func instance(info *types.Info, expr ast.Expr) bool {
+	var id *ast.Ident
+	switch x := expr.(type) {
+	case *ast.Ident:
+		id = x
+	case *ast.SelectorExpr:
+		id = x.Sel
+	default:
+		return false
+	}
+	_, ok := typeparams.GetInstances(info)[id]
+	return ok
+}
diff --git a/refactor/satisfy/find_test.go b/refactor/satisfy/find_test.go
new file mode 100644
index 0000000..234bce9
--- /dev/null
+++ b/refactor/satisfy/find_test.go
@@ -0,0 +1,226 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package satisfy_test
+
+import (
+	"fmt"
+	"go/ast"
+	"go/parser"
+	"go/token"
+	"go/types"
+	"reflect"
+	"sort"
+	"testing"
+
+	"golang.org/x/tools/internal/typeparams"
+	"golang.org/x/tools/refactor/satisfy"
+)
+
+// This test exercises various operations on core types of type parameters.
+// (It also provides pretty decent coverage of the non-generic operations.)
+func TestGenericCoreOperations(t *testing.T) {
+	if !typeparams.Enabled {
+		t.Skip("!typeparams.Enabled")
+	}
+
+	const src = `package foo
+
+type I interface { f() }
+
+type impl struct{}
+func (impl) f() {}
+
+// A big pile of single-serving types that implement I.
+type A struct{impl}
+type B struct{impl}
+type C struct{impl}
+type D struct{impl}
+type E struct{impl}
+type F struct{impl}
+type G struct{impl}
+type H struct{impl}
+type J struct{impl}
+type K struct{impl}
+type L struct{impl}
+type M struct{impl}
+type N struct{impl}
+type O struct{impl}
+type P struct{impl}
+type Q struct{impl}
+type R struct{impl}
+type S struct{impl}
+type T struct{impl}
+type U struct{impl}
+
+type Generic[T any] struct{impl}
+func (Generic[T]) g(T) {}
+
+type GI[T any] interface{
+	g(T)
+}
+
+func _[Slice interface{ []I }](s Slice) Slice {
+	s[0] = L{} // I <- L
+	return append(s, A{}) // I <- A
+}
+
+func _[Func interface{ func(I) B }](fn Func) {
+	b := fn(C{}) // I <- C
+	var _ I = b // I <- B
+}
+
+func _[Chan interface{ chan D }](ch Chan) {
+	var i I
+	for i = range ch {} // I <- D
+	_ = i
+}
+
+func _[Chan interface{ chan E }](ch Chan) {
+	var _ I = <-ch // I <- E
+}
+
+func _[Chan interface{ chan I }](ch Chan) {
+	ch <- F{} // I <- F
+}
+
+func _[Map interface{ map[G]H }](m Map) {
+	var k, v I
+	for k, v = range m {} // I <- G, I <- H
+	_, _ = k, v
+}
+
+func _[Map interface{ map[I]K }](m Map) {
+	var _ I = m[J{}] // I <- J, I <- K
+	delete(m, R{}) // I <- R
+	_, _ = m[J{}]
+}
+
+func _[Array interface{ [1]I }](a Array) {
+	a[0] = M{} // I <- M
+}
+
+func _[Array interface{ [1]N }](a Array) {
+	var _ I = a[0] // I <- N
+}
+
+func _[Array interface{ [1]O }](a Array) {
+	var v I
+	for _, v = range a {} // I <- O
+	_ = v
+}
+
+func _[ArrayPtr interface{ *[1]P }](a ArrayPtr) {
+	var v I
+	for _, v = range a {} // I <- P
+	_ = v
+}
+
+func _[Slice interface{ []Q }](s Slice) {
+	var v I
+	for _, v = range s {} // I <- Q
+	_ = v
+}
+
+func _[Func interface{ func() (S, bool) }](fn Func) {
+	var i I
+	i, _ = fn() // I <- S
+	_ = i
+}
+
+func _() I {
+	var _ I = T{} // I <- T
+	var _ I = Generic[T]{} // I <- Generic[T]
+	var _ I = Generic[string]{} // I <- Generic[string]
+	return U{} // I <- U
+}
+
+var _ GI[string] = Generic[string]{} //  GI[string] <- Generic[string]
+
+// universally quantified constraints:
+// the type parameter may appear on the left, the right, or both sides.
+
+func  _[T any](g Generic[T]) GI[T] {
+	return g // GI[T] <- Generic[T]
+}
+
+func  _[T any]() {
+	type GI2[T any] interface{ g(string) }
+	var _ GI2[T] = Generic[string]{} // GI2[T] <- Generic[string]
+}
+
+type Gen2[T any] struct{}
+func (f Gen2[T]) g(string) { global = f } // GI[string] <- Gen2[T]
+
+var global GI[string] 
+
+`
+	got := constraints(t, src)
+	want := []string{
+		"p.GI2[T] <- p.Generic[string]", // implicitly "forall T" quantified
+		"p.GI[T] <- p.Generic[T]",       // implicitly "forall T" quantified
+		"p.GI[string] <- p.Gen2[T]",     // implicitly "forall T" quantified
+		"p.GI[string] <- p.Generic[string]",
+		"p.I <- p.A",
+		"p.I <- p.B",
+		"p.I <- p.C",
+		"p.I <- p.D",
+		"p.I <- p.E",
+		"p.I <- p.F",
+		"p.I <- p.G",
+		"p.I <- p.Generic[p.T]",
+		"p.I <- p.Generic[string]",
+		"p.I <- p.H",
+		"p.I <- p.J",
+		"p.I <- p.K",
+		"p.I <- p.L",
+		"p.I <- p.M",
+		"p.I <- p.N",
+		"p.I <- p.O",
+		"p.I <- p.P",
+		"p.I <- p.Q",
+		"p.I <- p.R",
+		"p.I <- p.S",
+		"p.I <- p.T",
+		"p.I <- p.U",
+	}
+	if !reflect.DeepEqual(got, want) {
+		t.Fatalf("found unexpected constraints: got %s, want %s", got, want)
+	}
+}
+
+func constraints(t *testing.T, src string) []string {
+	// parse
+	fset := token.NewFileSet()
+	f, err := parser.ParseFile(fset, "p.go", src, 0)
+	if err != nil {
+		t.Fatal(err) // parse error
+	}
+	files := []*ast.File{f}
+
+	// type-check
+	info := &types.Info{
+		Types:      make(map[ast.Expr]types.TypeAndValue),
+		Defs:       make(map[*ast.Ident]types.Object),
+		Uses:       make(map[*ast.Ident]types.Object),
+		Implicits:  make(map[ast.Node]types.Object),
+		Scopes:     make(map[ast.Node]*types.Scope),
+		Selections: make(map[*ast.SelectorExpr]*types.Selection),
+	}
+	typeparams.InitInstanceInfo(info)
+	conf := types.Config{}
+	if _, err := conf.Check("p", fset, files, info); err != nil {
+		t.Fatal(err) // type error
+	}
+
+	// gather constraints
+	var finder satisfy.Finder
+	finder.Find(info, files)
+	var constraints []string
+	for c := range finder.Result {
+		constraints = append(constraints, fmt.Sprintf("%v <- %v", c.LHS, c.RHS))
+	}
+	sort.Strings(constraints)
+	return constraints
+}