go/types/typeutil: add support for mapping generic types

Add support to the typeutil package for hashing the new types produced
when type-checking generic code.

Change-Id: I05a213baee80c53c673442f3c28fddb26ad0b03f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/366614
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Tim King <taking@google.com>
diff --git a/go/types/typeutil/map.go b/go/types/typeutil/map.go
index c7f7545..490ee90 100644
--- a/go/types/typeutil/map.go
+++ b/go/types/typeutil/map.go
@@ -11,6 +11,8 @@
 	"fmt"
 	"go/types"
 	"reflect"
+
+	"golang.org/x/tools/internal/typeparams"
 )
 
 // Map is a hash-table-based mapping from types (types.Type) to
@@ -211,11 +213,29 @@
 // Call MakeHasher to create a Hasher.
 type Hasher struct {
 	memo map[types.Type]uint32
+
+	// ptrMap records pointer identity.
+	ptrMap map[interface{}]uint32
+
+	// sigTParams holds type parameters from the signature being hashed.
+	// Signatures are considered identical modulo renaming of type parameters, so
+	// within the scope of a signature type the identity of the signature's type
+	// parameters is just their index.
+	//
+	// Since the language does not currently support referring to uninstantiated
+	// generic types or functions, and instantiated signatures do not have type
+	// parameter lists, we should never encounter a second non-empty type
+	// parameter list when hashing a generic signature.
+	sigTParams *typeparams.TypeParamList
 }
 
 // MakeHasher returns a new Hasher instance.
 func MakeHasher() Hasher {
-	return Hasher{make(map[types.Type]uint32)}
+	return Hasher{
+		memo:       make(map[types.Type]uint32),
+		ptrMap:     make(map[interface{}]uint32),
+		sigTParams: nil,
+	}
 }
 
 // Hash computes a hash value for the given type t such that
@@ -273,17 +293,62 @@
 		if t.Variadic() {
 			hash *= 8863
 		}
+
+		// Use a separate hasher for types inside of the signature, where type
+		// parameter identity is modified to be (index, constraint). We must use a
+		// new memo for this hasher as type identity may be affected by this
+		// masking. For example, in func[T any](*T), the identity of *T depends on
+		// whether we are mapping the argument in isolation, or recursively as part
+		// of hashing the signature.
+		//
+		// We should never encounter a generic signature while hashing another
+		// generic signature, but defensively set sigTParams only if h.mask is
+		// unset.
+		tparams := typeparams.ForSignature(t)
+		if h.sigTParams == nil && tparams.Len() != 0 {
+			h = Hasher{
+				// There may be something more efficient than discarding the existing
+				// memo, but it would require detecting whether types are 'tainted' by
+				// references to type parameters.
+				memo: make(map[types.Type]uint32),
+				// Re-using ptrMap ensures that pointer identity is preserved in this
+				// hasher.
+				ptrMap:     h.ptrMap,
+				sigTParams: tparams,
+			}
+		}
+
+		for i := 0; i < tparams.Len(); i++ {
+			tparam := tparams.At(i)
+			hash += 7 * h.Hash(tparam.Constraint())
+		}
+
 		return hash + 3*h.hashTuple(t.Params()) + 5*h.hashTuple(t.Results())
 
+	case *typeparams.Union:
+		return h.hashUnion(t)
+
 	case *types.Interface:
+		// Interfaces are identical if they have the same set of methods, with
+		// identical names and types, and they have the same set of type
+		// restrictions. See go/types.identical for more details.
 		var hash uint32 = 9103
+
+		// Hash methods.
 		for i, n := 0, t.NumMethods(); i < n; i++ {
-			// See go/types.identicalMethods for rationale.
 			// Method order is not significant.
 			// Ignore m.Pkg().
 			m := t.Method(i)
 			hash += 3*hashString(m.Name()) + 5*h.Hash(m.Type())
 		}
+
+		// Hash type restrictions.
+		terms, err := typeparams.InterfaceTermSet(t)
+		// if err != nil t has invalid type restrictions.
+		if err == nil {
+			hash += h.hashTermSet(terms)
+		}
+
 		return hash
 
 	case *types.Map:
@@ -293,13 +358,22 @@
 		return 9127 + 2*uint32(t.Dir()) + 3*h.Hash(t.Elem())
 
 	case *types.Named:
-		// Not safe with a copying GC; objects may move.
-		return uint32(reflect.ValueOf(t.Obj()).Pointer())
+		hash := h.hashPtr(t.Obj())
+		targs := typeparams.NamedTypeArgs(t)
+		for i := 0; i < targs.Len(); i++ {
+			targ := targs.At(i)
+			hash += 2 * h.Hash(targ)
+		}
+		return hash
+
+	case *typeparams.TypeParam:
+		return h.hashTypeParam(t)
 
 	case *types.Tuple:
 		return h.hashTuple(t)
 	}
