|  | // Copyright 2022 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package stubmethods | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "fmt" | 
|  | "go/ast" | 
|  | "go/format" | 
|  | "go/token" | 
|  | "go/types" | 
|  | "strconv" | 
|  | "strings" | 
|  |  | 
|  | "golang.org/x/tools/go/analysis" | 
|  | "golang.org/x/tools/go/analysis/passes/inspect" | 
|  | "golang.org/x/tools/go/ast/astutil" | 
|  | "golang.org/x/tools/internal/analysisinternal" | 
|  | "golang.org/x/tools/internal/typesinternal" | 
|  | ) | 
|  |  | 
|  | const Doc = `stub methods analyzer | 
|  |  | 
|  | This analyzer generates method stubs for concrete types | 
|  | in order to implement a target interface` | 
|  |  | 
|  | var Analyzer = &analysis.Analyzer{ | 
|  | Name:             "stubmethods", | 
|  | Doc:              Doc, | 
|  | Requires:         []*analysis.Analyzer{inspect.Analyzer}, | 
|  | Run:              run, | 
|  | RunDespiteErrors: true, | 
|  | } | 
|  |  | 
|  | func run(pass *analysis.Pass) (interface{}, error) { | 
|  | for _, err := range pass.TypeErrors { | 
|  | ifaceErr := strings.Contains(err.Msg, "missing method") || strings.HasPrefix(err.Msg, "cannot convert") | 
|  | if !ifaceErr { | 
|  | continue | 
|  | } | 
|  | var file *ast.File | 
|  | for _, f := range pass.Files { | 
|  | if f.Pos() <= err.Pos && err.Pos < f.End() { | 
|  | file = f | 
|  | break | 
|  | } | 
|  | } | 
|  | if file == nil { | 
|  | continue | 
|  | } | 
|  | // Get the end position of the error. | 
|  | _, _, endPos, ok := typesinternal.ReadGo116ErrorData(err) | 
|  | if !ok { | 
|  | var buf bytes.Buffer | 
|  | if err := format.Node(&buf, pass.Fset, file); err != nil { | 
|  | continue | 
|  | } | 
|  | endPos = analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), err.Pos) | 
|  | } | 
|  | path, _ := astutil.PathEnclosingInterval(file, err.Pos, endPos) | 
|  | si := GetStubInfo(pass.TypesInfo, path, err.Pos) | 
|  | if si == nil { | 
|  | continue | 
|  | } | 
|  | qf := RelativeToFiles(si.Concrete.Obj().Pkg(), file, nil, nil) | 
|  | pass.Report(analysis.Diagnostic{ | 
|  | Pos:     err.Pos, | 
|  | End:     endPos, | 
|  | Message: fmt.Sprintf("Implement %s", types.TypeString(si.Interface.Type(), qf)), | 
|  | }) | 
|  | } | 
|  | return nil, nil | 
|  | } | 
|  |  | 
|  | // StubInfo represents a concrete type | 
|  | // that wants to stub out an interface type | 
|  | type StubInfo struct { | 
|  | // Interface is the interface that the client wants to implement. | 
|  | // When the interface is defined, the underlying object will be a TypeName. | 
|  | // Note that we keep track of types.Object instead of types.Type in order | 
|  | // to keep a reference to the declaring object's package and the ast file | 
|  | // in the case where the concrete type file requires a new import that happens to be renamed | 
|  | // in the interface file. | 
|  | // TODO(marwan-at-work): implement interface literals. | 
|  | Interface *types.TypeName | 
|  | Concrete  *types.Named | 
|  | Pointer   bool | 
|  | } | 
|  |  | 
|  | // GetStubInfo determines whether the "missing method error" | 
|  | // can be used to deduced what the concrete and interface types are. | 
|  | func GetStubInfo(ti *types.Info, path []ast.Node, pos token.Pos) *StubInfo { | 
|  | for _, n := range path { | 
|  | switch n := n.(type) { | 
|  | case *ast.ValueSpec: | 
|  | return fromValueSpec(ti, n, pos) | 
|  | case *ast.ReturnStmt: | 
|  | // An error here may not indicate a real error the user should know about, but it may. | 
|  | // Therefore, it would be best to log it out for debugging/reporting purposes instead of ignoring | 
|  | // it. However, event.Log takes a context which is not passed via the analysis package. | 
|  | // TODO(marwan-at-work): properly log this error. | 
|  | si, _ := fromReturnStmt(ti, pos, path, n) | 
|  | return si | 
|  | case *ast.AssignStmt: | 
|  | return fromAssignStmt(ti, n, pos) | 
|  | case *ast.CallExpr: | 
|  | // Note that some call expressions don't carry the interface type | 
|  | // because they don't point to a function or method declaration elsewhere. | 
|  | // For eaxmple, "var Interface = (*Concrete)(nil)". In that case, continue | 
|  | // this loop to encounter other possibilities such as *ast.ValueSpec or others. | 
|  | si := fromCallExpr(ti, pos, n) | 
|  | if si != nil { | 
|  | return si | 
|  | } | 
|  | } | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // fromCallExpr tries to find an *ast.CallExpr's function declaration and | 
|  | // analyzes a function call's signature against the passed in parameter to deduce | 
|  | // the concrete and interface types. | 
|  | func fromCallExpr(ti *types.Info, pos token.Pos, ce *ast.CallExpr) *StubInfo { | 
|  | paramIdx := -1 | 
|  | for i, p := range ce.Args { | 
|  | if pos >= p.Pos() && pos <= p.End() { | 
|  | paramIdx = i | 
|  | break | 
|  | } | 
|  | } | 
|  | if paramIdx == -1 { | 
|  | return nil | 
|  | } | 
|  | p := ce.Args[paramIdx] | 
|  | concObj, pointer := concreteType(p, ti) | 
|  | if concObj == nil || concObj.Obj().Pkg() == nil { | 
|  | return nil | 
|  | } | 
|  | tv, ok := ti.Types[ce.Fun] | 
|  | if !ok { | 
|  | return nil | 
|  | } | 
|  | sig, ok := tv.Type.(*types.Signature) | 
|  | if !ok { | 
|  | return nil | 
|  | } | 
|  | sigVar := sig.Params().At(paramIdx) | 
|  | iface := ifaceObjFromType(sigVar.Type()) | 
|  | if iface == nil { | 
|  | return nil | 
|  | } | 
|  | return &StubInfo{ | 
|  | Concrete:  concObj, | 
|  | Pointer:   pointer, | 
|  | Interface: iface, | 
|  | } | 
|  | } | 
|  |  | 
|  | // fromReturnStmt analyzes a "return" statement to extract | 
|  | // a concrete type that is trying to be returned as an interface type. | 
|  | // | 
|  | // For example, func() io.Writer { return myType{} } | 
|  | // would return StubInfo with the interface being io.Writer and the concrete type being myType{}. | 
|  | func fromReturnStmt(ti *types.Info, pos token.Pos, path []ast.Node, rs *ast.ReturnStmt) (*StubInfo, error) { | 
|  | returnIdx := -1 | 
|  | for i, r := range rs.Results { | 
|  | if pos >= r.Pos() && pos <= r.End() { | 
|  | returnIdx = i | 
|  | } | 
|  | } | 
|  | if returnIdx == -1 { | 
|  | return nil, fmt.Errorf("pos %d not within return statement bounds: [%d-%d]", pos, rs.Pos(), rs.End()) | 
|  | } | 
|  | concObj, pointer := concreteType(rs.Results[returnIdx], ti) | 
|  | if concObj == nil || concObj.Obj().Pkg() == nil { | 
|  | return nil, nil | 
|  | } | 
|  | ef := enclosingFunction(path, ti) | 
|  | if ef == nil { | 
|  | return nil, fmt.Errorf("could not find the enclosing function of the return statement") | 
|  | } | 
|  | iface := ifaceType(ef.Results.List[returnIdx].Type, ti) | 
|  | if iface == nil { | 
|  | return nil, nil | 
|  | } | 
|  | return &StubInfo{ | 
|  | Concrete:  concObj, | 
|  | Pointer:   pointer, | 
|  | Interface: iface, | 
|  | }, nil | 
|  | } | 
|  |  | 
|  | // fromValueSpec returns *StubInfo from a variable declaration such as | 
|  | // var x io.Writer = &T{} | 
|  | func fromValueSpec(ti *types.Info, vs *ast.ValueSpec, pos token.Pos) *StubInfo { | 
|  | var idx int | 
|  | for i, vs := range vs.Values { | 
|  | if pos >= vs.Pos() && pos <= vs.End() { | 
|  | idx = i | 
|  | break | 
|  | } | 
|  | } | 
|  |  | 
|  | valueNode := vs.Values[idx] | 
|  | ifaceNode := vs.Type | 
|  | callExp, ok := valueNode.(*ast.CallExpr) | 
|  | // if the ValueSpec is `var _ = myInterface(...)` | 
|  | // as opposed to `var _ myInterface = ...` | 
|  | if ifaceNode == nil && ok && len(callExp.Args) == 1 { | 
|  | ifaceNode = callExp.Fun | 
|  | valueNode = callExp.Args[0] | 
|  | } | 
|  | concObj, pointer := concreteType(valueNode, ti) | 
|  | if concObj == nil || concObj.Obj().Pkg() == nil { | 
|  | return nil | 
|  | } | 
|  | ifaceObj := ifaceType(ifaceNode, ti) | 
|  | if ifaceObj == nil { | 
|  | return nil | 
|  | } | 
|  | return &StubInfo{ | 
|  | Concrete:  concObj, | 
|  | Interface: ifaceObj, | 
|  | Pointer:   pointer, | 
|  | } | 
|  | } | 
|  |  | 
|  | // fromAssignStmt returns *StubInfo from a variable re-assignment such as | 
|  | // var x io.Writer | 
|  | // x = &T{} | 
|  | func fromAssignStmt(ti *types.Info, as *ast.AssignStmt, pos token.Pos) *StubInfo { | 
|  | idx := -1 | 
|  | var lhs, rhs ast.Expr | 
|  | // Given a re-assignment interface conversion error, | 
|  | // the compiler error shows up on the right hand side of the expression. | 
|  | // For example, x = &T{} where x is io.Writer highlights the error | 
|  | // under "&T{}" and not "x". | 
|  | for i, hs := range as.Rhs { | 
|  | if pos >= hs.Pos() && pos <= hs.End() { | 
|  | idx = i | 
|  | break | 
|  | } | 
|  | } | 
|  | if idx == -1 { | 
|  | return nil | 
|  | } | 
|  | // Technically, this should never happen as | 
|  | // we would get a "cannot assign N values to M variables" | 
|  | // before we get an interface conversion error. Nonetheless, | 
|  | // guard against out of range index errors. | 
|  | if idx >= len(as.Lhs) { | 
|  | return nil | 
|  | } | 
|  | lhs, rhs = as.Lhs[idx], as.Rhs[idx] | 
|  | ifaceObj := ifaceType(lhs, ti) | 
|  | if ifaceObj == nil { | 
|  | return nil | 
|  | } | 
|  | concType, pointer := concreteType(rhs, ti) | 
|  | if concType == nil || concType.Obj().Pkg() == nil { | 
|  | return nil | 
|  | } | 
|  | return &StubInfo{ | 
|  | Concrete:  concType, | 
|  | Interface: ifaceObj, | 
|  | Pointer:   pointer, | 
|  | } | 
|  | } | 
|  |  | 
|  | // RelativeToFiles returns a types.Qualifier that formats package | 
|  | // names according to the import environments of the files that define | 
|  | // the concrete type and the interface type. (Only the imports of the | 
|  | // latter file are provided.) | 
|  | // | 
|  | // This is similar to types.RelativeTo except if a file imports the package with a different name, | 
|  | // then it will use it. And if the file does import the package but it is ignored, | 
|  | // then it will return the original name. It also prefers package names in importEnv in case | 
|  | // an import is missing from concFile but is present among importEnv. | 
|  | // | 
|  | // Additionally, if missingImport is not nil, the function will be called whenever the concFile | 
|  | // is presented with a package that is not imported. This is useful so that as types.TypeString is | 
|  | // formatting a function signature, it is identifying packages that will need to be imported when | 
|  | // stubbing an interface. | 
|  | func RelativeToFiles(concPkg *types.Package, concFile *ast.File, ifaceImports []*ast.ImportSpec, missingImport func(name, path string)) types.Qualifier { | 
|  | return func(other *types.Package) string { | 
|  | if other == concPkg { | 
|  | return "" | 
|  | } | 
|  |  | 
|  | // Check if the concrete file already has the given import, | 
|  | // if so return the default package name or the renamed import statement. | 
|  | for _, imp := range concFile.Imports { | 
|  | impPath, _ := strconv.Unquote(imp.Path.Value) | 
|  | isIgnored := imp.Name != nil && (imp.Name.Name == "." || imp.Name.Name == "_") | 
|  | // TODO(adonovan): this comparison disregards a vendor prefix in 'other'. | 
|  | if impPath == other.Path() && !isIgnored { | 
|  | importName := other.Name() | 
|  | if imp.Name != nil { | 
|  | importName = imp.Name.Name | 
|  | } | 
|  | return importName | 
|  | } | 
|  | } | 
|  |  | 
|  | // If the concrete file does not have the import, check if the package | 
|  | // is renamed in the interface file and prefer that. | 
|  | var importName string | 
|  | for _, imp := range ifaceImports { | 
|  | impPath, _ := strconv.Unquote(imp.Path.Value) | 
|  | isIgnored := imp.Name != nil && (imp.Name.Name == "." || imp.Name.Name == "_") | 
|  | // TODO(adonovan): this comparison disregards a vendor prefix in 'other'. | 
|  | if impPath == other.Path() && !isIgnored { | 
|  | if imp.Name != nil && imp.Name.Name != concPkg.Name() { | 
|  | importName = imp.Name.Name | 
|  | } | 
|  | break | 
|  | } | 
|  | } | 
|  |  | 
|  | if missingImport != nil { | 
|  | missingImport(importName, other.Path()) | 
|  | } | 
|  |  | 
|  | // Up until this point, importName must stay empty when calling missingImport, | 
|  | // otherwise we'd end up with `import time "time"` which doesn't look idiomatic. | 
|  | if importName == "" { | 
|  | importName = other.Name() | 
|  | } | 
|  | return importName | 
|  | } | 
|  | } | 
|  |  | 
|  | // ifaceType will try to extract the types.Object that defines | 
|  | // the interface given the ast.Expr where the "missing method" | 
|  | // or "conversion" errors happen. | 
|  | func ifaceType(n ast.Expr, ti *types.Info) *types.TypeName { | 
|  | tv, ok := ti.Types[n] | 
|  | if !ok { | 
|  | return nil | 
|  | } | 
|  | return ifaceObjFromType(tv.Type) | 
|  | } | 
|  |  | 
|  | func ifaceObjFromType(t types.Type) *types.TypeName { | 
|  | named, ok := t.(*types.Named) | 
|  | if !ok { | 
|  | return nil | 
|  | } | 
|  | _, ok = named.Underlying().(*types.Interface) | 
|  | if !ok { | 
|  | return nil | 
|  | } | 
|  | // Interfaces defined in the "builtin" package return nil a Pkg(). | 
|  | // But they are still real interfaces that we need to make a special case for. | 
|  | // Therefore, protect gopls from panicking if a new interface type was added in the future. | 
|  | if named.Obj().Pkg() == nil && named.Obj().Name() != "error" { | 
|  | return nil | 
|  | } | 
|  | return named.Obj() | 
|  | } | 
|  |  | 
|  | // concreteType tries to extract the *types.Named that defines | 
|  | // the concrete type given the ast.Expr where the "missing method" | 
|  | // or "conversion" errors happened. If the concrete type is something | 
|  | // that cannot have methods defined on it (such as basic types), this | 
|  | // method will return a nil *types.Named. The second return parameter | 
|  | // is a boolean that indicates whether the concreteType was defined as a | 
|  | // pointer or value. | 
|  | func concreteType(n ast.Expr, ti *types.Info) (*types.Named, bool) { | 
|  | tv, ok := ti.Types[n] | 
|  | if !ok { | 
|  | return nil, false | 
|  | } | 
|  | typ := tv.Type | 
|  | ptr, isPtr := typ.(*types.Pointer) | 
|  | if isPtr { | 
|  | typ = ptr.Elem() | 
|  | } | 
|  | named, ok := typ.(*types.Named) | 
|  | if !ok { | 
|  | return nil, false | 
|  | } | 
|  | return named, isPtr | 
|  | } | 
|  |  | 
|  | // enclosingFunction returns the signature and type of the function | 
|  | // enclosing the given position. | 
|  | func enclosingFunction(path []ast.Node, info *types.Info) *ast.FuncType { | 
|  | for _, node := range path { | 
|  | switch t := node.(type) { | 
|  | case *ast.FuncDecl: | 
|  | if _, ok := info.Defs[t.Name]; ok { | 
|  | return t.Type | 
|  | } | 
|  | case *ast.FuncLit: | 
|  | if _, ok := info.Types[t]; ok { | 
|  | return t.Type | 
|  | } | 
|  | } | 
|  | } | 
|  | return nil | 
|  | } |