gopls/internal/util/fingerprint/fingerprint: unify type params

Enhance Matches to observe the bindings of type parameters.

Previously, each occurrence of a type parameter matched any type.
For example, matching these two signatures:

    func f[T any](T, T)
    func g(int, bool)

succeeded even though g is not an instantiation of f.

This CL tracks the bindings of type parameters, so that the above
match fails but matching f with this function:

    func h(int, int)

succeeds.

Change-Id: Ia1ed653b24168d8e307593ca98d7c151b9dbb458
Reviewed-on: https://go-review.googlesource.com/c/tools/+/655995
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/gopls/internal/util/fingerprint/fingerprint.go b/gopls/internal/util/fingerprint/fingerprint.go
index 2b657ba..22817e4 100644
--- a/gopls/internal/util/fingerprint/fingerprint.go
+++ b/gopls/internal/util/fingerprint/fingerprint.go
@@ -338,44 +338,126 @@
 	}
 }
 
-// unify reports whether the types of methods x and y match, in the
-// presence of type parameters, each of which matches anything at all.
-// (It's not true unification as we don't track substitutions.)
-//
-// TODO(adonovan): implement full unification.
+// unify reports whether x and y match, in the presence of type parameters.
+// The constraints on type parameters are ignored, but each type parameter must
+// have a consistent binding.
 func unify(x, y sexpr) bool {
-	if isTypeParam(x) >= 0 || isTypeParam(y) >= 0 {
-		return true // a type parameter matches anything
+
+	// maxTypeParam returns the maximum type parameter index in x.
+	var maxTypeParam func(x sexpr) int
+	maxTypeParam = func(x sexpr) int {
+		if i := typeParamIndex(x); i >= 0 {
+			return i
+		}
+		if c, ok := x.(*cons); ok {
+			return max(maxTypeParam(c.car), maxTypeParam(c.cdr))
+		}
+		return 0
 	}
-	if reflect.TypeOf(x) != reflect.TypeOf(y) {
-		return false // type mismatch
+
+	// xBindings[i] is the binding for type parameter #i in x, and similarly for y.
+	// Although type parameters are nominally bound to sexprs, each bindings[i]
+	// is a *sexpr, so unbound variables can share a binding.
+	xBindings := make([]*sexpr, maxTypeParam(x)+1)
+	for i := range len(xBindings) {
+		xBindings[i] = new(sexpr)
 	}
-	switch x := x.(type) {
-	case nil, string, int, symbol:
-		return x == y
-	case *cons:
-		y := y.(*cons)
-		if !unify(x.car, y.car) {
+	yBindings := make([]*sexpr, maxTypeParam(y)+1)
+	for i := range len(yBindings) {
+		yBindings[i] = new(sexpr)
+	}
+
+	// bind sets binding b to s from bindings if it does not occur in s.
+	bind := func(b *sexpr, s sexpr, bindings []*sexpr) bool {
+		// occurs reports whether b is present in s.
+		var occurs func(s sexpr) bool
+		occurs = func(s sexpr) bool {
+			if j := typeParamIndex(s); j >= 0 {
+				return b == bindings[j]
+			}
+			if c, ok := s.(*cons); ok {
+				return occurs(c.car) || occurs(c.cdr)
+			}
 			return false
 		}
-		if x.cdr == nil {
-			return y.cdr == nil
-		}
-		if y.cdr == nil {
+
+		if occurs(s) {
 			return false
 		}
-		return unify(x.cdr, y.cdr)
-	default:
-		panic(fmt.Sprintf("unify %T %T", x, y))
+		*b = s
+		return true
 	}
+
+	var uni func(x, y sexpr) bool
+	uni = func(x, y sexpr) bool {
+		var bx, by *sexpr
+		ix := typeParamIndex(x)
+		if ix >= 0 {
+			bx = xBindings[ix]
+		}
+		iy := typeParamIndex(y)
+		if iy >= 0 {
+			by = yBindings[iy]
+		}
+
+		if bx != nil || by != nil {
+			// If both args are type params and neither is bound, have them share a binding.
+			if bx != nil && by != nil && *bx == nil && *by == nil {
+				xBindings[ix] = yBindings[iy]
+				return true
+			}
+			// Treat param bindings like original args in what follows.
+			if bx != nil && *bx != nil {
+				x = *bx
+			}
+			if by != nil && *by != nil {
+				y = *by
+			}
+			// If the x param is unbound, bind it to y.
+			if bx != nil && *bx == nil {
+				return bind(bx, y, yBindings)
+			}
+			// If the y param is unbound, bind it to x.
+			if by != nil && *by == nil {
+				return bind(by, x, xBindings)
+			}
+			// Unify the binding of a bound parameter.
+			return uni(x, y)
+		}
+
+		// Neither arg is a type param.
+		if reflect.TypeOf(x) != reflect.TypeOf(y) {
+			return false // type mismatch
+		}
+		switch x := x.(type) {
+		case nil, string, int, symbol:
+			return x == y
+		case *cons:
+			y := y.(*cons)
+			if !uni(x.car, y.car) {
+				return false
+			}
+			if x.cdr == nil {
+				return y.cdr == nil
+			}
+			if y.cdr == nil {
+				return false
+			}
+			return uni(x.cdr, y.cdr)
+		default:
+			panic(fmt.Sprintf("unify %T %T", x, y))
+		}
+	}
+	// At least one param is bound. Unify its binding with the other.
+	return uni(x, y)
 }
 
-// isTypeParam returns the index of the type parameter,
+// typeParamIndex returns the index of the type parameter,
 // if x has the form "(typeparam INTEGER)", otherwise -1.
-func isTypeParam(x sexpr) int {
+func typeParamIndex(x sexpr) int {
 	if x, ok := x.(*cons); ok {
 		if sym, ok := x.car.(symbol); ok && sym == symTypeparam {
-			return 0
+			return x.cdr.(*cons).car.(int)
 		}
 	}
 	return -1
diff --git a/gopls/internal/util/fingerprint/fingerprint_test.go b/gopls/internal/util/fingerprint/fingerprint_test.go
index 7a7a2fe..737c689 100644
--- a/gopls/internal/util/fingerprint/fingerprint_test.go
+++ b/gopls/internal/util/fingerprint/fingerprint_test.go
@@ -104,6 +104,7 @@
 func C3(int, bool, ...string) rune
 func C4(int, bool, ...string)
 func C5(int, float64, bool, string) bool
+func C6(int, bool, ...string) bool
 
 func DAny[T any](Named[T]) { panic(0) }
 func DString(Named[string])
@@ -114,6 +115,17 @@
 func E1(byte) rune
 func E2(uint8) int32
 func E3(int8) uint32
+
+// generic vs. generic
+func F1[T any](T) { panic(0) }
+func F2[T any](*T) { panic(0) }
+func F3[T any](T, T) { panic(0) }
+func F4[U any](U, *U) {panic(0) }
+func F5[T, U any](T, U, U) { panic(0) }
+func F6[T any](T, int, T) { panic(0) }
+func F7[T any](bool, T, T) { panic(0) }
+func F8[V any](*V, int, int) { panic(0) }
+func F9[V any](V, *V, V) { panic(0) }
 `
 	pkg := testfiles.LoadPackages(t, txtar.Parse([]byte(src)), "./a")[0]
 	scope := pkg.Types.Scope()
@@ -128,11 +140,12 @@
 		{"B", "String", "", true},
 		{"B", "Int", "", true},
 		{"B", "A", "", true},
-		{"C1", "C2", "", true}, // matches despite inconsistent substitution
-		{"C1", "C3", "", true},
+		{"C1", "C2", "", false},
+		{"C1", "C3", "", false},
 		{"C1", "C4", "", false},
 		{"C1", "C5", "", false},
-		{"C2", "C3", "", false}, // intransitive (C1≡C2 ^ C1≡C3)
+		{"C1", "C6", "", true},
+		{"C2", "C3", "", false},
 		{"C2", "C4", "", false},
 		{"C3", "C4", "", false},
 		{"DAny", "DString", "", true},
@@ -140,6 +153,13 @@
 		{"DString", "DInt", "", false}, // different instantiations of Named
 		{"E1", "E2", "", true},         // byte and rune are just aliases
 		{"E2", "E3", "", false},
+		// The following tests cover all of the type param cases of unify.
+		{"F1", "F2", "", true},  // F1[*int] = F2[int]
+		{"F3", "F4", "", false}, // would require U identical to *U, prevented by occur check
+		{"F5", "F6", "", true},  // one param is bound, the other is not
+		{"F6", "F7", "", false}, // both are bound
+		{"F5", "F8", "", true},  // T=*int, U=int, V=int
+		{"F5", "F9", "", false}, // T is unbound, V is bound, and T occurs in V
 	} {
 		lookup := func(name string) types.Type {
 			obj := scope.Lookup(name)
@@ -155,20 +175,30 @@
 			return obj.Type()
 		}
 
-		a := lookup(test.a)
-		b := lookup(test.b)
+		check := func(sa, sb string, want bool) {
+			t.Helper()
 
-		afp, _ := fingerprint.Encode(a)
-		bfp, _ := fingerprint.Encode(b)
+			a := lookup(sa)
+			b := lookup(sb)
 
-		atree := fingerprint.Parse(afp)
-		btree := fingerprint.Parse(bfp)
+			afp, _ := fingerprint.Encode(a)
+			bfp, _ := fingerprint.Encode(b)
 
-		got := fingerprint.Matches(atree, btree)
-		if got != test.want {
-			t.Errorf("a=%s b=%s method=%s: unify returned %t for these inputs:\n- %s\n- %s",
-				test.a, test.b, test.method,
-				got, atree, btree)
+			atree := fingerprint.Parse(afp)
+			btree := fingerprint.Parse(bfp)
+
+			got := fingerprint.Matches(atree, btree)
+			if got != want {
+				t.Errorf("a=%s b=%s method=%s: unify returned %t for these inputs:\n- %s\n- %s",
+					sa, sb, test.method, got, a, b)
+			}
 		}
+
+		check(test.a, test.b, test.want)
+		// Matches is symmetric
+		check(test.b, test.a, test.want)
+		// Matches is reflexive
+		check(test.a, test.a, true)
+		check(test.b, test.b, true)
 	}
 }