| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package stubmethods |
| |
| import ( |
| "bytes" |
| "fmt" |
| "go/ast" |
| "go/token" |
| "go/types" |
| "strings" |
| "unicode" |
| |
| "golang.org/x/tools/go/ast/inspector" |
| "golang.org/x/tools/gopls/internal/cache/parsego" |
| "golang.org/x/tools/gopls/internal/util/cursorutil" |
| "golang.org/x/tools/gopls/internal/util/typesutil" |
| "golang.org/x/tools/internal/typesinternal" |
| ) |
| |
| var anyType = types.Universe.Lookup("any").Type() |
| |
| // CallStubInfo represents a missing method |
| // that a receiver type is about to generate |
| // which has "type X has no field or method Y" error |
| type CallStubInfo struct { |
| Fset *token.FileSet // the FileSet used to type-check the types below |
| Receiver typesinternal.NamedOrAlias // the method's receiver type |
| MethodName string |
| After types.Object // decl after which to insert the new decl |
| pointer bool |
| info *types.Info |
| curCall inspector.Cursor // cursor to the CallExpr |
| } |
| |
| // GetCallStubInfo extracts necessary information to generate a method definition from |
| // a CallExpr. |
| func GetCallStubInfo(fset *token.FileSet, info *types.Info, pgf *parsego.File, start, end token.Pos) *CallStubInfo { |
| callCur, _ := pgf.Cursor.FindByPos(start, end) |
| call, callCur := cursorutil.FirstEnclosing[*ast.CallExpr](callCur) |
| if call == nil { |
| return nil |
| } |
| s, ok := call.Fun.(*ast.SelectorExpr) |
| // TODO: support generating stub functions in the same way. |
| if !ok { |
| return nil |
| } |
| |
| // If recvExpr is a package name, compiler error would be |
| // e.g., "undefined: http.bar", thus will not hit this code path. |
| recvExpr := s.X |
| recvType, pointer := concreteType(recvExpr, info) |
| |
| if recvType == nil || recvType.Obj().Pkg() == nil { |
| return nil |
| } |
| |
| // A method of a function-local type cannot be stubbed |
| // since there's nowhere to put the methods. |
| recv := recvType.Obj() |
| if recv.Parent() != recv.Pkg().Scope() { |
| return nil |
| } |
| |
| after := types.Object(recv) |
| // If the enclosing function declaration is a method declaration, |
| // and matches the receiver type of the diagnostic, |
| // insert after the enclosing method. |
| decl, _ := cursorutil.FirstEnclosing[*ast.FuncDecl](callCur) |
| if decl != nil && decl.Recv != nil { |
| if len(decl.Recv.List) != 1 { |
| return nil |
| } |
| mrt := info.TypeOf(decl.Recv.List[0].Type) |
| if mrt != nil && types.Identical(types.Unalias(typesinternal.Unpointer(mrt)), recv.Type()) { |
| after = info.ObjectOf(decl.Name) |
| } |
| } |
| return &CallStubInfo{ |
| Fset: fset, |
| Receiver: recvType, |
| MethodName: s.Sel.Name, |
| After: after, |
| pointer: pointer, |
| curCall: callCur, |
| info: info, |
| } |
| } |
| |
| // Emit writes to out the missing method based on type info of si.Receiver and CallExpr. |
| func (si *CallStubInfo) Emit(out *bytes.Buffer, qual types.Qualifier) error { |
| params := si.collectParams() |
| rets := typesutil.TypesFromContext(si.info, si.curCall) |
| recv := si.Receiver.Obj() |
| // Pointer receiver? |
| var star string |
| if si.pointer { |
| star = "*" |
| } |
| |
| // Choose receiver name. |
| // If any method has a named receiver, choose the first one. |
| // Otherwise, use lowercase for the first letter of the object. |
| recvName := strings.ToLower(fmt.Sprintf("%.1s", recv.Name())) |
| if named, ok := types.Unalias(si.Receiver).(*types.Named); ok { |
| for method := range named.Methods() { |
| if recv := method.Signature().Recv(); recv.Name() != "" { |
| recvName = recv.Name() |
| break |
| } |
| } |
| } |
| |
| // Emit method declaration. |
| fmt.Fprintf(out, "\nfunc (%s %s%s%s) %s", |
| recvName, |
| star, |
| recv.Name(), |
| typesutil.FormatTypeParams(si.Receiver.TypeParams()), |
| si.MethodName) |
| |
| // Emit parameters, avoiding name conflicts. |
| seen := map[string]bool{recvName: true} |
| out.WriteString("(") |
| for i, param := range params { |
| name := param.name |
| if seen[name] { |
| name = fmt.Sprintf("param%d", i+1) |
| } |
| seen[name] = true |
| |
| if i > 0 { |
| out.WriteString(", ") |
| } |
| fmt.Fprintf(out, "%s %s", name, types.TypeString(param.typ, qual)) |
| } |
| out.WriteString(") ") |
| |
| // Emit result types. |
| if len(rets) > 1 { |
| out.WriteString("(") |
| } |
| for i, r := range rets { |
| if i > 0 { |
| out.WriteString(", ") |
| } |
| out.WriteString(types.TypeString(r, qual)) |
| } |
| if len(rets) > 1 { |
| out.WriteString(")") |
| } |
| |
| // Emit body. |
| out.WriteString(` { |
| panic("unimplemented") |
| }`) |
| return nil |
| } |
| |
| type param struct { |
| name string |
| typ types.Type // the type of param, inferred from CallExpr |
| } |
| |
| // collectParams gathers the parameter information needed to generate a method stub. |
| // The param's type default to any if there is a type error in the argument. |
| func (si *CallStubInfo) collectParams() []param { |
| var params []param |
| appendParam := func(e ast.Expr, t types.Type) { |
| p := param{"param", anyType} |
| if t != nil && !containsInvalid(t) { |
| t = types.Default(t) |
| p = param{paramName(e, t), t} |
| } |
| params = append(params, p) |
| } |
| |
| args := si.curCall.Node().(*ast.CallExpr).Args |
| for _, arg := range args { |
| t := si.info.TypeOf(arg) |
| switch t := t.(type) { |
| // This is the case where another function call returning multiple |
| // results is used as an argument. |
| case *types.Tuple: |
| for v := range t.Variables() { |
| appendParam(arg, v.Type()) |
| } |
| default: |
| appendParam(arg, t) |
| } |
| } |
| return params |
| } |
| |
| // containsInvalid checks if the type name contains "invalid type", |
| // which is not a valid syntax to generate. |
| func containsInvalid(t types.Type) bool { |
| typeString := types.TypeString(t, nil) |
| return strings.Contains(typeString, types.Typ[types.Invalid].String()) |
| } |
| |
| // paramName heuristically chooses a parameter name from |
| // its argument expression and type. Caller should ensure |
| // typ is non-nil. |
| func paramName(e ast.Expr, typ types.Type) string { |
| if typ == types.Universe.Lookup("error").Type() { |
| return "err" |
| } |
| switch t := e.(type) { |
| // Use the identifier's name as the argument name. |
| case *ast.Ident: |
| return t.Name |
| // Use the Sel.Name's last section as the argument name. |
| case *ast.SelectorExpr: |
| return lastSection(t.Sel.Name) |
| } |
| |
| typ = typesinternal.Unpointer(typ) |
| switch t := typ.(type) { |
| // Use the first character of the type name as the argument name for builtin types |
| case *types.Basic: |
| return t.Name()[:1] |
| case *types.Slice: |
| return paramName(e, t.Elem()) |
| case *types.Array: |
| return paramName(e, t.Elem()) |
| case *types.Signature: |
| return "f" |
| case *types.Map: |
| return "m" |
| case *types.Chan: |
| return "ch" |
| case *types.Named: |
| return lastSection(t.Obj().Name()) |
| default: |
| return lastSection(t.String()) |
| } |
| } |
| |
| // lastSection find the position of the last uppercase letter, |
| // extract the substring from that point onward, |
| // and convert it to lowercase. |
| // |
| // Example: lastSection("registryManagerFactory") = "factory" |
| func lastSection(identName string) string { |
| lastUpperIndex := -1 |
| for i, r := range identName { |
| if unicode.IsUpper(r) { |
| lastUpperIndex = i |
| } |
| } |
| if lastUpperIndex != -1 { |
| last := identName[lastUpperIndex:] |
| return strings.ToLower(last) |
| } else { |
| return identName |
| } |
| } |