-	panic(t)
+
+	panic(fmt.Sprintf("%T: %v", t, t))
 }
 
 func (h Hasher) hashTuple(tuple *types.Tuple) uint32 {
@@ -311,3 +385,57 @@
 	}
 	return hash
 }
+
+func (h Hasher) hashUnion(t *typeparams.Union) uint32 {
+	// Hash type restrictions.
+	terms, err := typeparams.UnionTermSet(t)
+	// if err != nil t has invalid type restrictions. Fall back on a non-zero
+	// hash.
+	if err != nil {
+		return 9151
+	}
+	return h.hashTermSet(terms)
+}
+
+func (h Hasher) hashTermSet(terms []*typeparams.Term) uint32 {
+	var hash uint32 = 9157 + 2*uint32(len(terms))
+	for _, term := range terms {
+		// term order is not significant.
+		termHash := h.Hash(term.Type())
+		if term.Tilde() {
+			termHash *= 9161
+		}
+		hash += 3 * termHash
+	}
+	return hash
+}
+
+// hashTypeParam returns a hash of the type parameter t, with a hash value
+// depending on whether t is contained in h.sigTParams.
+//
+// If h.sigTParams is set and contains t, then we are in the process of hashing
+// a signature, and the hash value of t must depend only on t's index and
+// constraint: signatures are considered identical modulo type parameter
+// renaming.
+//
+// Otherwise the hash of t depends only on t's pointer identity.
+func (h Hasher) hashTypeParam(t *typeparams.TypeParam) uint32 {
+	if h.sigTParams != nil {
+		i := t.Index()
+		if i >= 0 && i < h.sigTParams.Len() && t == h.sigTParams.At(i) {
+			return 9173 + 2*h.Hash(t.Constraint()) + 3*uint32(i)
+		}
+	}
+	return h.hashPtr(t.Obj())
+}
+
+// hashPtr hashes the pointer identity of ptr. It uses h.ptrMap to ensure that
+// pointers values are not dependent on the GC.
+func (h Hasher) hashPtr(ptr interface{}) uint32 {
+	if hash, ok := h.ptrMap[ptr]; ok {
+		return hash
+	}
+	hash := uint32(reflect.ValueOf(ptr).Pointer())
+	h.ptrMap[ptr] = hash
+	return hash
+}
diff --git a/go/types/typeutil/map_test.go b/go/types/typeutil/map_test.go
index d4b0f63..17f87ed 100644
--- a/go/types/typeutil/map_test.go
+++ b/go/types/typeutil/map_test.go
@@ -10,10 +10,14 @@
 //   (e.g. all types generated by type-checking some body of real code).
 
 import (
+	"go/ast"
+	"go/parser"
+	"go/token"
 	"go/types"
 	"testing"
 
 	"golang.org/x/tools/go/types/typeutil"
+	"golang.org/x/tools/internal/typeparams"
 )
 
 var (
@@ -172,3 +176,190 @@
 		t.Errorf("Len(): got %q, want %q", s, "")
 	}
 }
