cmd/compile/internal/types2: add support for inferring type instances

This is an essentially clean port of CL 356489 from go/types to types2,
with minor adjustments due to the different AST packages and error
reporting.

Fixes #47990.

Change-Id: I52187872474bfc1fb49eb77905f22fc820b7295b
Reviewed-on: https://go-review.googlesource.com/c/go/+/356514
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
diff --git a/src/cmd/compile/internal/types2/call.go b/src/cmd/compile/internal/types2/call.go
index 8b45b28..3859e39 100644
--- a/src/cmd/compile/internal/types2/call.go
+++ b/src/cmd/compile/internal/types2/call.go
@@ -57,7 +57,7 @@
 	}
 
 	// instantiate function signature
-	res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature)
+	res := check.instantiateSignature(x.Pos(), sig, targs, poslist)
 	assert(res.TypeParams().Len() == 0) // signature is not generic anymore
 	check.recordInstance(inst.X, targs, res)
 	x.typ = res
@@ -65,6 +65,34 @@
 	x.expr = inst
 }
 
+func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, posList []syntax.Pos) (res *Signature) {
+	assert(check != nil)
+	assert(len(targs) == typ.TypeParams().Len())
+
+	if check.conf.Trace {
+		check.trace(pos, "-- instantiating %s with %s", typ, targs)
+		check.indent++
+		defer func() {
+			check.indent--
+			check.trace(pos, "=> %s (under = %s)", res, res.Underlying())
+		}()
+	}
+
+	inst := check.instance(pos, typ, targs, check.conf.Context).(*Signature)
+	assert(len(posList) <= len(targs))
+	tparams := typ.TypeParams().list()
+	if i, err := check.verify(pos, tparams, targs); err != nil {
+		// best position for error reporting
+		pos := pos
+		if i < len(posList) {
+			pos = posList[i]
+		}
+		check.softErrorf(pos, err.Error())
+	}
+
+	return inst
+}
+
 func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
 	var inst *syntax.IndexExpr // function instantiation, if any
 	if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil {
@@ -345,7 +373,7 @@
 		}
 
 		// compute result signature
-		rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature)
+		rsig = check.instantiateSignature(call.Pos(), sig, targs, nil)
 		assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
 		check.recordInstance(call.Fun, targs, rsig)
 
diff --git a/src/cmd/compile/internal/types2/instantiate.go b/src/cmd/compile/internal/types2/instantiate.go
index 5e45ea3..d6cefb4 100644
--- a/src/cmd/compile/internal/types2/instantiate.go
+++ b/src/cmd/compile/internal/types2/instantiate.go
@@ -49,55 +49,6 @@
 	return inst, err
 }
 
-// instantiate creates an instance and defers verification of constraints to
-// later in the type checking pass. For Named types the resulting instance will
-// be unexpanded.
-func (check *Checker) instantiate(pos syntax.Pos, typ Type, targs []Type, posList []syntax.Pos) (res Type) {
-	assert(check != nil)
-	if check.conf.Trace {
-		check.trace(pos, "-- instantiating %s with %s", typ, NewTypeList(targs))
-		check.indent++
-		defer func() {
-			check.indent--
-			var under Type
-			if res != nil {
-				// Calling under() here may lead to endless instantiations.
-				// Test case: type T[P any] T[P]
-				under = safeUnderlying(res)
-			}
-			check.trace(pos, "=> %s (under = %s)", res, under)
-		}()
-	}
-
-	inst := check.instance(pos, typ, targs, check.conf.Context)
-
-	assert(len(posList) <= len(targs))
-	check.later(func() {
-		// Collect tparams again because lazily loaded *Named types may not have
-		// had tparams set up above.
-		var tparams []*TypeParam
-		switch t := typ.(type) {
-		case *Named:
-			tparams = t.TypeParams().list()
-		case *Signature:
-			tparams = t.TypeParams().list()
-		}
-		// Avoid duplicate errors; instantiate will have complained if tparams
-		// and targs do not have the same length.
-		if len(tparams) == len(targs) {
-			if i, err := check.verify(pos, tparams, targs); err != nil {
-				// best position for error reporting
-				pos := pos
-				if i < len(posList) {
-					pos = posList[i]
-				}
-				check.softErrorf(pos, err.Error())
-			}
-		}
-	})
-	return inst
-}
-
 // instance creates a type or function instance using the given original type
 // typ and arguments targs. For Named types the resulting instance will be
 // unexpanded.
diff --git a/src/cmd/compile/internal/types2/named.go b/src/cmd/compile/internal/types2/named.go
index 3971e03..8f2a52b 100644
--- a/src/cmd/compile/internal/types2/named.go
+++ b/src/cmd/compile/internal/types2/named.go
@@ -239,7 +239,8 @@
 
 	check := n.check
 
