go/ssa: use core type for composite literal addresses

Dereferences using the core type during compLit and when creating
addresses for composite literals.

Also adds a new utility fieldOf for selecting a field from a
type whose core type is a struct.

Change-Id: I2fd0a1caf99819d0b9be5f3ba79a00f8053565e3
Reviewed-on: https://go-review.googlesource.com/c/tools/+/494978
TryBot-Result: Gopher Robot <gobot@golang.org>
gopls-CI: kokoro <noreply+kokoro@google.com>
Run-TryBot: Tim King <taking@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/go/ssa/builder.go b/go/ssa/builder.go
index 772edd6..8931fb4 100644
--- a/go/ssa/builder.go
+++ b/go/ssa/builder.go
@@ -429,12 +429,12 @@
 		return &address{addr: v, pos: e.Pos(), expr: e}
 
 	case *ast.CompositeLit:
-		t, _ := deptr(fn.typeOf(e))
+		typ, _ := deref(fn.typeOf(e))
 		var v *Alloc
 		if escaping {
-			v = emitNew(fn, t, e.Lbrace)
+			v = emitNew(fn, typ, e.Lbrace)
 		} else {
-			v = fn.addLocal(t, e.Lbrace)
+			v = fn.addLocal(typ, e.Lbrace)
 		}
 		v.Comment = "complit"
 		var sb storebuf
@@ -457,8 +457,7 @@
 		wantAddr := true
 		v := b.receiver(fn, e.X, wantAddr, escaping, sel)
 		index := sel.index[len(sel.index)-1]
-		dt, _ := deptr(v.Type())
-		fld := typeparams.CoreType(dt).(*types.Struct).Field(index)
+		fld := fieldOf(mustDeref(v.Type()), index) // v is an addr.
 
 		// Due to the two phases of resolving AssignStmt, a panic from x.f = p()
 		// when x is nil is required to come after the side-effects of
