blob: 529550aabcfa680efa24f307e80f91627e849597 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package modernize
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/edge"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/analysisinternal"
typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex"
"golang.org/x/tools/internal/typesinternal/typeindex"
)
// The testingContext pass replaces calls to context.WithCancel from within
// tests to a use of testing.{T,B,F}.Context(), added in Go 1.24.
//
// Specifically, the testingContext pass suggests to replace:
//
// ctx, cancel := context.WithCancel(context.Background()) // or context.TODO
// defer cancel()
//
// with:
//
// ctx := t.Context()
//
// provided:
//
// - ctx and cancel are declared by the assignment
// - the deferred call is the only use of cancel
// - the call is within a test or subtest function
// - the relevant testing.{T,B,F} is named and not shadowed at the call
func testingContext(pass *analysis.Pass) {
skipGenerated(pass)
var (
index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
info = pass.TypesInfo
contextWithCancel = index.Object("context", "WithCancel")
)
calls:
for cur := range index.Calls(contextWithCancel) {
call := cur.Node().(*ast.CallExpr)
// Have: context.WithCancel(...)
arg, ok := call.Args[0].(*ast.CallExpr)
if !ok {
continue
}
if !analysisinternal.IsFunctionNamed(typeutil.Callee(info, arg), "context", "Background", "TODO") {
continue
}
// Have: context.WithCancel(context.{Background,TODO}())
parent := cur.Parent()
assign, ok := parent.Node().(*ast.AssignStmt)
if !ok || assign.Tok != token.DEFINE {
continue
}
// Have: a, b := context.WithCancel(context.{Background,TODO}())
// Check that both a and b are declared, not redeclarations.
var lhs []types.Object
for _, expr := range assign.Lhs {
id, ok := expr.(*ast.Ident)
if !ok {
continue calls
}
obj, ok := info.Defs[id]
if !ok {
continue calls
}
lhs = append(lhs, obj)
}
next, ok := parent.NextSibling()
if !ok {
continue
}
defr, ok := next.Node().(*ast.DeferStmt)
if !ok {
continue
}
deferId, ok := defr.Call.Fun.(*ast.Ident)
if !ok || !soleUseIs(index, lhs[1], deferId) {
continue // b is used elsewhere
}
// Have:
// a, b := context.WithCancel(context.{Background,TODO}())
// defer b()
// Check that we are in a test func.
var testObj types.Object // relevant testing.{T,B,F}, or nil
if curFunc, ok := enclosingFunc(cur); ok {
switch n := curFunc.Node().(type) {
case *ast.FuncLit:
if ek, idx := curFunc.ParentEdge(); ek == edge.CallExpr_Args && idx == 1 {
// Have: call(..., func(...) { ...context.WithCancel(...)... })
obj := typeutil.Callee(info, curFunc.Parent().Node().(*ast.CallExpr))
if (analysisinternal.IsMethodNamed(obj, "testing", "T", "Run") ||
analysisinternal.IsMethodNamed(obj, "testing", "B", "Run")) &&
len(n.Type.Params.List[0].Names) == 1 {
// Have tb.Run(..., func(..., tb *testing.[TB]) { ...context.WithCancel(...)... }
testObj = info.Defs[n.Type.Params.List[0].Names[0]]
}
}
case *ast.FuncDecl:
testObj = isTestFn(info, n)
}
}
if testObj != nil && fileUses(info, enclosingFile(cur), "go1.24") {
// Have a test function. Check that we can resolve the relevant
// testing.{T,B,F} at the current position.
if _, obj := lhs[0].Parent().LookupParent(testObj.Name(), lhs[0].Pos()); obj == testObj {
pass.Report(analysis.Diagnostic{
Pos: call.Fun.Pos(),
End: call.Fun.End(),
Category: "testingcontext",
Message: fmt.Sprintf("context.WithCancel can be modernized using %s.Context", testObj.Name()),
SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Replace context.WithCancel with %s.Context", testObj.Name()),
TextEdits: []analysis.TextEdit{{
Pos: assign.Pos(),
End: defr.End(),
NewText: fmt.Appendf(nil, "%s := %s.Context()", lhs[0].Name(), testObj.Name()),
}},
}},
})
}
}
}
}
// soleUseIs reports whether id is the sole Ident that uses obj.
// (It returns false if there were no uses of obj.)
func soleUseIs(index *typeindex.Index, obj types.Object, id *ast.Ident) bool {
empty := true
for use := range index.Uses(obj) {
empty = false
if use.Node() != id {
return false
}
}
return !empty
}
// isTestFn checks whether fn is a test function (TestX, BenchmarkX, FuzzX),
// returning the corresponding types.Object of the *testing.{T,B,F} argument.
// Returns nil if fn is a test function, but the testing.{T,B,F} argument is
// unnamed (or _).
//
// TODO(rfindley): consider handling the case of an unnamed argument, by adding
// an edit to give the argument a name.
//
// Adapted from go/analysis/passes/tests.
// TODO(rfindley): consider refactoring to share logic.
func isTestFn(info *types.Info, fn *ast.FuncDecl) types.Object {
// Want functions with 0 results and 1 parameter.
if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
fn.Type.Params == nil ||
len(fn.Type.Params.List) != 1 ||
len(fn.Type.Params.List[0].Names) != 1 {
return nil
}
prefix := testKind(fn.Name.Name)
if prefix == "" {
return nil
}
if tparams := fn.Type.TypeParams; tparams != nil && len(tparams.List) > 0 {
return nil // test functions must not be generic
}
obj := info.Defs[fn.Type.Params.List[0].Names[0]]
if obj == nil {
return nil // e.g. _ *testing.T
}
var name string
switch prefix {
case "Test":
name = "T"
case "Benchmark":
name = "B"
case "Fuzz":
name = "F"
}
if !analysisinternal.IsPointerToNamed(obj.Type(), "testing", name) {
return nil
}
return obj
}
// testKind returns "Test", "Benchmark", or "Fuzz" if name is a valid resp.
// test, benchmark, or fuzz function name. Otherwise, isTestName returns "".
//
// Adapted from go/analysis/passes/tests.isTestName.
func testKind(name string) string {
var prefix string
switch {
case strings.HasPrefix(name, "Test"):
prefix = "Test"
case strings.HasPrefix(name, "Benchmark"):
prefix = "Benchmark"
case strings.HasPrefix(name, "Fuzz"):
prefix = "Fuzz"
}
if prefix == "" {
return ""
}
suffix := name[len(prefix):]
if len(suffix) == 0 {
// "Test" is ok.
return prefix
}
r, _ := utf8.DecodeRuneInString(suffix)
if unicode.IsLower(r) {
return ""
}
return prefix
}