+
+func TestMapGenerics(t *testing.T) {
+	if !typeparams.Enabled {
+		t.Skip("type params are not enabled at this Go version")
+	}
+
+	const src = `
+package p
+
+// Basic defined types.
+type T1 int
+type T2 int
+
+// Identical methods.
+func (T1) M(int) {}
+func (T2) M(int) {}
+
+// A constraint interface.
+type C interface {
+	~int | string
+}
+
+type I interface {
+}
+
+// A generic type.
+type G[P C] int
+
+// Generic functions with identical signature.
+func Fa1[P C](p P) {}
+func Fa2[Q C](q Q) {}
+
+// Fb1 and Fb2 are identical and should be mapped to the same entry, even if we
+// map their arguments first.
+func Fb1[P any](x *P) {
+	var y *P // Map this first.
+	_ = y
+}
+func Fb2[Q any](x *Q) {
+}
+
+// G1 and G2 are mutally recursive, and have identical methods.
+type G1[P any] struct{
+	Field *G2[P]
+}
+func (G1[P]) M(G1[P], G2[P]) {}
+type G2[Q any] struct{
+	Field *G1[Q]
+}
+func (G2[P]) M(G1[P], G2[P]) {}
+
+// Method type expressions on different generic types are different.
+var ME1 = G1[int].M
+var ME2 = G2[int].M
+
+// ME1Type should have identical type as ME1.
+var ME1Type func(G1[int], G1[int], G2[int])
+`
+
+	fset := token.NewFileSet()
+	file, err := parser.ParseFile(fset, "p.go", src, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var conf types.Config
+	pkg, err := conf.Check("", fset, []*ast.File{file}, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Collect types.
+	scope := pkg.Scope()
+	var (
+		T1      = scope.Lookup("T1").Type().(*types.Named)
+		T2      = scope.Lookup("T2").Type().(*types.Named)
+		T1M     = T1.Method(0).Type()
+		T2M     = T2.Method(0).Type()
+		G       = scope.Lookup("G").Type()
+		GInt1   = instantiate(t, G, types.Typ[types.Int])
+		GInt2   = instantiate(t, G, types.Typ[types.Int])
+		GStr    = instantiate(t, G, types.Typ[types.String])
+		C       = scope.Lookup("C").Type()
+		CI      = C.Underlying().(*types.Interface)
+		I       = scope.Lookup("I").Type()
+		II      = I.Underlying().(*types.Interface)
+		U       = CI.EmbeddedType(0).(*typeparams.Union)
+		Fa1     = scope.Lookup("Fa1").Type().(*types.Signature)
+		Fa2     = scope.Lookup("Fa2").Type().(*types.Signature)
+		Fa1P    = typeparams.ForSignature(Fa1).At(0)
+		Fa2Q    = typeparams.ForSignature(Fa2).At(0)
+		Fb1     = scope.Lookup("Fb1").Type().(*types.Signature)
+		Fb1x    = Fb1.Params().At(0).Type()
+		Fb1y    = scope.Lookup("Fb1").(*types.Func).Scope().Lookup("y").Type()
+		Fb2     = scope.Lookup("Fb2").Type().(*types.Signature)
+		Fb2x    = Fb2.Params().At(0).Type()
+		G1      = scope.Lookup("G1").Type().(*types.Named)
+		G1M     = G1.Method(0).Type()
+		G1IntM1 = instantiate(t, G1, types.Typ[types.Int]).(*types.Named).Method(0).Type()
+		G1IntM2 = instantiate(t, G1, types.Typ[types.Int]).(*types.Named).Method(0).Type()
+		G1StrM  = instantiate(t, G1, types.Typ[types.String]).(*types.Named).Method(0).Type()
+		G2      = scope.Lookup("G2").Type()
+		// See below.
+		// G2M     = G2.Method(0).Type()
+		G2IntM  = instantiate(t, G2, types.Typ[types.Int]).(*types.Named).Method(0).Type()
+		ME1     = scope.Lookup("ME1").Type()
+		ME1Type = scope.Lookup("ME1Type").Type()
+		ME2     = scope.Lookup("ME2").Type()
+	)
+
+	tmap := new(typeutil.Map)
+
+	steps := []struct {
+		typ      types.Type
+		name     string
+		newEntry bool
+	}{
+		{T1, "T1", true},
+		{T2, "T2", true},
+		{G, "G", true},
+		{C, "C", true},
+		{CI, "CI", true},
+		{U, "U", true},
+		{I, "I", true},
+		{II, "II", true}, // should not be identical to CI
+
+		// Methods can be identical, even with distinct receivers.
+		{T1M, "T1M", true},
+		{T2M, "T2M", false},
+
+		// Identical instances should map to the same entry.
+		{GInt1, "GInt1", true},
+		{GInt2, "GInt2", false},
+		// ..but instantiating with different arguments should yield a new entry.
+		{GStr, "GStr", true},
+
+		// F1 and F2 should have identical signatures.
+		{Fa1, "F1", true},
+		{Fa2, "F2", false},
+
+		// The identity of P and Q should not have been affected by type parameter
+		// masking during signature hashing.
+		{Fa1P, "F1P", true},
+		{Fa2Q, "F2Q", true},
+
+		{Fb1y, "Fb1y", true},
+		{Fb1x, "Fb1x", false},
+		{Fb2x, "Fb2x", true},
+		{Fb1, "Fb1", true},
+
+		// Mapping elements of the function scope should not affect the identity of
+		// Fb2 or Fb1.
+		{Fb2, "Fb1", false},
+
+		{G1, "G1", true},
+		{G1M, "G1M", true},
+		{G2, "G2", true},
+
+		// See golang/go#49912: receiver type parameter names should be ignored
+		// when comparing method identity.
+		// {G2M, "G2M", false},
+		{G1IntM1, "G1IntM1", true},
+		{G1IntM2, "G1IntM2", false},
+		{G1StrM, "G1StrM", true},
+		{G2IntM, "G2IntM", false}, // identical to G1IntM1
+
+		{ME1, "ME1", true},
+		{ME1Type, "ME1Type", false},
+		{ME2, "ME2", true},
+	}
+
+	for _, step := range steps {
+		existing := tmap.At(step.typ)
+		if (existing == nil) != step.newEntry {
+			t.Errorf("At(%s) = %v, want new entry: %t", step.name, existing, step.newEntry)
+		}
+		tmap.Set(step.typ, step.name)
+	}
+}
+
+func instantiate(t *testing.T, origin types.Type, targs ...types.Type) types.Type {
+	inst, err := typeparams.Instantiate(nil, origin, targs, true)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return inst
+}
diff --git a/internal/typeparams/normalize.go b/internal/typeparams/normalize.go
index f41ec6e..090f142 100644
--- a/internal/typeparams/normalize.go
+++ b/internal/typeparams/normalize.go
@@ -23,9 +23,9 @@
 //
 // Structural type restrictions of a type parameter are created via
 // non-interface types embedded in its constraint interface (directly, or via a
-// chain of interface embeddings). For example, in the declaration `type T[P
-// interface{~int; m()}] int`, the structural restriction of the type parameter
-// P is ~int.
+// chain of interface embeddings). For example, in the declaration
+//  type T[P interface{~int; m()}] int
+// the structural restriction of the type parameter P is ~int.
 //
 // With interface embedding and unions, the specification of structural type
 // restrictions may be arbitrarily complex. For example, consider the
@@ -67,7 +67,31 @@
 	if iface == nil {
 		return nil, fmt.Errorf("constraint is %T, not *types.Interface", constraint.Underlying())
 	}
-	tset, err := computeTermSet(iface, make(map[types.Type]*termSet), 0)
+	return InterfaceTermSet(iface)
+}
+
+// InterfaceTermSet computes the normalized terms for a constraint interface,
+// returning an error if the term set cannot be computed or is empty. In the
+// latter case, the error will be ErrEmptyTypeSet.
+//
+// See the documentation of StructuralTerms for more information on
+// normalization.
+func InterfaceTermSet(iface *types.Interface) ([]*Term, error) {
+	return computeTermSet(iface)
+}
+
+// UnionTermSet computes the normalized terms for a union, returning an error
+// if the term set cannot be computed or is empty. In the latter case, the
+// error will be ErrEmptyTypeSet.
+//
+// See the documentation of StructuralTerms for more information on
+// normalization.
+func UnionTermSet(union *Union) ([]*Term, error) {
+	return computeTermSet(union)
+}
+
+func computeTermSet(typ types.Type) ([]*Term, error) {
+	tset, err := computeTermSetInternal(typ, make(map[types.Type]*termSet), 0)
 	if err != nil {
 		return nil, err
 	}
@@ -98,7 +122,7 @@
 	fmt.Fprintf(os.Stderr, strings.Repeat(".", depth)+format+"\n", args...)
 }
 
-func computeTermSet(t types.Type, seen map[types.Type]*termSet, depth int) (res *termSet, err error) {
+func computeTermSetInternal(t types.Type, seen map[types.Type]*termSet, depth int) (res *termSet, err error) {
 	if t == nil {
 		panic("nil type")
 	}
@@ -139,7 +163,7 @@
 			if _, ok := embedded.Underlying().(*TypeParam); ok {
 				return nil, fmt.Errorf("invalid embedded type %T", embedded)
 			}
-			tset2, err := computeTermSet(embedded, seen, depth+1)
+			tset2, err := computeTermSetInternal(embedded, seen, depth+1)
 			if err != nil {
 				return nil, err
 			}
@@ -153,7 +177,7 @@
 			var terms termlist
 			switch t.Type().Underlying().(type) {
 			case *types.Interface:
-				tset2, err := computeTermSet(t.Type(), seen, depth+1)
+				tset2, err := computeTermSetInternal(t.Type(), seen, depth+1)
 				if err != nil {
 					return nil, err
 				}
diff --git a/internal/typeparams/typeparams_go117.go b/internal/typeparams/typeparams_go117.go
index 6ad3a43..e509daf 100644
--- a/internal/typeparams/typeparams_go117.go
+++ b/internal/typeparams/typeparams_go117.go
@@ -75,6 +75,7 @@
 // this Go version. Its methods panic on use.
 type TypeParam struct{ types.Type }
 
+func (*TypeParam) Index() int             { unsupported(); return 0 }
 func (*TypeParam) Constraint() types.Type { unsupported(); return nil }
 func (*TypeParam) Obj() *types.TypeName   { unsupported(); return nil }