|  | // Copyright 2022 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  | package ssa | 
|  |  | 
|  | import ( | 
|  | "fmt" | 
|  | "go/types" | 
|  |  | 
|  | "golang.org/x/tools/internal/typeparams" | 
|  | ) | 
|  |  | 
|  | // Type substituter for a fixed set of replacement types. | 
|  | // | 
|  | // A nil *subster is an valid, empty substitution map. It always acts as | 
|  | // the identity function. This allows for treating parameterized and | 
|  | // non-parameterized functions identically while compiling to ssa. | 
|  | // | 
|  | // Not concurrency-safe. | 
|  | type subster struct { | 
|  | // TODO(zpavlinovic): replacements can contain type params | 
|  | // when generating instances inside of a generic function body. | 
|  | replacements map[*typeparams.TypeParam]types.Type // values should contain no type params | 
|  | cache        map[types.Type]types.Type            // cache of subst results | 
|  | ctxt         *typeparams.Context | 
|  | debug        bool // perform extra debugging checks | 
|  | // TODO(taking): consider adding Pos | 
|  | } | 
|  |  | 
|  | // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. | 
|  | // targs should not contain any types in tparams. | 
|  | func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster { | 
|  | assert(tparams.Len() == len(targs), "makeSubster argument count must match") | 
|  |  | 
|  | subst := &subster{ | 
|  | replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()), | 
|  | cache:        make(map[types.Type]types.Type), | 
|  | ctxt:         ctxt, | 
|  | debug:        debug, | 
|  | } | 
|  | for i := 0; i < tparams.Len(); i++ { | 
|  | subst.replacements[tparams.At(i)] = targs[i] | 
|  | } | 
|  | if subst.debug { | 
|  | if err := subst.wellFormed(); err != nil { | 
|  | panic(err) | 
|  | } | 
|  | } | 
|  | return subst | 
|  | } | 
|  |  | 
|  | // wellFormed returns an error if subst was not properly initialized. | 
|  | func (subst *subster) wellFormed() error { | 
|  | if subst == nil || len(subst.replacements) == 0 { | 
|  | return nil | 
|  | } | 
|  | // Check that all of the type params do not appear in the arguments. | 
|  | s := make(map[types.Type]bool, len(subst.replacements)) | 
|  | for tparam := range subst.replacements { | 
|  | s[tparam] = true | 
|  | } | 
|  | for _, r := range subst.replacements { | 
|  | if reaches(r, s) { | 
|  | return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements) | 
|  | } | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // typ returns the type of t with the type parameter tparams[i] substituted | 
|  | // for the type targs[i] where subst was created using tparams and targs. | 
|  | func (subst *subster) typ(t types.Type) (res types.Type) { | 
|  | if subst == nil { | 
|  | return t // A nil subst is type preserving. | 
|  | } | 
|  | if r, ok := subst.cache[t]; ok { | 
|  | return r | 
|  | } | 
|  | defer func() { | 
|  | subst.cache[t] = res | 
|  | }() | 
|  |  | 
|  | // fall through if result r will be identical to t, types.Identical(r, t). | 
|  | switch t := t.(type) { | 
|  | case *typeparams.TypeParam: | 
|  | r := subst.replacements[t] | 
|  | assert(r != nil, "type param without replacement encountered") | 
|  | return r | 
|  |  | 
|  | case *types.Basic: | 
|  | return t | 
|  |  | 
|  | case *types.Array: | 
|  | if r := subst.typ(t.Elem()); r != t.Elem() { | 
|  | return types.NewArray(r, t.Len()) | 
|  | } | 
|  | return t | 
|  |  | 
|  | case *types.Slice: | 
|  | if r := subst.typ(t.Elem()); r != t.Elem() { | 
|  | return types.NewSlice(r) | 
|  | } | 
|  | return t | 
|  |  | 
|  | case *types.Pointer: | 
|  | if r := subst.typ(t.Elem()); r != t.Elem() { | 
|  | return types.NewPointer(r) | 
|  | } | 
|  | return t | 
|  |  | 
|  | case *types.Tuple: | 
|  | return subst.tuple(t) | 
|  |  | 
|  | case *types.Struct: | 
|  | return subst.struct_(t) | 
|  |  | 
|  | case *types.Map: | 
|  | key := subst.typ(t.Key()) | 
|  | elem := subst.typ(t.Elem()) | 
|  | if key != t.Key() || elem != t.Elem() { | 
|  | return types.NewMap(key, elem) | 
|  | } | 
|  | return t | 
|  |  | 
|  | case *types.Chan: | 
|  | if elem := subst.typ(t.Elem()); elem != t.Elem() { | 
|  | return types.NewChan(t.Dir(), elem) | 
|  | } | 
|  | return t | 
|  |  | 
|  | case *types.Signature: | 
|  | return subst.signature(t) | 
|  |  | 
|  | case *typeparams.Union: | 
|  | return subst.union(t) | 
|  |  | 
|  | case *types.Interface: | 
|  | return subst.interface_(t) | 
|  |  | 
|  | case *types.Named: | 
|  | return subst.named(t) | 
|  |  | 
|  | default: | 
|  | panic("unreachable") | 
|  | } | 
|  | } | 
|  |  | 
|  | // types returns the result of {subst.typ(ts[i])}. | 
|  | func (subst *subster) types(ts []types.Type) []types.Type { | 
|  | res := make([]types.Type, len(ts)) | 
|  | for i := range ts { | 
|  | res[i] = subst.typ(ts[i]) | 
|  | } | 
|  | return res | 
|  | } | 
|  |  | 
|  | func (subst *subster) tuple(t *types.Tuple) *types.Tuple { | 
|  | if t != nil { | 
|  | if vars := subst.varlist(t); vars != nil { | 
|  | return types.NewTuple(vars...) | 
|  | } | 
|  | } | 
|  | return t | 
|  | } | 
|  |  | 
|  | type varlist interface { | 
|  | At(i int) *types.Var | 
|  | Len() int | 
|  | } | 
|  |  | 
|  | // fieldlist is an adapter for structs for the varlist interface. | 
|  | type fieldlist struct { | 
|  | str *types.Struct | 
|  | } | 
|  |  | 
|  | func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } | 
|  | func (fl fieldlist) Len() int            { return fl.str.NumFields() } | 
|  |  | 
|  | func (subst *subster) struct_(t *types.Struct) *types.Struct { | 
|  | if t != nil { | 
|  | if fields := subst.varlist(fieldlist{t}); fields != nil { | 
|  | tags := make([]string, t.NumFields()) | 
|  | for i, n := 0, t.NumFields(); i < n; i++ { | 
|  | tags[i] = t.Tag(i) | 
|  | } | 
|  | return types.NewStruct(fields, tags) | 
|  | } | 
|  | } | 
|  | return t | 
|  | } | 
|  |  | 
|  | // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. | 
|  | func (subst *subster) varlist(in varlist) []*types.Var { | 
|  | var out []*types.Var // nil => no updates | 
|  | for i, n := 0, in.Len(); i < n; i++ { | 
|  | v := in.At(i) | 
|  | w := subst.var_(v) | 
|  | if v != w && out == nil { | 
|  | out = make([]*types.Var, n) | 
|  | for j := 0; j < i; j++ { | 
|  | out[j] = in.At(j) | 
|  | } | 
|  | } | 
|  | if out != nil { | 
|  | out[i] = w | 
|  | } | 
|  | } | 
|  | return out | 
|  | } | 
|  |  | 
|  | func (subst *subster) var_(v *types.Var) *types.Var { | 
|  | if v != nil { | 
|  | if typ := subst.typ(v.Type()); typ != v.Type() { | 
|  | if v.IsField() { | 
|  | return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) | 
|  | } | 
|  | return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) | 
|  | } | 
|  | } | 
|  | return v | 
|  | } | 
|  |  | 
|  | func (subst *subster) union(u *typeparams.Union) *typeparams.Union { | 
|  | var out []*typeparams.Term // nil => no updates | 
|  |  | 
|  | for i, n := 0, u.Len(); i < n; i++ { | 
|  | t := u.Term(i) | 
|  | r := subst.typ(t.Type()) | 
|  | if r != t.Type() && out == nil { | 
|  | out = make([]*typeparams.Term, n) | 
|  | for j := 0; j < i; j++ { | 
|  | out[j] = u.Term(j) | 
|  | } | 
|  | } | 
|  | if out != nil { | 
|  | out[i] = typeparams.NewTerm(t.Tilde(), r) | 
|  | } | 
|  | } | 
|  |  | 
|  | if out != nil { | 
|  | return typeparams.NewUnion(out) | 
|  | } | 
|  | return u | 
|  | } | 
|  |  | 
|  | func (subst *subster) interface_(iface *types.Interface) *types.Interface { | 
|  | if iface == nil { | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // methods for the interface. Initially nil if there is no known change needed. | 
|  | // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers. | 
|  | var methods []*types.Func | 
|  | initMethods := func(n int) { // copy first n explicit methods | 
|  | methods = make([]*types.Func, iface.NumExplicitMethods()) | 
|  | for i := 0; i < n; i++ { | 
|  | f := iface.ExplicitMethod(i) | 
|  | norecv := changeRecv(f.Type().(*types.Signature), nil) | 
|  | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) | 
|  | } | 
|  | } | 
|  | for i := 0; i < iface.NumExplicitMethods(); i++ { | 
|  | f := iface.ExplicitMethod(i) | 
|  | // On interfaces, we need to cycle break on anonymous interface types | 
|  | // being in a cycle with their signatures being in cycles with their recievers | 
|  | // that do not go through a Named. | 
|  | norecv := changeRecv(f.Type().(*types.Signature), nil) | 
|  | sig := subst.typ(norecv) | 
|  | if sig != norecv && methods == nil { | 
|  | initMethods(i) | 
|  | } | 
|  | if methods != nil { | 
|  | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) | 
|  | } | 
|  | } | 
|  |  | 
|  | var embeds []types.Type | 
|  | initEmbeds := func(n int) { // copy first n embedded types | 
|  | embeds = make([]types.Type, iface.NumEmbeddeds()) | 
|  | for i := 0; i < n; i++ { | 
|  | embeds[i] = iface.EmbeddedType(i) | 
|  | } | 
|  | } | 
|  | for i := 0; i < iface.NumEmbeddeds(); i++ { | 
|  | e := iface.EmbeddedType(i) | 
|  | r := subst.typ(e) | 
|  | if e != r && embeds == nil { | 
|  | initEmbeds(i) | 
|  | } | 
|  | if embeds != nil { | 
|  | embeds[i] = r | 
|  | } | 
|  | } | 
|  |  | 
|  | if methods == nil && embeds == nil { | 
|  | return iface | 
|  | } | 
|  | if methods == nil { | 
|  | initMethods(iface.NumExplicitMethods()) | 
|  | } | 
|  | if embeds == nil { | 
|  | initEmbeds(iface.NumEmbeddeds()) | 
|  | } | 
|  | return types.NewInterfaceType(methods, embeds).Complete() | 
|  | } | 
|  |  | 
|  | func (subst *subster) named(t *types.Named) types.Type { | 
|  | // A name type may be: | 
|  | // (1) ordinary (no type parameters, no type arguments), | 
|  | // (2) generic (type parameters but no type arguments), or | 
|  | // (3) instantiated (type parameters and type arguments). | 
|  | tparams := typeparams.ForNamed(t) | 
|  | if tparams.Len() == 0 { | 
|  | // case (1) ordinary | 
|  |  | 
|  | // Note: If Go allows for local type declarations in generic | 
|  | // functions we may need to descend into underlying as well. | 
|  | return t | 
|  | } | 
|  | targs := typeparams.NamedTypeArgs(t) | 
|  |  | 
|  | // insts are arguments to instantiate using. | 
|  | insts := make([]types.Type, tparams.Len()) | 
|  |  | 
|  | // case (2) generic ==> targs.Len() == 0 | 
|  | // Instantiating a generic with no type arguments should be unreachable. | 
|  | // Please report a bug if you encounter this. | 
|  | assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") | 
|  |  | 
|  | // case (3) instantiated. | 
|  | // Substitute into the type arguments and instantiate the replacements/ | 
|  | // Example: | 
|  | //    type N[A any] func() A | 
|  | //    func Foo[T](g N[T]) {} | 
|  | //  To instantiate Foo[string], one goes through {T->string}. To get the type of g | 
|  | //  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } | 
|  | //  to get {N with TypeArgs == {string} and typeparams == {A} }. | 
|  | assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") | 
|  | for i, n := 0, targs.Len(); i < n; i++ { | 
|  | inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion | 
|  | insts[i] = inst | 
|  | } | 
|  | r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false) | 
|  | assert(err == nil, "failed to Instantiate Named type") | 
|  | return r | 
|  | } | 
|  |  | 
|  | func (subst *subster) signature(t *types.Signature) types.Type { | 
|  | tparams := typeparams.ForSignature(t) | 
|  |  | 
|  | // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. | 
|  | // | 
|  | // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. | 
|  | // To support tparams.Len() > 0, we just need to do the following [psuedocode]: | 
|  | //   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) | 
|  |  | 
|  | assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") | 
|  |  | 
|  | // Either: | 
|  | // (1)non-generic function. | 
|  | //    no type params to substitute | 
|  | // (2)generic method and recv needs to be substituted. | 
|  |  | 
|  | // Recievers can be either: | 
|  | // named | 
|  | // pointer to named | 
|  | // interface | 
|  | // nil | 
|  | // interface is the problematic case. We need to cycle break there! | 
|  | recv := subst.var_(t.Recv()) | 
|  | params := subst.tuple(t.Params()) | 
|  | results := subst.tuple(t.Results()) | 
|  | if recv != t.Recv() || params != t.Params() || results != t.Results() { | 
|  | return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) | 
|  | } | 
|  | return t | 
|  | } | 
|  |  | 
|  | // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. | 
|  | // Updates c to cache results. | 
|  | func reaches(t types.Type, c map[types.Type]bool) (res bool) { | 
|  | if c, ok := c[t]; ok { | 
|  | return c | 
|  | } | 
|  | c[t] = false // prevent cycles | 
|  | defer func() { | 
|  | c[t] = res | 
|  | }() | 
|  |  | 
|  | switch t := t.(type) { | 
|  | case *typeparams.TypeParam, *types.Basic: | 
|  | // no-op => c == false | 
|  | case *types.Array: | 
|  | return reaches(t.Elem(), c) | 
|  | case *types.Slice: | 
|  | return reaches(t.Elem(), c) | 
|  | case *types.Pointer: | 
|  | return reaches(t.Elem(), c) | 
|  | case *types.Tuple: | 
|  | for i := 0; i < t.Len(); i++ { | 
|  | if reaches(t.At(i).Type(), c) { | 
|  | return true | 
|  | } | 
|  | } | 
|  | case *types.Struct: | 
|  | for i := 0; i < t.NumFields(); i++ { | 
|  | if reaches(t.Field(i).Type(), c) { | 
|  | return true | 
|  | } | 
|  | } | 
|  | case *types.Map: | 
|  | return reaches(t.Key(), c) || reaches(t.Elem(), c) | 
|  | case *types.Chan: | 
|  | return reaches(t.Elem(), c) | 
|  | case *types.Signature: | 
|  | if t.Recv() != nil && reaches(t.Recv().Type(), c) { | 
|  | return true | 
|  | } | 
|  | return reaches(t.Params(), c) || reaches(t.Results(), c) | 
|  | case *typeparams.Union: | 
|  | for i := 0; i < t.Len(); i++ { | 
|  | if reaches(t.Term(i).Type(), c) { | 
|  | return true | 
|  | } | 
|  | } | 
|  | case *types.Interface: | 
|  | for i := 0; i < t.NumEmbeddeds(); i++ { | 
|  | if reaches(t.Embedded(i), c) { | 
|  | return true | 
|  | } | 
|  | } | 
|  | for i := 0; i < t.NumExplicitMethods(); i++ { | 
|  | if reaches(t.ExplicitMethod(i).Type(), c) { | 
|  | return true | 
|  | } | 
|  | } | 
|  | case *types.Named: | 
|  | return reaches(t.Underlying(), c) | 
|  | default: | 
|  | panic("unreachable") | 
|  | } | 
|  | return false | 
|  | } |