blob: 8228faf0fc89c66aa96af8cd7721ff54957cc89e [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package golang
// This file defines the behavior of the "Add test for FUNC" command.
import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"text/template"
"unicode"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/cache/metadata"
"golang.org/x/tools/gopls/internal/cache/parsego"
"golang.org/x/tools/gopls/internal/protocol"
goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
"golang.org/x/tools/internal/imports"
"golang.org/x/tools/internal/typesinternal"
)
const testTmplString = `
func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
{{- /* Test cases struct declaration and empty initialization. */}}
tests := []struct {
name string // description of this test case
{{- $commentPrinted := false }}
{{- if and .Receiver .Receiver.Constructor}}
{{- range .Receiver.Constructor.Args}}
{{- if .Name}}
{{- if not $commentPrinted}}
// Named input parameters for receiver constructor.
{{- $commentPrinted = true }}
{{- end}}
{{.Name}} {{.Type}}
{{- end}}
{{- end}}
{{- end}}
{{- $commentPrinted := false }}
{{- range .Func.Args}}
{{- if .Name}}
{{- if not $commentPrinted}}
// Named input parameters for target function.
{{- $commentPrinted = true }}
{{- end}}
{{.Name}} {{.Type}}
{{- end}}
{{- end}}
{{- range $index, $res := .Func.Results}}
{{- if eq $res.Name "gotErr"}}
wantErr bool
{{- else if eq $index 0}}
want {{$res.Type}}
{{- else}}
want{{add $index 1}} {{$res.Type}}
{{- end}}
{{- end}}
}{
// TODO: Add test cases.
}
{{- /* Loop over all the test cases. */}}
for _, tt := range tests {
t.Run(tt.name, func(t *{{.TestingPackageName}}.T) {
{{- /* Constructor or empty initialization. */}}
{{- if .Receiver}}
{{- if .Receiver.Constructor}}
{{- /* Receiver variable by calling constructor. */}}
{{fieldNames .Receiver.Constructor.Results ""}} := {{if .PackageName}}{{.PackageName}}.{{end}}
{{- .Receiver.Constructor.Name}}
{{- /* Constructor input parameters. */ -}}
(
{{- range $index, $arg := .Receiver.Constructor.Args}}
{{- if ne $index 0}}, {{end}}
{{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
{{- end -}}
)
{{- /* Handles the error return from constructor. */}}
{{- $last := last .Receiver.Constructor.Results}}
{{- if eq $last.Type "error"}}
if err != nil {
t.Fatalf("could not construct receiver type: %v", err)
}
{{- end}}
{{- else}}
{{- /* Receiver variable declaration. */}}
// TODO: construct the receiver type.
var {{.Receiver.Var.Name}} {{.Receiver.Var.Type}}
{{- end}}
{{- end}}
{{- /* Got variables. */}}
{{if .Func.Results}}{{fieldNames .Func.Results ""}} := {{end}}
{{- /* Call expression. */}}
{{- if .Receiver}}{{/* Call method by VAR.METHOD. */}}
{{- .Receiver.Var.Name}}.
{{- else if .PackageName}}{{/* Call function by PACKAGE.FUNC. */}}
{{- .PackageName}}.
{{- end}}{{.Func.Name}}
{{- /* Input parameters. */ -}}
(
{{- range $index, $arg := .Func.Args}}
{{- if ne $index 0}}, {{end}}
{{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
{{- end -}}
)
{{- /* Handles the returned error before the rest of return value. */}}
{{- $last := last .Func.Results}}
{{- if eq $last.Type "error"}}
if gotErr != nil {
if !tt.wantErr {
t.Errorf("{{$.Func.Name}}() failed: %v", gotErr)
}
return
}
if tt.wantErr {
t.Fatal("{{$.Func.Name}}() succeeded unexpectedly")
}
{{- end}}
{{- /* Compare the returned values except for the last returned error. */}}
{{- if or (and .Func.Results (ne $last.Type "error")) (and (gt (len .Func.Results) 1) (eq $last.Type "error"))}}
// TODO: update the condition below to compare got with tt.want.
{{- range $index, $res := .Func.Results}}
{{- if ne $res.Name "gotErr"}}
if true {
t.Errorf("{{$.Func.Name}}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
}
{{- end}}
{{- end}}
{{- end}}
})
}
}
`
// Name is the name of the field this input parameter should reference.
// Value is the expression this input parameter should accept.
//
// Exactly one of Name or Value must be set.
type field struct {
Name, Type, Value string
}
type function struct {
Name string
Args []field
Results []field
}
type receiver struct {
// Var is the name and type of the receiver variable.
Var field
// Constructor holds information about the constructor for the receiver type.
// If no qualified constructor is found, this field will be nil.
Constructor *function
}
type testInfo struct {
// TestingPackageName is the package name should be used when referencing
// package "testing"
TestingPackageName string
// PackageName is the package name the target function/method is delcared from.
PackageName string
TestFuncName string
// Func holds information about the function or method being tested.
Func function
// Receiver holds information about the receiver of the function or method
// being tested.
// This field is nil for functions and non-nil for methods.
Receiver *receiver
}
var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
"add": func(a, b int) int { return a + b },
"last": func(slice []field) field {
if len(slice) == 0 {
return field{}
}
return slice[len(slice)-1]
},
"fieldNames": func(fields []field, qualifier string) (res string) {
var names []string
for _, f := range fields {
names = append(names, qualifier+f.Name)
}
return strings.Join(names, ", ")
},
}).Parse(testTmplString))
// AddTestForFunc adds a test for the function enclosing the given input range.
// It creates a _test.go file if one does not already exist.
func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.Location) (changes []protocol.DocumentChange, _ error) {
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, loc.URI)
if err != nil {
return nil, err
}
if metadata.IsCommandLineArguments(pkg.Metadata().ID) {
return nil, fmt.Errorf("current file in command-line-arguments package")
}
if errors := pkg.ParseErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has parse errors: %v", errors[0])
}
if errors := pkg.TypeErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has type errors: %v", errors[0])
}
// All three maps map the path of an imported package to
// the local name if explicit or "" otherwise.
var (
fileImports map[string]string // imports in foo.go file
testImports map[string]string // imports in foo_test.go file
extraImports = make(map[string]string) // imports to add to test file
)
var collectImports = func(file *ast.File) (map[string]string, error) {
imps := make(map[string]string)
for _, spec := range file.Imports {
// TODO(hxjiang): support dot imports.
if spec.Name != nil && spec.Name.Name == "." {
return nil, fmt.Errorf("\"add test for func\" does not support files containing dot imports")
}
path, err := strconv.Unquote(spec.Path.Value)
if err != nil {
return nil, err
}
if spec.Name != nil {
if spec.Name.Name == "_" {
continue
}
imps[path] = spec.Name.Name
} else {
imps[path] = ""
}
}
return imps, nil
}
// Collect all the imports from the x.go, keep track of the local package name.
if fileImports, err = collectImports(pgf.File); err != nil {
return nil, err
}
testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go"
goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.DirPath(), testBase))
testFH, err := snapshot.ReadFile(ctx, goTestFileURI)
if err != nil {
return nil, err
}
// TODO(hxjiang): use a fresh name if the same test function name already
// exist.
var (
eofRange protocol.Range // empty selection at end of new file
// edits contains all the text edits to be applied to the test file.
edits []protocol.TextEdit
// xtest indicates whether the test file use package x or x_test.
// TODO(hxjiang): We can discuss the option to interpret the user's
// intention by which function they are selecting. Have one file for
// x_test package testing, one file for x package testing.
xtest = true
)
start, end, err := pgf.RangePos(loc.Range)
if err != nil {
return nil, err
}
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
if len(path) < 2 {
return nil, fmt.Errorf("no enclosing function")
}
decl, ok := path[len(path)-2].(*ast.FuncDecl)
if !ok {
return nil, fmt.Errorf("no enclosing function")
}
fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)
sig := fn.Signature()
testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
changes = append(changes, protocol.DocumentChangeCreate(goTestFileURI))
// header is the buffer containing the text to add to the beginning of the file.
var header bytes.Buffer
// If this test file was created by the gopls, add a copyright header and
// package decl based on the originating file.
// Search for something that looks like a copyright header, to replicate
// in the new file.
if c := copyrightComment(pgf.File); c != nil {
start, end, err := pgf.NodeOffsets(c)
if err != nil {
return nil, err
}
header.Write(pgf.Src[start:end])
// One empty line between copyright header and following.
header.WriteString("\n\n")
}
// If this test file was created by gopls, add build constraints
// matching the non-test file.
if c := buildConstraintComment(pgf.File); c != nil {
start, end, err := pgf.NodeOffsets(c)
if err != nil {
return nil, err
}
header.Write(pgf.Src[start:end])
// One empty line between build constraint and following.
header.WriteString("\n\n")
}
// Determine if a new test file should use in-package test (package x)
// or external test (package x_test). If any of the function parameters
// reference an unexported object, we cannot write out test cases from
// an x_test package.
externalTestOK := func() bool {
if !fn.Exported() {
return false
}
if fn.Signature().Recv() != nil {
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() {
return false
}
}
refsUnexported := false
ast.Inspect(decl, func(n ast.Node) bool {
// The original function refs to an unexported object from the
// same package, so further inspection is unnecessary.
if refsUnexported {
return false
}
switch t := n.(type) {
case *ast.BlockStmt:
// Avoid inspect the function body.
return false
case *ast.Ident:
// Use test variant (package foo) if the function signature
// references any unexported objects (like types or
// constants) from the same package.
// Note: types.PkgName is excluded from this check as it's
// always defined in the same package.
if obj, ok := pkg.TypesInfo().Uses[t]; ok && !obj.Exported() && obj.Pkg() == pkg.Types() && !is[*types.PkgName](obj) {
refsUnexported = true
}
return false
default:
return true
}
})
return !refsUnexported
}
xtest = externalTestOK()
if xtest {
fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name())
} else {
fmt.Fprintf(&header, "package %s\n", pkg.Types().Name())
}
// Write the copyright and package decl to the beginning of the file.
edits = append(edits, protocol.TextEdit{
Range: protocol.Range{},
NewText: header.String(),
})
} else { // existing _test.go file.
if testPGF.File.Name == nil || testPGF.File.Name.NamePos == token.NoPos {
return nil, fmt.Errorf("missing package declaration")
}
switch testPGF.File.Name.Name {
case pgf.File.Name.Name:
xtest = false
case pgf.File.Name.Name + "_test":
xtest = true
default:
return nil, fmt.Errorf("invalid package declaration %q in test file %q", testPGF.File.Name, testPGF)
}
eofRange, err = testPGF.PosRange(testPGF.File.FileEnd, testPGF.File.FileEnd)
if err != nil {
return nil, err
}
// Collect all the imports from the foo_test.go.
if testImports, err = collectImports(testPGF.File); err != nil {
return nil, err
}
}
// qf qualifier determines the correct package name to use for a type in
// foo_test.go. It does this by:
// - Consult imports map from test file foo_test.go.
// - If not found, consult imports map from original file foo.go.
// If the package is not imported in test file foo_test.go, it is added to
// extraImports map.
qf := func(p *types.Package) string {
// References from an in-package test should not be qualified.
if !xtest && p == pkg.Types() {
return ""
}
// Prefer using the package name if already defined in foo_test.go
if local, ok := testImports[p.Path()]; ok {
if local != "" {
return local
} else {
return p.Name()
}
}
// TODO(hxjiang): we should consult the scope of the test package to
// ensure these new imports do not shadow any package-level names.
// Prefer the local import name (if any) used in the package under test.
if local, ok := fileImports[p.Path()]; ok && local != "" {
extraImports[p.Path()] = local
return local
}
// Fall back to the package name since there is no renaming.
extraImports[p.Path()] = ""
return p.Name()
}
if xtest {
// Reject if function/method is unexported.
if !fn.Exported() {
return nil, fmt.Errorf("cannot add test of unexported function %s to external test package %s_test", decl.Name, pgf.File.Name)
}
// Reject if receiver is unexported.
if sig.Recv() != nil {
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() {
return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
}
}
// TODO(hxjiang): reject if the any input parameter type is unexported.
// TODO(hxjiang): reject if any return value type is unexported. Explore
// the option to drop the return value if the type is unexported.
}
testName, err := testName(fn)
if err != nil {
return nil, err
}
data := testInfo{
TestingPackageName: qf(types.NewPackage("testing", "testing")),
PackageName: qf(pkg.Types()),
TestFuncName: testName,
Func: function{
Name: fn.Name(),
},
}
errorType := types.Universe.Lookup("error").Type()
var isContextType = func(t types.Type) bool {
named, ok := t.(*types.Named)
if !ok {
return false
}
return named.Obj().Pkg().Path() == "context" && named.Obj().Name() == "Context"
}
for i := range sig.Params().Len() {
param := sig.Params().At(i)
name, typ := param.Name(), param.Type()
f := field{Type: types.TypeString(typ, qf)}
if i == 0 && isContextType(typ) {
f.Value = qf(types.NewPackage("context", "context")) + ".Background()"
} else if name == "" || name == "_" {
f.Value = typesinternal.ZeroString(typ, qf)
} else {
f.Name = name
}
data.Func.Args = append(data.Func.Args, f)
}
for i := range sig.Results().Len() {
typ := sig.Results().At(i).Type()
var name string
if i == sig.Results().Len()-1 && types.Identical(typ, errorType) {
name = "gotErr"
} else if i == 0 {
name = "got"
} else {
name = fmt.Sprintf("got%d", i+1)
}
data.Func.Results = append(data.Func.Results, field{
Name: name,
Type: types.TypeString(typ, qf),
})
}
if sig.Recv() != nil {
// Find the preferred type for the receiver. We don't use
// typesinternal.ReceiverNamed here as we want to preserve aliases.
recvType := sig.Recv().Type()
if ptr, ok := recvType.(*types.Pointer); ok {
recvType = ptr.Elem()
}
t, ok := recvType.(typesinternal.NamedOrAlias)
if !ok {
return nil, fmt.Errorf("the receiver type is neither named type nor alias type")
}
var varName string
{
var possibleNames []string // list of candidates, preferring earlier entries.
if len(sig.Recv().Name()) > 0 {
possibleNames = append(possibleNames,
sig.Recv().Name(), // receiver name.
string(sig.Recv().Name()[0]), // first character of receiver name.
)
}
possibleNames = append(possibleNames,
string(t.Obj().Name()[0]), // first character of receiver type name.
)
if len(t.Obj().Name()) >= 2 {
possibleNames = append(possibleNames,
string(t.Obj().Name()[:2]), // first two character of receiver type name.
)
}
var camelCase []rune
for i, s := range t.Obj().Name() {
if i == 0 || unicode.IsUpper(s) {
camelCase = append(camelCase, s)
}
}
possibleNames = append(possibleNames,
string(camelCase), // captalized initials.
)
for _, name := range possibleNames {
name = strings.ToLower(name)
if name == "" || name == "t" || name == "tt" {
continue
}
varName = name
break
}
if varName == "" {
varName = "r" // default as "r" for "receiver".
}
}
data.Receiver = &receiver{
Var: field{
Name: varName,
Type: types.TypeString(recvType, qf),
},
}
// constructor is the selected constructor for type T.
var constructor *types.Func
// When finding the qualified constructor, the function should return the
// any type whose named type is the same type as T's named type.
_, wantType := typesinternal.ReceiverNamed(sig.Recv())
for _, name := range pkg.Types().Scope().Names() {
f, ok := pkg.Types().Scope().Lookup(name).(*types.Func)
if !ok {
continue
}
if f.Signature().Recv() != nil {
continue
}
// Unexported constructor is not visible in x_test package.
if xtest && !f.Exported() {
continue
}
// Only allow constructors returning T, T, (T, error), or (T, error).
if f.Signature().Results().Len() > 2 || f.Signature().Results().Len() == 0 {
continue
}
_, gotType := typesinternal.ReceiverNamed(f.Signature().Results().At(0))
if gotType == nil || !types.Identical(gotType, wantType) {
continue
}
if f.Signature().Results().Len() == 2 && !types.Identical(f.Signature().Results().At(1).Type(), errorType) {
continue
}
if constructor == nil {
constructor = f
}
// Functions named NewType are prioritized as constructors over other
// functions that match only the signature criteria.
if strings.EqualFold(strings.ToLower(f.Name()), strings.ToLower("new"+t.Obj().Name())) {
constructor = f
}
}
if constructor != nil {
data.Receiver.Constructor = &function{Name: constructor.Name()}
for i := range constructor.Signature().Params().Len() {
param := constructor.Signature().Params().At(i)
name, typ := param.Name(), param.Type()
f := field{Type: types.TypeString(typ, qf)}
if i == 0 && isContextType(typ) {
f.Value = qf(types.NewPackage("context", "context")) + ".Background()"
} else if name == "" || name == "_" {
f.Value = typesinternal.ZeroString(typ, qf)
} else {
f.Name = name
}
data.Receiver.Constructor.Args = append(data.Receiver.Constructor.Args, f)
}
for i := range constructor.Signature().Results().Len() {
typ := constructor.Signature().Results().At(i).Type()
var name string
if i == 0 {
// The first return value must be of type T, *T, or a type whose named
// type is the same as named type of T.
name = varName
} else if i == constructor.Signature().Results().Len()-1 && types.Identical(typ, errorType) {
name = "err"
} else {
// Drop any return values beyond the first and the last.
// e.g., "f, _, _, err := NewFoo()".
name = "_"
}
data.Receiver.Constructor.Results = append(data.Receiver.Constructor.Results, field{
Name: name,
Type: types.TypeString(typ, qf),
})
}
}
}
// Resolves duplicate parameter names between the function and its
// receiver's constructor. It adds prefix to the constructor's parameters
// until no conflicts remain.
if data.Receiver != nil && data.Receiver.Constructor != nil {
seen := map[string]bool{}
for _, f := range data.Func.Args {
if f.Name == "" {
continue
}
seen[f.Name] = true
}
// "" for no change, "c" for constructor, "i" for input.
for _, prefix := range []string{"", "c", "c_", "i", "i_"} {
conflict := false
for _, f := range data.Receiver.Constructor.Args {
if f.Name == "" {
continue
}
if seen[prefix+f.Name] {
conflict = true
break
}
}
if !conflict {
for i, f := range data.Receiver.Constructor.Args {
if f.Name == "" {
continue
}
data.Receiver.Constructor.Args[i].Name = prefix + data.Receiver.Constructor.Args[i].Name
}
break
}
}
}
// Compute edits to update imports.
//
// If we're adding to an existing test file, we need to adjust existing
// imports. Otherwise, we can simply write out the imports to the new file.
if testPGF != nil {
var importFixes []*imports.ImportFix
for path, name := range extraImports {
importFixes = append(importFixes, &imports.ImportFix{
StmtInfo: imports.ImportInfo{
ImportPath: path,
Name: name,
},
FixType: imports.AddImport,
})
}
importEdits, err := ComputeImportFixEdits(snapshot.Options().Local, testPGF.Src, importFixes...)
if err != nil {
return nil, fmt.Errorf("could not compute the import fix edits: %w", err)
}
edits = append(edits, importEdits...)
} else {
var importsBuffer bytes.Buffer
if len(extraImports) == 1 {
importsBuffer.WriteString("\nimport ")
for path, name := range extraImports {
if name != "" {
importsBuffer.WriteString(name + " ")
}
importsBuffer.WriteString(fmt.Sprintf("\"%s\"\n", path))
}
} else {
importsBuffer.WriteString("\nimport(")
// Loop over the map in sorted order ensures deterministic outcome.
paths := make([]string, 0, len(extraImports))
for key := range extraImports {
paths = append(paths, key)
}
sort.Strings(paths)
for _, path := range paths {
importsBuffer.WriteString("\n\t")
if name := extraImports[path]; name != "" {
importsBuffer.WriteString(name + " ")
}
importsBuffer.WriteString(fmt.Sprintf("\"%s\"", path))
}
importsBuffer.WriteString("\n)\n")
}
edits = append(edits, protocol.TextEdit{
Range: protocol.Range{},
NewText: importsBuffer.String(),
})
}
var test bytes.Buffer
if err := testTmpl.Execute(&test, data); err != nil {
return nil, err
}
formatted, err := format.Source(test.Bytes())
if err != nil {
return nil, err
}
edits = append(edits,
protocol.TextEdit{
Range: eofRange,
NewText: string(formatted),
})
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
}
// testName returns the name of the function to use for the new function that
// tests fn.
// Returns empty string if the fn is ill typed or nil.
func testName(fn *types.Func) (string, error) {
if fn == nil {
return "", fmt.Errorf("input nil function")
}
testName := "Test"
if recv := fn.Signature().Recv(); recv != nil { // method declaration.
// Retrieve the unpointered receiver type to ensure the test name is based
// on the topmost alias or named type, not the alias' RHS type (potentially
// unexported) type.
// For example:
// type Foo = foo // Foo is an exported alias for the unexported type foo
recvType := recv.Type()
if ptr, ok := recv.Type().(*types.Pointer); ok {
recvType = ptr.Elem()
}
t, ok := recvType.(typesinternal.NamedOrAlias)
if !ok {
return "", fmt.Errorf("receiver type is not named type or alias type")
}
if !t.Obj().Exported() {
testName += "_"
}
testName += t.Obj().Name() + "_"
} else if !fn.Exported() { // unexported function declaration.
testName += "_"
}
return testName + fn.Name(), nil
}