-	if check.validateTArgLen(instPos, n.orig.tparams.Len(), n.targs.Len()) {
+	// Mismatching arg and tparam length may be checked elsewhere.
+	if n.orig.tparams.Len() == n.targs.Len() {
 		// We must always have a context, to avoid infinite recursion.
 		ctxt = check.bestContext(ctxt)
 		h := ctxt.TypeHash(n.orig, n.targs.list())
diff --git a/src/cmd/compile/internal/types2/testdata/check/tinference.go2 b/src/cmd/compile/internal/types2/testdata/check/funcinference.go2
similarity index 82%
rename from src/cmd/compile/internal/types2/testdata/check/tinference.go2
rename to src/cmd/compile/internal/types2/testdata/check/funcinference.go2
index 2409fef..7160e18 100644
--- a/src/cmd/compile/internal/types2/testdata/check/tinference.go2
+++ b/src/cmd/compile/internal/types2/testdata/check/funcinference.go2
@@ -2,26 +2,25 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package tinferenceB
+package funcInference
 
 import "strconv"
 
 type any interface{}
 
-// Embedding stand-alone type parameters is not permitted for now. Disabled.
-// func f0[A any, B interface{~C}, C interface{~D}, D interface{~A}](A, B, C, D)
-// func _() {
-// 	f := f0[string]
-// 	f("a", "b", "c", "d")
-// 	f0("a", "b", "c", "d")
-// }
-//
-// func f1[A any, B interface{~A}](A, B)
-// func _() {
-// 	f := f1[int]
-// 	f(int(0), int(0))
-// 	f1(int(0), int(0))
-// }
+func f0[A any, B interface{~*C}, C interface{~*D}, D interface{~*A}](A, B, C, D) {}
+func _() {
+	f := f0[string]
+	f("a", nil, nil, nil)
+	f0("a", nil, nil, nil)
+}
+
+func f1[A any, B interface{~*A}](A, B) {}
+func _() {
+	f := f1[int]
+	f(int(0), new(int))
+	f1(int(0), new(int))
+}
 
 func f2[A any, B interface{~[]A}](A, B) {}
 func _() {
diff --git a/src/cmd/compile/internal/types2/testdata/check/typeinference.go2 b/src/cmd/compile/internal/types2/testdata/check/typeinference.go2
new file mode 100644
index 0000000..8876cca
--- /dev/null
+++ b/src/cmd/compile/internal/types2/testdata/check/typeinference.go2
@@ -0,0 +1,47 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package typeInference
+
+// basic inference
+type Tb[P ~*Q, Q any] int
+func _() {
+	var x Tb[*int]
+	var y Tb[*int, int]
+	x = y
+	_ = x
+}
+
+// recursive inference
+type Tr[A any, B ~*C, C ~*D, D ~*A] int
+func _() {
+	var x Tr[string]
+	var y Tr[string, ***string, **string, *string]
+	var z Tr[int, ***int, **int, *int]
+	x = y
+	x = z // ERROR cannot use z .* as Tr
+	_ = x
+}
+
+// other patterns of inference
+type To0[A any, B ~[]A] int
+type To1[A any, B ~struct{a A}] int
+type To2[A any, B ~[][]A] int
+type To3[A any, B ~[3]*A] int
+type To4[A any, B any, C ~struct{a A; b B}] int
+func _() {
+	var _ To0[int]
+	var _ To1[int]
+	var _ To2[int]
+	var _ To3[int]
+	var _ To4[int, string]
+}
+
+// failed inference
+type Tf0[A, B any] int
+type Tf1[A any, B ~struct{a A; c C}, C any] int
+func _() {
+	var _ Tf0 /* ERROR cannot infer B */ /* ERROR got 1 arguments but 2 type parameters */ [int]
+	var _ Tf1 /* ERROR cannot infer B */ /* ERROR got 1 arguments but 3 type parameters */ [int]
+}
diff --git a/src/cmd/compile/internal/types2/testdata/check/typeinst2.go2 b/src/cmd/compile/internal/types2/testdata/check/typeinst2.go2
index 49f48c7..ecf4b0f 100644
--- a/src/cmd/compile/internal/types2/testdata/check/typeinst2.go2
+++ b/src/cmd/compile/internal/types2/testdata/check/typeinst2.go2
@@ -252,3 +252,7 @@
 var _ = f0_[bool /* ERROR does not satisfy I0_ */ ]
 var _ = f0_[string /* ERROR does not satisfy I0_ */ ]
 var _ = f0_[float64 /* ERROR does not satisfy I0_ */ ]
+
+// Using a function instance as a type is an error.
+var _ f0 // ERROR not a type
+var _ f0 /* ERROR not a type */ [int]
diff --git a/src/cmd/compile/internal/types2/testdata/check/typeinstcycles.go2 b/src/cmd/compile/internal/types2/testdata/check/typeinstcycles.go2
new file mode 100644
index 0000000..74fe191
--- /dev/null
+++ b/src/cmd/compile/internal/types2/testdata/check/typeinstcycles.go2
@@ -0,0 +1,11 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+import "unsafe"
+
+func F1[T any](_ [unsafe.Sizeof(F1[int])]T) (res T)      { return }
+func F2[T any](_ T) (res [unsafe.Sizeof(F2[string])]int) { return }
+func F3[T any](_ [unsafe.Sizeof(F1[string])]int)         {}
diff --git a/src/cmd/compile/internal/types2/typexpr.go b/src/cmd/compile/internal/types2/typexpr.go
index eae9330..20e56ca 100644
--- a/src/cmd/compile/internal/types2/typexpr.go
+++ b/src/cmd/compile/internal/types2/typexpr.go
@@ -396,13 +396,24 @@
 	return typ
 }
 
-func (check *Checker) instantiatedType(x syntax.Expr, targsx []syntax.Expr, def *Named) Type {
+func (check *Checker) instantiatedType(x syntax.Expr, targsx []syntax.Expr, def *Named) (res Type) {
+	if check.conf.Trace {
+		check.trace(x.Pos(), "-- instantiating %s with %s", x, targsx)
+		check.indent++
+		defer func() {
+			check.indent--
+			// Don't format the underlying here. It will always be nil.
+			check.trace(x.Pos(), "=> %s", res)
+		}()
+	}
+
 	gtyp := check.genericType(x, true)
 	if gtyp == Typ[Invalid] {
 		return gtyp // error already reported
 	}
-	base, _ := gtyp.(*Named)
-	if base == nil {
+
+	origin, _ := gtyp.(*Named)
+	if origin == nil {
 		panic(fmt.Sprintf("%v: cannot instantiate %v", x.Pos(), gtyp))
 	}
 
@@ -416,20 +427,67 @@
 	// determine argument positions
 	posList := make([]syntax.Pos, len(targs))
 	for i, arg := range targsx {
-		posList[i] = syntax.StartPos(arg)
+		posList[i] = arg.Pos()
 	}
 
-	typ := check.instantiate(x.Pos(), base, targs, posList)
-	def.setUnderlying(typ)
-	check.recordInstance(x, targs, typ)
+	// create the instance
+	h := check.conf.Context.TypeHash(origin, targs)
+	// targs may be incomplete, and require inference. In any case we should de-duplicate.
+	inst := check.conf.Context.typeForHash(h, nil)
+	// If inst is non-nil, we can't just return here. Inst may have been
+	// constructed via recursive substitution, in which case we wouldn't do the
+	// validation below. Ensure that the validation (and resulting errors) runs
+	// for each instantiated type in the source.
+	if inst == nil {
+		tname := NewTypeName(x.Pos(), origin.obj.pkg, origin.obj.name, nil)
+		inst = check.newNamed(tname, origin, nil, nil, nil) // underlying, methods and tparams are set when named is resolved
+		inst.targs = NewTypeList(targs)
+		inst = check.conf.Context.typeForHash(h, inst)
+	}
+	def.setUnderlying(inst)
 
-	// make sure we check instantiation works at least once
-	// and that the resulting type is valid
+	inst.resolver = func(ctxt *Context, n *Named) (*TypeParamList, Type, []*Func) {
+		tparams := origin.TypeParams().list()
+
+		inferred := targs
+		if len(targs) < len(tparams) {
+			// If inference fails, len(inferred) will be 0, and inst.underlying will
+			// be set to Typ[Invalid] in expandNamed.
+			inferred = check.infer(x.Pos(), tparams, targs, nil, nil)
+			if len(inferred) > len(targs) {
+				inst.targs = NewTypeList(inferred)
+			}
+		}
+
+		check.recordInstance(x, inferred, inst)
+		return expandNamed(ctxt, n, x.Pos())
+	}
+
+	// origin.tparams may not be set up, so we need to do expansion later.
 	check.later(func() {
-		check.validType(typ, nil)
+		// This is an instance from the source, not from recursive substitution,
+		// and so it must be resolved during type-checking so that we can report
+		// errors.
+		inst.resolve(check.conf.Context)
+		// Since check is non-nil, we can still mutate inst. Unpinning the resolver
+		// frees some memory.
+		inst.resolver = nil
+
+		if check.validateTArgLen(x.Pos(), inst.tparams.Len(), inst.targs.Len()) {
+			if i, err := check.verify(x.Pos(), inst.tparams.list(), inst.targs.list()); err != nil {
+				// best position for error reporting
+				pos := x.Pos()
+				if i < len(posList) {
+					pos = posList[i]
+				}
+				check.softErrorf(pos, err.Error())
+			}
+		}
+
+		check.validType(inst, nil)
 	})
 
-	return typ
+	return inst
 }
 
 // arrayLength type-checks the array length expression e