go/analysis: extend the loopclosure checker to considering errgroup.Group.Go.
errgroup.Group.Go(f) executes f asynchronously in a Go routine. This Go call is used quite often in Go projects.
Change-Id: I397af118300a25a5c38dbce83fcead974b58cef2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/287173
Reviewed-by: Michael Matloob <matloob@golang.org>
Reviewed-by: Alan Donovan <adonovan@google.com>
Trust: Tim King <taking@google.com>
diff --git a/go/analysis/passes/loopclosure/loopclosure.go b/go/analysis/passes/loopclosure/loopclosure.go
index a14e7eb..3ea9157 100644
--- a/go/analysis/passes/loopclosure/loopclosure.go
+++ b/go/analysis/passes/loopclosure/loopclosure.go
@@ -8,22 +8,14 @@
import (
"go/ast"
+ "go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
+ "golang.org/x/tools/go/types/typeutil"
)
-// TODO(adonovan): also report an error for the following structure,
-// which is often used to ensure that deferred calls do not accumulate
-// in a loop:
-//
-// for i, x := range c {
-// func() {
-// ...reference to i or x...
-// }()
-// }
-
const Doc = `check references to loop variables from within nested functions
This analyzer checks for references to loop variables from within a
@@ -95,16 +87,19 @@
if len(body.List) == 0 {
return
}
- var last *ast.CallExpr
+ // The function invoked in the last return statement.
+ var fun ast.Expr
switch s := body.List[len(body.List)-1].(type) {
case *ast.GoStmt:
- last = s.Call
+ fun = s.Call.Fun
case *ast.DeferStmt:
- last = s.Call
- default:
- return
+ fun = s.Call.Fun
+ case *ast.ExprStmt: // check for errgroup.Group.Go()
+ if call, ok := s.X.(*ast.CallExpr); ok {
+ fun = goInvokes(pass.TypesInfo, call)
+ }
}
- lit, ok := last.Fun.(*ast.FuncLit)
+ lit, ok := fun.(*ast.FuncLit)
if !ok {
return
}
@@ -128,3 +123,43 @@
})
return nil, nil
}
+
+// goInvokes returns a function expression that would be called asynchronously
+// (but not awaited) in another goroutine as a consequence of the call.
+// For example, given the g.Go call below, it returns the function literal expression.
+//
+// import "sync/errgroup"
+// var g errgroup.Group
+// g.Go(func() error { ... })
+//
+// Currently only "golang.org/x/sync/errgroup.Group()" is considered.
+func goInvokes(info *types.Info, call *ast.CallExpr) ast.Expr {
+ f := typeutil.StaticCallee(info, call)
+ // Note: Currently only supports: golang.org/x/sync/errgroup.Go.
+ if f == nil || f.Name() != "Go" {
+ return nil
+ }
+ recv := f.Type().(*types.Signature).Recv()
+ if recv == nil {
+ return nil
+ }
+ rtype, ok := recv.Type().(*types.Pointer)
+ if !ok {
+ return nil
+ }
+ named, ok := rtype.Elem().(*types.Named)
+ if !ok {
+ return nil
+ }
+ if named.Obj().Name() != "Group" {
+ return nil
+ }
+ pkg := f.Pkg()
+ if pkg == nil {
+ return nil
+ }
+ if pkg.Path() != "golang.org/x/sync/errgroup" {
+ return nil
+ }
+ return call.Args[0]
+}
diff --git a/go/analysis/passes/loopclosure/testdata/src/a/a.go b/go/analysis/passes/loopclosure/testdata/src/a/a.go
index e1f7bad..2c8e2e6 100644
--- a/go/analysis/passes/loopclosure/testdata/src/a/a.go
+++ b/go/analysis/passes/loopclosure/testdata/src/a/a.go
@@ -6,6 +6,8 @@
package testdata
+import "golang.org/x/sync/errgroup"
+
func _() {
var s []int
for i, v := range s {
@@ -88,3 +90,31 @@
}()
}
}
+
+// Group is used to test that loopclosure does not match on any type named "Group".
+// The checker only matches on methods "(*...errgroup.Group).Go".
+type Group struct{};
+
+func (g *Group) Go(func() error) {}
+
+func _() {
+ var s []int
+ // errgroup.Group.Go() invokes Go routines
+ g := new(errgroup.Group)
+ for i, v := range s {
+ g.Go(func() error {
+ print(i) // want "loop variable i captured by func literal"
+ print(v) // want "loop variable v captured by func literal"
+ return nil
+ })
+ }
+ // Do not match other Group.Go cases
+ g1 := new(Group)
+ for i, v := range s {
+ g1.Go(func() error {
+ print(i)
+ print(v)
+ return nil
+ })
+ }
+}
diff --git a/go/analysis/passes/loopclosure/testdata/src/golang.org/x/sync/errgroup/errgroup.go b/go/analysis/passes/loopclosure/testdata/src/golang.org/x/sync/errgroup/errgroup.go
new file mode 100644
index 0000000..939fd52
--- /dev/null
+++ b/go/analysis/passes/loopclosure/testdata/src/golang.org/x/sync/errgroup/errgroup.go
@@ -0,0 +1,12 @@
+// Package errgroup synthesizes Go's package "golang.org/x/sync/errgroup",
+// which is used in unit-testing.
+package errgroup
+
+type Group struct {
+}
+
+func (g *Group) Go(f func() error) {
+ go func() {
+ f()
+ }()
+}