@@ -553,7 +552,7 @@
 		// so if the type of the location is a pointer,
 		// an &-operation is implied.
 		if _, ok := loc.(blank); !ok { // avoid calling blank.typ()
-			if _, ok := deptr(loc.typ()); ok {
+			if _, ok := deref(loc.typ()); ok {
 				ptr := b.addr(fn, e, true).address(fn)
 				// copy address
 				if sb != nil {
@@ -583,7 +582,7 @@
 
 				// Subtle: emit debug ref for aggregate types only;
 				// slice and map are handled by store ops in compLit.
-				switch loc.typ().Underlying().(type) { // TODO(taking): check if Underlying() appropriate.
+				switch typeparams.CoreType(loc.typ()).(type) {
 				case *types.Struct, *types.Array:
 					emitDebugRef(fn, e, addr, true)
 				}
@@ -1253,39 +1252,13 @@
 // literal has type *T behaves like &T{}.
 // In that case, addr must hold a T, not a *T.
 func (b *builder) compLit(fn *Function, addr Value, e *ast.CompositeLit, isZero bool, sb *storebuf) {
-	typ, _ := deptr(fn.typeOf(e))           // type with name [may be type param]
-	t, _ := deptr(typeparams.CoreType(typ)) // core type for comp lit case
-	t = t.Underlying()
-
-	// Computing typ and t is subtle as these handle pointer types.
-	// For example, &T{...} is valid even for maps and slices.
-	// Also typ should refer to T (not *T) while t should be the core type of T.
-	//
-	// To show the ordering to take into account, consider the composite literal
-	// expressions `&T{f: 1}` and `{f: 1}` within the expression `[]S{{f: 1}}` here:
-	//   type N struct{f int}
-	//   func _[T N, S *N]() {
-	//     _ = &T{f: 1}
-	//     _ = []S{{f: 1}}
-	//   }
-	// For `&T{f: 1}`, we compute `typ` and `t` as:
-	//     typeOf(&T{f: 1}) == *T
-	//     deref(*T)        == T (typ)
-	//     CoreType(T)      == N
-	//     deref(N)         == N
-	//     N.Underlying()   == struct{f int} (t)
-	// For `{f: 1}` in `[]S{{f: 1}}`,  we compute `typ` and `t` as:
-	//     typeOf({f: 1})   == S
-	//     deref(S)         == S (typ)
-	//     CoreType(S)      == *N
-	//     deref(*N)        == N
-	//     N.Underlying()   == struct{f int} (t)
-	switch t := t.(type) {
+	typ, _ := deref(fn.typeOf(e)) // type with name [may be type param]
+	switch t := typeparams.CoreType(typ).(type) {
 	case *types.Struct:
 		if !isZero && len(e.Elts) != t.NumFields() {
 			// memclear
-			dt, _ := deptr(addr.Type())
-			sb.store(&address{addr, e.Lbrace, nil}, zeroConst(dt))
+			zt, _ := deref(addr.Type())
+			sb.store(&address{addr, e.Lbrace, nil}, zeroConst(zt))
 			isZero = true
 		}
 		for i, e := range e.Elts {
@@ -1329,8 +1302,8 @@
 
 			if !isZero && int64(len(e.Elts)) != at.Len() {
 				// memclear
-				dt, _ := deptr(array.Type())
-				sb.store(&address{array, e.Lbrace, nil}, zeroConst(dt))
+				zt, _ := deref(array.Type())
+				sb.store(&address{array, e.Lbrace, nil}, zeroConst(zt))
 			}
 		}
 
@@ -1385,7 +1358,7 @@
 			//	map[*struct{}]bool{&struct{}{}: true}
 			wantAddr := false
 			if _, ok := unparen(e.Key).(*ast.CompositeLit); ok {
-				_, wantAddr = t.Key().Underlying().(*types.Pointer)
+				_, wantAddr = deref(t.Key())
 			}
 
 			var key Value
@@ -1416,7 +1389,7 @@
 		sb.store(&address{addr: addr, pos: e.Lbrace, expr: e}, m)
 
 	default:
-		panic("unexpected CompositeLit type: " + t.String())
+		panic("unexpected CompositeLit type: " + typ.String())
 	}
 }
 
diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go
index 7187c35..77de326 100644
--- a/go/ssa/builder_generic_test.go
+++ b/go/ssa/builder_generic_test.go
@@ -607,6 +607,13 @@
 		return u
 	}
 
+	//@ instrs("f1b", "*ssa.Alloc", "new T (complit)")
+	//@ instrs("f1b", "*ssa.FieldAddr", "&t0.x [#0]")
+	func f1b[T ~struct{ x string }]() *T {
+		u := &T{"lorem"}
+		return u
+	}
+
 	//@ instrs("f2", "*ssa.TypeAssert", "typeassert t0.(interface{})")
 	//@ instrs("f2", "*ssa.Call", "invoke x.foo()")
 	func f2[T interface{ foo() string }](x T) {
@@ -628,6 +635,61 @@
 			print(i, v)
 		}
 	}
+
+	//@ instrs("f5", "*ssa.Call", "nil:func()()")
+	func f5() {
+		var f func()
+		f()
+	}
+
+	type S struct{ f int }
+
+	//@ instrs("f6", "*ssa.Alloc", "new [1]P (slicelit)", "new S (complit)")
+	//@ instrs("f6", "*ssa.IndexAddr", "&t0[0:int]")
+	//@ instrs("f6", "*ssa.FieldAddr", "&t2.f [#0]")
+	func f6[P *S]() []P { return []P{{f: 1}} }
+
+	//@ instrs("f7", "*ssa.Alloc", "local S (complit)")
+	//@ instrs("f7", "*ssa.FieldAddr", "&t0.f [#0]")
+	func f7[T any, S struct{f T}](x T) S { return S{f: x} }
+
+	//@ instrs("f8", "*ssa.Alloc", "new [1]P (slicelit)", "new struct{f T} (complit)")
+	//@ instrs("f8", "*ssa.IndexAddr", "&t0[0:int]")
+	//@ instrs("f8", "*ssa.FieldAddr", "&t2.f [#0]")
+	func f8[T any, P *struct{f T}](x T) []P { return []P{{f: x}} }
+
+	//@ instrs("f9", "*ssa.Alloc", "new [1]PS (slicelit)", "new S (complit)")
+	//@ instrs("f9", "*ssa.IndexAddr", "&t0[0:int]")
+	//@ instrs("f9", "*ssa.FieldAddr", "&t2.f [#0]")
+	func f9[T any, S struct{f T}, PS *S](x T) {
+		_ = []PS{{f: x}}
+	}
+
+	//@ instrs("f10", "*ssa.FieldAddr", "&t0.x [#0]")
+	//@ instrs("f10", "*ssa.Store", "*t0 = *new(T):T", "*t1 = 4:int")
+	func f10[T ~struct{ x, y int }]() T {
+		var u T
+		u = T{x: 4}
+		return u
+	}
+
+	//@ instrs("f11", "*ssa.FieldAddr", "&t1.y [#1]")
+	//@ instrs("f11", "*ssa.Store", "*t1 = *new(T):T", "*t2 = 5:int")
+	func f11[T ~struct{ x, y int }, PT *T]() PT {
+		var u PT = new(T)
+		*u = T{y: 5}
+		return u
+	}
+
+	//@ instrs("f12", "*ssa.Alloc", "new struct{f T} (complit)")
+	//@ instrs("f12", "*ssa.MakeMap", "make map[P]bool 1:int")
+	func f12[T any, P *struct{f T}](x T) map[P]bool { return map[P]bool{{}: true} }
+
+	//@ instrs("f13", "&v[0:int]")
+	//@ instrs("f13", "*ssa.Store", "*t0 = 7:int", "*v = *new(A):A")
+	func f13[A [3]int, PA *A](v PA) {
+		*v = A{7}
+	}
 	`
 
 	// Parse
diff --git a/go/ssa/emit.go b/go/ssa/emit.go
index d402e67..80e30b6 100644
--- a/go/ssa/emit.go
+++ b/go/ssa/emit.go
@@ -11,8 +11,6 @@
 	"go/ast"
 	"go/token"
 	"go/types"
-
-	"golang.org/x/tools/internal/typeparams"
 )
 
 // emitNew emits to f a new (heap Alloc) instruction allocating an
@@ -478,9 +476,8 @@
 // value of a field.
 func emitImplicitSelections(f *Function, v Value, indices []int, pos token.Pos) Value {
 	for _, index := range indices {
-		st, vptr := deptr(v.Type())
-		fld := typeparams.CoreType(st).(*types.Struct).Field(index)
-		if vptr {
+		if st, vptr := deptr(v.Type()); vptr {
+			fld := fieldOf(st, index)
 			instr := &FieldAddr{
 				X:     v,
 				Field: index,
@@ -493,6 +490,7 @@
 				v = emitLoad(f, v)
 			}
 		} else {
+			fld := fieldOf(v.Type(), index)
 			instr := &Field{
 				X:     v,
 				Field: index,
@@ -512,15 +510,8 @@
 // field's value.
 // Ident id is used for position and debug info.
 func emitFieldSelection(f *Function, v Value, index int, wantAddr bool, id *ast.Ident) Value {
-	// TODO(taking): Cover the following cases of interest
-	// 	func f[T any, S struct{f T}, P *struct{f T}, PS *S](x T) {
-	// 		_ := S{f: x}
-	//      _ := P{f: x}
-	//      _ := PS{f: x}
-	//  }
-	st, vptr := deptr(v.Type())
-	fld := typeparams.CoreType(st).(*types.Struct).Field(index)
-	if vptr {
+	if st, vptr := deptr(v.Type()); vptr {
+		fld := fieldOf(st, index)
 		instr := &FieldAddr{
 			X:     v,
 			Field: index,
@@ -533,6 +524,7 @@
 			v = emitLoad(f, v)
 		}
 	} else {
+		fld := fieldOf(v.Type(), index)
 		instr := &Field{
 			X:     v,
 			Field: index,
diff --git a/go/ssa/print.go b/go/ssa/print.go
index e47e516..7f34a7b 100644
--- a/go/ssa/print.go
+++ b/go/ssa/print.go
@@ -259,22 +259,19 @@
 }
 
 func (v *FieldAddr) String() string {
-	dt, _ := deptr(v.X.Type())
-	st := typeparams.CoreType(dt).(*types.Struct)
 	// Be robust against a bad index.
 	name := "?"
-	if 0 <= v.Field && v.Field < st.NumFields() {
-		name = st.Field(v.Field).Name()
+	if fld := fieldOf(mustDeref(v.X.Type()), v.Field); fld != nil {
+		name = fld.Name()
 	}
 	return fmt.Sprintf("&%s.%s [#%d]", relName(v.X, v), name, v.Field)
 }
 
 func (v *Field) String() string {
-	st := typeparams.CoreType(v.X.Type()).(*types.Struct)
 	// Be robust against a bad index.
 	name := "?"
-	if 0 <= v.Field && v.Field < st.NumFields() {
-		name = st.Field(v.Field).Name()
+	if fld := fieldOf(v.X.Type(), v.Field); fld != nil {
+		name = fld.Name()
 	}
 	return fmt.Sprintf("%s.%s [#%d]", relName(v.X, v), name, v.Field)
 }
diff --git a/go/ssa/util.go b/go/ssa/util.go
index 53a7487..7735dd8 100644
--- a/go/ssa/util.go
+++ b/go/ssa/util.go
@@ -128,6 +128,17 @@
 	return obj.Type().(*types.Signature).Recv().Type()
 }
 
+// fieldOf returns the index'th field of the (core type of) a struct type;
+// otherwise returns nil.
+func fieldOf(typ types.Type, index int) *types.Var {
+	if st, ok := typeparams.CoreType(typ).(*types.Struct); ok {
+		if 0 <= index && index < st.NumFields() {
+			return st.Field(index)
+		}
+	}
+	return nil
+}
+
 // isUntyped returns true for types that are untyped.
 func isUntyped(typ types.Type) bool {
 	b, ok := typ.(*types.Basic)