go/analysis/passes/waitgroup: report WaitGroup.Add in goroutine
This CL defines a new analyzer, "waitgroup", that reports a
common mistake with sync.WaitGroup: calling wg.Add(1) inside
the new goroutine, instead of before starting it.
This is a port of Dominik Honnef's SA2000 algorithm,
which uses tree-based pattern matching, to elementary
go/{ast,types} + inspector operations.
Fixes golang/go#18022
Updates golang/go#63796
Change-Id: I9d6d3b602ce963912422ee0459bb1f9522fc51f9
Reviewed-on: https://go-review.googlesource.com/c/tools/+/632915
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/go/analysis/passes/waitgroup/doc.go b/go/analysis/passes/waitgroup/doc.go
new file mode 100644
index 0000000..207f741
--- /dev/null
+++ b/go/analysis/passes/waitgroup/doc.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package waitgroup defines an Analyzer that detects simple misuses
+// of sync.WaitGroup.
+//
+// # Analyzer waitgroup
+//
+// waitgroup: check for misuses of sync.WaitGroup
+//
+// This analyzer detects mistaken calls to the (*sync.WaitGroup).Add
+// method from inside a new goroutine, causing Add to race with Wait:
+//
+// // WRONG
+// var wg sync.WaitGroup
+// go func() {
+// wg.Add(1) // "WaitGroup.Add called from inside new goroutine"
+// defer wg.Done()
+// ...
+// }()
+// wg.Wait() // (may return prematurely before new goroutine starts)
+//
+// The correct code calls Add before starting the goroutine:
+//
+// // RIGHT
+// var wg sync.WaitGroup
+// wg.Add(1)
+// go func() {
+// defer wg.Done()
+// ...
+// }()
+// wg.Wait()
+package waitgroup
diff --git a/go/analysis/passes/waitgroup/main.go b/go/analysis/passes/waitgroup/main.go
new file mode 100644
index 0000000..785eadd
--- /dev/null
+++ b/go/analysis/passes/waitgroup/main.go
@@ -0,0 +1,16 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore
+
+// The waitgroup command applies the golang.org/x/tools/go/analysis/passes/waitgroup
+// analysis to the specified packages of Go source code.
+package main
+
+import (
+ "golang.org/x/tools/go/analysis/passes/waitgroup"
+ "golang.org/x/tools/go/analysis/singlechecker"
+)
+
+func main() { singlechecker.Main(waitgroup.Analyzer) }
diff --git a/go/analysis/passes/waitgroup/testdata/src/a/a.go b/go/analysis/passes/waitgroup/testdata/src/a/a.go
new file mode 100644
index 0000000..c1fecc2
--- /dev/null
+++ b/go/analysis/passes/waitgroup/testdata/src/a/a.go
@@ -0,0 +1,14 @@
+package a
+
+import "sync"
+
+func f() {
+ var wg sync.WaitGroup
+ wg.Add(1) // ok
+ go func() {
+ wg.Add(1) // want "WaitGroup.Add called from inside new goroutine"
+ // ...
+ wg.Add(1) // ok
+ }()
+ wg.Add(1) // ok
+}
diff --git a/go/analysis/passes/waitgroup/waitgroup.go b/go/analysis/passes/waitgroup/waitgroup.go
new file mode 100644
index 0000000..cbb0bfc
--- /dev/null
+++ b/go/analysis/passes/waitgroup/waitgroup.go
@@ -0,0 +1,105 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package waitgroup defines an Analyzer that detects simple misuses
+// of sync.WaitGroup.
+package waitgroup
+
+import (
+ _ "embed"
+ "go/ast"
+ "go/types"
+ "reflect"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/inspect"
+ "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
+ "golang.org/x/tools/go/ast/inspector"
+ "golang.org/x/tools/go/types/typeutil"
+ "golang.org/x/tools/internal/typesinternal"
+)
+
+//go:embed doc.go
+var doc string
+
+var Analyzer = &analysis.Analyzer{
+ Name: "waitgroup",
+ Doc: analysisutil.MustExtractDoc(doc, "waitgroup"),
+ URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/waitgroup",
+ Requires: []*analysis.Analyzer{inspect.Analyzer},
+ Run: run,
+}
+
+func run(pass *analysis.Pass) (any, error) {
+ if !analysisutil.Imports(pass.Pkg, "sync") {
+ return nil, nil // doesn't directly import sync
+ }
+
+ inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
+ nodeFilter := []ast.Node{
+ (*ast.CallExpr)(nil),
+ }
+
+ inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) {
+ if push {
+ call := n.(*ast.CallExpr)
+ if fn, ok := typeutil.Callee(pass.TypesInfo, call).(*types.Func); ok &&
+ isMethodNamed(fn, "sync", "WaitGroup", "Add") &&
+ hasSuffix(stack, wantSuffix) &&
+ backindex(stack, 1) == backindex(stack, 2).(*ast.BlockStmt).List[0] { // ExprStmt must be Block's first stmt
+
+ pass.Reportf(call.Lparen, "WaitGroup.Add called from inside new goroutine")
+ }
+ }
+ return true
+ })
+
+ return nil, nil
+}
+
+// go func() {
+// wg.Add(1)
+// ...
+// }()
+var wantSuffix = []ast.Node{
+ (*ast.GoStmt)(nil),
+ (*ast.CallExpr)(nil),
+ (*ast.FuncLit)(nil),
+ (*ast.BlockStmt)(nil),
+ (*ast.ExprStmt)(nil),
+ (*ast.CallExpr)(nil),
+}
+
+// hasSuffix reports whether stack has the matching suffix,
+// considering only node types.
+func hasSuffix(stack, suffix []ast.Node) bool {
+ // TODO(adonovan): the inspector could implement this for us.
+ if len(stack) < len(suffix) {
+ return false
+ }
+ for i := range len(suffix) {
+ if reflect.TypeOf(backindex(stack, i)) != reflect.TypeOf(backindex(suffix, i)) {
+ return false
+ }
+ }
+ return true
+}
+
+// isMethodNamed reports whether f is a method with the specified
+// package, receiver type, and method names.
+func isMethodNamed(fn *types.Func, pkg, recv, name string) bool {
+ if fn.Pkg() != nil && fn.Pkg().Path() == pkg && fn.Name() == name {
+ if r := fn.Type().(*types.Signature).Recv(); r != nil {
+ if _, gotRecv := typesinternal.ReceiverNamed(r); gotRecv != nil {
+ return gotRecv.Obj().Name() == recv
+ }
+ }
+ }
+ return false
+}
+
+// backindex is like [slices.Index] but from the back of the slice.
+func backindex[T any](slice []T, i int) T {
+ return slice[len(slice)-1-i]
+}
diff --git a/go/analysis/passes/waitgroup/waitgroup_test.go b/go/analysis/passes/waitgroup/waitgroup_test.go
new file mode 100644
index 0000000..bd6443a
--- /dev/null
+++ b/go/analysis/passes/waitgroup/waitgroup_test.go
@@ -0,0 +1,16 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package waitgroup_test
+
+import (
+ "testing"
+
+ "golang.org/x/tools/go/analysis/analysistest"
+ "golang.org/x/tools/go/analysis/passes/waitgroup"
+)
+
+func Test(t *testing.T) {
+ analysistest.Run(t, analysistest.TestData(), waitgroup.Analyzer, "a")
+}
diff --git a/gopls/internal/util/typesutil/typesutil.go b/gopls/internal/util/typesutil/typesutil.go
index 35a14c6..98f5605 100644
--- a/gopls/internal/util/typesutil/typesutil.go
+++ b/gopls/internal/util/typesutil/typesutil.go
@@ -257,8 +257,3 @@
}
return nil
}
-
-func is[T any](x any) bool {
- _, ok := x.(T)
- return ok
-}