blob: 7b6f4f45137d5c3d96bafd2661e547216cc82ed2 [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package source
import (
"bytes"
"context"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/lsp/analysis/stubmethods"
"golang.org/x/tools/gopls/internal/lsp/protocol"
"golang.org/x/tools/gopls/internal/lsp/safetoken"
"golang.org/x/tools/gopls/internal/span"
"golang.org/x/tools/internal/typeparams"
)
func stubSuggestedFixFunc(ctx context.Context, snapshot Snapshot, fh FileHandle, rng protocol.Range) (*token.FileSet, *analysis.SuggestedFix, error) {
pkg, pgf, err := PackageForFile(ctx, snapshot, fh.URI(), TypecheckWorkspace, NarrowestPackage)
if err != nil {
return nil, nil, fmt.Errorf("GetTypedFile: %w", err)
}
nodes, pos, err := getStubNodes(pgf, rng)
if err != nil {
return nil, nil, fmt.Errorf("getNodes: %w", err)
}
si := stubmethods.GetStubInfo(pkg.GetTypesInfo(), nodes, pos)
if si == nil {
return nil, nil, fmt.Errorf("nil interface request")
}
// A function-local type cannot be stubbed
// since there's nowhere to put the methods.
conc := si.Concrete.Obj()
if conc != conc.Pkg().Scope().Lookup(conc.Name()) {
return nil, nil, fmt.Errorf("local type %q cannot be stubbed", conc.Name())
}
// Parse the file defining the concrete type.
concreteFilename := safetoken.StartPosition(snapshot.FileSet(), si.Concrete.Obj().Pos()).Filename
concreteFH, err := snapshot.GetFile(ctx, span.URIFromPath(concreteFilename))
if err != nil {
return nil, nil, err
}
parsedConcreteFile, err := snapshot.ParseGo(ctx, concreteFH, ParseFull)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse file declaring implementation type: %w", err)
}
var (
methodsSrc []byte
stubImports []*stubImport // additional imports needed for method stubs
)
if si.Interface.Pkg() == nil && si.Interface.Name() == "error" && si.Interface.Parent() == types.Universe {
methodsSrc = stubErr(ctx, parsedConcreteFile.File, si, snapshot)
} else {
methodsSrc, stubImports, err = stubMethods(ctx, parsedConcreteFile.File, si, snapshot)
if err != nil {
return nil, nil, fmt.Errorf("stubMethods: %w", err)
}
}
// Splice the methods into the file.
// The insertion point is after the top-level declaration
// enclosing the (package-level) type object.
insertPos := parsedConcreteFile.File.End()
for _, decl := range parsedConcreteFile.File.Decls {
if decl.End() > conc.Pos() {
insertPos = decl.End()
break
}
}
concreteSrc, err := concreteFH.Read()
if err != nil {
return nil, nil, fmt.Errorf("error reading concrete file source: %w", err)
}
insertOffset, err := safetoken.Offset(parsedConcreteFile.Tok, insertPos)
if err != nil || insertOffset >= len(concreteSrc) {
return nil, nil, fmt.Errorf("insertion position is past the end of the file")
}
var buf bytes.Buffer
buf.Write(concreteSrc[:insertOffset])
buf.WriteByte('\n')
buf.Write(methodsSrc)
buf.Write(concreteSrc[insertOffset:])
// Re-parse it, splice in imports, pretty-print it.
fset := token.NewFileSet()
newF, err := parser.ParseFile(fset, parsedConcreteFile.File.Name.Name, buf.Bytes(), parser.ParseComments)
if err != nil {
return nil, nil, fmt.Errorf("could not reparse file: %w", err)
}
for _, imp := range stubImports {
astutil.AddNamedImport(fset, newF, imp.Name, imp.Path)
}
var source strings.Builder
if err := format.Node(&source, fset, newF); err != nil {
return nil, nil, fmt.Errorf("format.Node: %w", err)
}
// Return the diff.
diffs := snapshot.View().Options().ComputeEdits(string(parsedConcreteFile.Src), source.String())
var edits []analysis.TextEdit
for _, edit := range diffs {
edits = append(edits, analysis.TextEdit{
Pos: parsedConcreteFile.Tok.Pos(edit.Start),
End: parsedConcreteFile.Tok.Pos(edit.End),
NewText: []byte(edit.New),
})
}
return snapshot.FileSet(), // to match snapshot.ParseGo above
&analysis.SuggestedFix{TextEdits: edits},
nil
}
// stubMethods returns the Go code of all methods
// that implement the given interface
func stubMethods(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) ([]byte, []*stubImport, error) {
concMS := types.NewMethodSet(types.NewPointer(si.Concrete.Obj().Type()))
missing, err := missingMethods(ctx, snapshot, concMS, si.Concrete.Obj().Pkg(), si.Interface, map[string]struct{}{})
if err != nil {
return nil, nil, fmt.Errorf("missingMethods: %w", err)
}
if len(missing) == 0 {
return nil, nil, fmt.Errorf("no missing methods found")
}
var (
stubImports []*stubImport
methodsBuffer bytes.Buffer
)
for _, mi := range missing {
for _, m := range mi.missing {
// TODO(marwan-at-work): this should share the same logic with source.FormatVarType
// as it also accounts for type aliases.
sig := types.TypeString(m.Type(), stubmethods.RelativeToFiles(si.Concrete.Obj().Pkg(), concreteFile, mi.imports, func(name, path string) {
for _, imp := range stubImports {
if imp.Name == name && imp.Path == path {
return
}
}
stubImports = append(stubImports, &stubImport{name, path})
}))
_, err = methodsBuffer.Write(printStubMethod(methodData{
Method: m.Name(),
Concrete: getStubReceiver(si),
Interface: deduceIfaceName(si.Concrete.Obj().Pkg(), si.Interface.Pkg(), si.Interface),
Signature: strings.TrimPrefix(sig, "func"),
}))
if err != nil {
return nil, nil, fmt.Errorf("error printing method: %w", err)
}
methodsBuffer.WriteRune('\n')
}
}
return methodsBuffer.Bytes(), stubImports, nil
}
// stubErr reurns the Go code implementation
// of an error interface relevant to the
// concrete type
func stubErr(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) []byte {
return printStubMethod(methodData{
Method: "Error",
Interface: "error",
Concrete: getStubReceiver(si),
Signature: "() string",
})
}
// getStubReceiver returns the concrete type's name as a method receiver.
// It accounts for type parameters if they exist.
func getStubReceiver(si *stubmethods.StubInfo) string {
var concrete string
if si.Pointer {
concrete += "*"
}
concrete += si.Concrete.Obj().Name()
concrete += FormatTypeParams(typeparams.ForNamed(si.Concrete))
return concrete
}
type methodData struct {
Method string
Interface string
Concrete string
Signature string
}
// printStubMethod takes methodData and returns Go code that represents the given method such as:
//
// // {{ .Method }} implements {{ .Interface }}
// func ({{ .Concrete }}) {{ .Method }}{{ .Signature }} {
// panic("unimplemented")
// }
func printStubMethod(md methodData) []byte {
var b bytes.Buffer
fmt.Fprintf(&b, "// %s implements %s\n", md.Method, md.Interface)
fmt.Fprintf(&b, "func (%s) %s%s {\n\t", md.Concrete, md.Method, md.Signature)
fmt.Fprintln(&b, `panic("unimplemented")`)
fmt.Fprintln(&b, "}")
return b.Bytes()
}
func deduceIfaceName(concretePkg, ifacePkg *types.Package, ifaceObj types.Object) string {
if concretePkg.Path() == ifacePkg.Path() {
return ifaceObj.Name()
}
return fmt.Sprintf("%s.%s", ifacePkg.Name(), ifaceObj.Name())
}
func getStubNodes(pgf *ParsedGoFile, pRng protocol.Range) ([]ast.Node, token.Pos, error) {
rng, err := pgf.RangeToTokenRange(pRng)
if err != nil {
return nil, 0, err
}
nodes, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.End)
return nodes, rng.Start, nil
}
/*
missingMethods takes a concrete type and returns any missing methods for the given interface as well as
any missing interface that might have been embedded to its parent. For example:
type I interface {
io.Writer
Hello()
}
returns
[]*missingInterface{
{
iface: *types.Interface (io.Writer),
file: *ast.File: io.go,
missing []*types.Func{Write},
},
{
iface: *types.Interface (I),
file: *ast.File: myfile.go,
missing: []*types.Func{Hello}
},
}
*/
func missingMethods(ctx context.Context, snapshot Snapshot, concMS *types.MethodSet, concPkg *types.Package, ifaceObj *types.TypeName, visited map[string]struct{}) ([]*missingInterface, error) {
iface, ok := ifaceObj.Type().Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("expected %v to be an interface but got %T", iface, ifaceObj.Type().Underlying())
}
// The built-in error interface is special.
if ifaceObj.Pkg() == nil && ifaceObj.Name() == "error" {
var missingInterfaces []*missingInterface
if concMS.Lookup(nil, "Error") == nil {
errorMethod, _, _ := types.LookupFieldOrMethod(iface, false, nil, "Error")
missingInterfaces = append(missingInterfaces, &missingInterface{
iface: ifaceObj,
missing: []*types.Func{errorMethod.(*types.Func)},
})
}
return missingInterfaces, nil
}
// Parse the imports from the file that declares the interface.
ifaceFilename := safetoken.StartPosition(snapshot.FileSet(), ifaceObj.Pos()).Filename
ifaceFH, err := snapshot.GetFile(ctx, span.URIFromPath(ifaceFilename))
if err != nil {
return nil, err
}
ifaceFile, err := snapshot.ParseGo(ctx, ifaceFH, ParseHeader)
if err != nil {
return nil, fmt.Errorf("error parsing imports from interface file: %w", err)
}
var missing []*types.Func
// Add all the interface methods not defined by the concrete type to missing.
for i := 0; i < iface.NumExplicitMethods(); i++ {
method := iface.ExplicitMethod(i)
if sel := concMS.Lookup(concPkg, method.Name()); sel == nil {
// Concrete type does not have the interface method.
if _, ok := visited[method.Name()]; !ok {
missing = append(missing, method)
visited[method.Name()] = struct{}{}
}
} else {
// Concrete type does have the interface method.
implSig := sel.Type().(*types.Signature)
ifaceSig := method.Type().(*types.Signature)
if !types.Identical(ifaceSig, implSig) {
return nil, fmt.Errorf("mimsatched %q function signatures:\nhave: %s\nwant: %s", method.Name(), implSig, ifaceSig)
}
}
}
// Process embedded interfaces, recursively.
//
// TODO(adonovan): this whole computation could be expressed
// more simply without recursion, driven by the method
// sets of the interface and concrete types. Once the set
// difference (missing methods) is computed, the imports
// from the declaring file(s) could be loaded as needed.
var missingInterfaces []*missingInterface
for i := 0; i < iface.NumEmbeddeds(); i++ {
eiface := iface.Embedded(i).Obj()
em, err := missingMethods(ctx, snapshot, concMS, concPkg, eiface, visited)
if err != nil {
return nil, err
}
missingInterfaces = append(missingInterfaces, em...)
}
// The type checker is deterministic, but its choice of
// ordering of embedded interfaces varies with Go version
// (e.g. go1.17 was sorted, go1.18 was lexical order).
// Sort to ensure test portability.
sort.Slice(missingInterfaces, func(i, j int) bool {
return missingInterfaces[i].iface.Id() < missingInterfaces[j].iface.Id()
})
if len(missing) > 0 {
missingInterfaces = append(missingInterfaces, &missingInterface{
iface: ifaceObj,
imports: ifaceFile.File.Imports,
missing: missing,
})
}
return missingInterfaces, nil
}
// missingInterface represents an interface
// that has all or some of its methods missing
// from the destination concrete type
type missingInterface struct {
iface *types.TypeName
imports []*ast.ImportSpec // the interface's import environment
missing []*types.Func
}
// stubImport represents a newly added import
// statement to the concrete type. If name is not
// empty, then that import is required to have that name.
type stubImport struct{ Name, Path string }