blob: abf5885cee28cc00eaf60dc5ce53bc98d6e79b6c [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package modernize
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/analysis/analyzerutil"
typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
"golang.org/x/tools/internal/astutil"
"golang.org/x/tools/internal/refactor"
"golang.org/x/tools/internal/typesinternal/typeindex"
"golang.org/x/tools/internal/versions"
)
var WaitGroupAnalyzer = &analysis.Analyzer{
Name: "waitgroup",
Doc: analyzerutil.MustExtractDoc(doc, "waitgroup"),
Requires: []*analysis.Analyzer{
inspect.Analyzer,
typeindexanalyzer.Analyzer,
},
Run: waitgroup,
URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#waitgroup",
}
// The waitgroup pass replaces old more complex code with
// go1.25 added API WaitGroup.Go.
//
// Patterns:
//
// 1. wg.Add(1); go func() { defer wg.Done(); ... }()
// =>
// wg.Go(go func() { ... })
//
// 2. wg.Add(1); go func() { ...; wg.Done() }()
// =>
// wg.Go(go func() { ... })
//
// The wg.Done must occur within the first statement of the block in a
// defer format or last statement of the block, and the offered fix
// only removes the first/last wg.Done call. It doesn't fix existing
// wrong usage of sync.WaitGroup.
//
// The use of WaitGroup.Go in pattern 1 implicitly introduces a
// 'defer', which may change the behavior in the case of panic from
// the "..." logic. In this instance, the change is safe: before and
// after the transformation, an unhandled panic inevitably results in
// a fatal crash. The fact that the transformed code calls wg.Done()
// before the crash doesn't materially change anything. (If Done had
// other effects, or blocked, or if WaitGroup.Go propagated panics
// from child to parent goroutine, the argument would be different.)
func waitgroup(pass *analysis.Pass) (any, error) {
var (
index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
info = pass.TypesInfo
syncWaitGroupAdd = index.Selection("sync", "WaitGroup", "Add")
syncWaitGroupDone = index.Selection("sync", "WaitGroup", "Done")
)
if !index.Used(syncWaitGroupDone) {
return nil, nil
}
for curAddCall := range index.Calls(syncWaitGroupAdd) {
// Extract receiver from wg.Add call.
addCall := curAddCall.Node().(*ast.CallExpr)
if !isIntLiteral(info, addCall.Args[0], 1) {
continue // not a call to wg.Add(1)
}
// Inv: the Args[0] check ensures addCall is not of
// the form sync.WaitGroup.Add(&wg, 1).
addCallRecv := ast.Unparen(addCall.Fun).(*ast.SelectorExpr).X
// Following statement must be go func() { ... } ().
curAddStmt := curAddCall.Parent()
if !is[*ast.ExprStmt](curAddStmt.Node()) {
continue // unnecessary parens?
}
curNext, ok := curAddCall.Parent().NextSibling()
if !ok {
continue // no successor
}
goStmt, ok := curNext.Node().(*ast.GoStmt)
if !ok {
continue // not a go stmt
}
lit, ok := goStmt.Call.Fun.(*ast.FuncLit)
if !ok || len(goStmt.Call.Args) != 0 {
continue // go argument is not func(){...}()
}
list := lit.Body.List
if len(list) == 0 {
continue
}
// Body must start with "defer wg.Done()" or end with "wg.Done()".
var doneStmt ast.Stmt
if deferStmt, ok := list[0].(*ast.DeferStmt); ok &&
typeutil.Callee(info, deferStmt.Call) == syncWaitGroupDone &&
astutil.EqualSyntax(ast.Unparen(deferStmt.Call.Fun).(*ast.SelectorExpr).X, addCallRecv) {
doneStmt = deferStmt // "defer wg.Done()"
} else if lastStmt, ok := list[len(list)-1].(*ast.ExprStmt); ok {
if doneCall, ok := lastStmt.X.(*ast.CallExpr); ok &&
typeutil.Callee(info, doneCall) == syncWaitGroupDone &&
astutil.EqualSyntax(ast.Unparen(doneCall.Fun).(*ast.SelectorExpr).X, addCallRecv) {
doneStmt = lastStmt // "wg.Done()"
}
}
if doneStmt == nil {
continue
}
curDoneStmt, ok := curNext.FindNode(doneStmt)
if !ok {
panic("can't find Cursor for 'done' statement")
}
file := astutil.EnclosingFile(curAddCall)
if !analyzerutil.FileUsesGoVersion(pass, file, versions.Go1_25) {
continue
}
tokFile := pass.Fset.File(file.Pos())
var addCallRecvText bytes.Buffer
err := printer.Fprint(&addCallRecvText, pass.Fset, addCallRecv)
if err != nil {
continue // error getting text for the edit
}
pass.Report(analysis.Diagnostic{
// go func() {
// ~~~~~~~~~
Pos: goStmt.Pos(),
End: lit.Type.End(),
Message: "Goroutine creation can be simplified using WaitGroup.Go",
SuggestedFixes: []analysis.SuggestedFix{{
Message: "Simplify by using WaitGroup.Go",
TextEdits: slices.Concat(
// delete "wg.Add(1)"
refactor.DeleteStmt(tokFile, curAddStmt),
// delete "wg.Done()" or "defer wg.Done()"
refactor.DeleteStmt(tokFile, curDoneStmt),
[]analysis.TextEdit{
// go func()
// ------
// wg.Go(func()
{
Pos: goStmt.Pos(),
End: goStmt.Call.Pos(),
NewText: fmt.Appendf(nil, "%s.Go(", addCallRecvText.String()),
},
// ... }()
// -
// ... } )
{
Pos: goStmt.Call.Lparen,
End: goStmt.Call.Rparen,
},
},
),
}},
})
}
return nil, nil
}