blob: 5558c6743a6724e902129a860a8c02dbb46de420 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package typesutil
import (
"bytes"
"go/ast"
"go/token"
"go/types"
"strings"
"golang.org/x/tools/go/ast/edge"
"golang.org/x/tools/go/ast/inspector"
)
// FormatTypeParams turns TypeParamList into its Go representation, such as:
// [T, Y]. Note that it does not print constraints as this is mainly used for
// formatting type params in method receivers.
func FormatTypeParams(tparams *types.TypeParamList) string {
if tparams == nil || tparams.Len() == 0 {
return ""
}
var buf bytes.Buffer
buf.WriteByte('[')
for i := 0; i < tparams.Len(); i++ {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(tparams.At(i).Obj().Name())
}
buf.WriteByte(']')
return buf.String()
}
// TypesFromContext returns the type (or perhaps zero or multiple types)
// of the "hole" into which the expression identified by path must fit.
//
// For example, given
//
// s, i := "", 0
// s, i = EXPR
//
// the hole that must be filled by EXPR has type (string, int).
//
// It returns nil on failure.
func TypesFromContext(info *types.Info, cur inspector.Cursor) []types.Type {
anyType := types.Universe.Lookup("any").Type()
var typs []types.Type
// TODO: do cur = unparenEnclosing(cur), once CL 701035 lands.
for {
ek, _ := cur.ParentEdge()
if ek != edge.ParenExpr_X {
break
}
cur = cur.Parent()
}
validType := func(t types.Type) types.Type {
if t != nil && !containsInvalid(t) {
return types.Default(t)
} else {
return anyType
}
}
ek, idx := cur.ParentEdge()
switch ek {
case edge.AssignStmt_Lhs, edge.AssignStmt_Rhs:
assign := cur.Parent().Node().(*ast.AssignStmt)
// Append all lhs's type
if len(assign.Rhs) == 1 {
for _, lhs := range assign.Lhs {
t := info.TypeOf(lhs)
typs = append(typs, validType(t))
}
break
}
// Lhs and Rhs counts do not match, give up
if len(assign.Lhs) != len(assign.Rhs) {
break
}
// Append corresponding index of lhs's type
if ek == edge.AssignStmt_Rhs {
t := info.TypeOf(assign.Lhs[idx])
typs = append(typs, validType(t))
}
case edge.ValueSpec_Names, edge.ValueSpec_Type, edge.ValueSpec_Values:
spec := cur.Parent().Node().(*ast.ValueSpec)
if len(spec.Values) == 1 {
for _, lhs := range spec.Names {
t := info.TypeOf(lhs)
typs = append(typs, validType(t))
}
break
}
if len(spec.Values) != len(spec.Names) {
break
}
t := info.TypeOf(spec.Type)
typs = append(typs, validType(t))
case edge.ReturnStmt_Results:
returnstmt := cur.Parent().Node().(*ast.ReturnStmt)
sig := EnclosingSignature(cur, info)
if sig == nil || sig.Results() == nil {
break
}
retsig := sig.Results()
// Append all return declarations' type
if len(returnstmt.Results) == 1 {
for v := range retsig.Variables() {
t := v.Type()
typs = append(typs, validType(t))
}
break
}
// Return declaration and actual return counts do not match, give up
if retsig.Len() != len(returnstmt.Results) {
break
}
// Append corresponding index of return declaration's type
t := retsig.At(idx).Type()
typs = append(typs, validType(t))
case edge.CallExpr_Args:
call := cur.Parent().Node().(*ast.CallExpr)
t := info.TypeOf(call.Fun)
if t == nil {
break
}
if sig, ok := t.Underlying().(*types.Signature); ok {
var paramType types.Type
if sig.Variadic() && idx >= sig.Params().Len()-1 {
v := sig.Params().At(sig.Params().Len() - 1)
if s, _ := v.Type().(*types.Slice); s != nil {
paramType = s.Elem()
}
} else if idx < sig.Params().Len() {
paramType = sig.Params().At(idx).Type()
} else {
break
}
if paramType == nil || containsInvalid(paramType) {
paramType = anyType
}
typs = append(typs, paramType)
}
case edge.IfStmt_Cond:
typs = append(typs, types.Typ[types.Bool])
case edge.ForStmt_Cond:
typs = append(typs, types.Typ[types.Bool])
case edge.UnaryExpr_X:
unexpr := cur.Parent().Node().(*ast.UnaryExpr)
var t types.Type
switch unexpr.Op {
case token.NOT:
t = types.Typ[types.Bool]
case token.ADD, token.SUB, token.XOR:
t = types.Typ[types.Int]
default:
t = anyType
}
typs = append(typs, t)
case edge.BinaryExpr_X, edge.BinaryExpr_Y:
binexpr := cur.Parent().Node().(*ast.BinaryExpr)
switch ek {
case edge.BinaryExpr_X:
t := info.TypeOf(binexpr.Y)
typs = append(typs, validType(t))
case edge.BinaryExpr_Y:
t := info.TypeOf(binexpr.X)
typs = append(typs, validType(t))
}
default:
// TODO: support other kinds of "holes" as the need arises.
}
return typs
}
// containsInvalid checks if the type name contains "invalid type",
// which is not a valid syntax to generate.
func containsInvalid(t types.Type) bool {
typeString := types.TypeString(t, nil)
return strings.Contains(typeString, types.Typ[types.Invalid].String())
}
// EnclosingSignature returns the signature of the innermost
// function enclosing the syntax node denoted by cur
// or nil if the node is not within a function.
func EnclosingSignature(cur inspector.Cursor, info *types.Info) *types.Signature {
for c := range cur.Enclosing((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)) {
switch n := c.Node().(type) {
case *ast.FuncDecl:
if f, ok := info.Defs[n.Name]; ok {
return f.Type().(*types.Signature)
}
return nil
case *ast.FuncLit:
if f, ok := info.Types[n]; ok {
return f.Type.(*types.Signature)
}
return nil
}
}
return nil
}