internal/refactor/inline: permit return conversions in tailcall
Previously, the tail-call strategies required that the callee's
implicit return conversions must be trivial. That meant
returning a nil error (for example) defeated these strategies,
even though it is common in tail-call situations for the caller
function to have identical result types. For example:
func callee() error { return nil } // nontrivial conversion
func caller() error { return callee() } // identical result types
This change permits the tail-call strategies when the
callee and caller's results tuples are identical.
Fixes golang/go#63336
Change-Id: I57d62213023861a2cfebed25b01ec28921efe441
Reviewed-on: https://go-review.googlesource.com/c/tools/+/533075
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go
index f3bc4db..9615ab4 100644
--- a/internal/refactor/inline/inline.go
+++ b/internal/refactor/inline/inline.go
@@ -760,6 +760,7 @@
//
// If:
// - the body is just "return expr" with trivial implicit conversions,
+ // or the caller's return type matches the callee's,
// - all parameters and result vars can be eliminated
// or replaced by a binding decl,
// then the call expression can be replaced by the
@@ -767,7 +768,7 @@
if len(calleeDecl.Body.List) == 1 &&
is[*ast.ReturnStmt](calleeDecl.Body.List[0]) &&
len(calleeDecl.Body.List[0].(*ast.ReturnStmt).Results) > 0 && // not a bare return
- callee.TrivialReturns == callee.TotalReturns {
+ safeReturn(caller, calleeSymbol, callee) {
results := calleeDecl.Body.List[0].(*ast.ReturnStmt).Results
context := callContext(caller.path)
@@ -880,7 +881,8 @@
// so long as:
// - all parameters can be eliminated or replaced by a binding decl,
// - call is a tail-call;
- // - all returns in body have trivial result conversions;
+ // - all returns in body have trivial result conversions,
+ // or the caller's return type matches the callee's,
// - there is no label conflict;
// - no result variable is referenced by name,
// or implicitly by a bare return.
@@ -896,7 +898,7 @@
// or implicit) return.
if ret, ok := callContext(caller.path).(*ast.ReturnStmt); ok &&
len(ret.Results) == 1 &&
- callee.TrivialReturns == callee.TotalReturns &&
+ safeReturn(caller, calleeSymbol, callee) &&
!callee.HasBareReturn &&
(!needBindingDecl || bindingDeclStmt != nil) &&
!hasLabelConflict(caller.path, callee.Labels) &&
@@ -2601,3 +2603,34 @@
}
return names
}
+
+// safeReturn reports whether the callee's return statements may be safely
+// used to return from the function enclosing the caller (which must exist).
+func safeReturn(caller *Caller, calleeSymbol *types.Func, callee *gobCallee) bool {
+ // It is safe if all callee returns involve only trivial conversions.
+ if callee.TrivialReturns == callee.TotalReturns {
+ return true
+ }
+
+ var callerType types.Type
+ // Find type of innermost function enclosing call.
+ // (Beware: Caller.enclosingFunc is the outermost.)
+loop:
+ for _, n := range caller.path {
+ switch f := n.(type) {
+ case *ast.FuncDecl:
+ callerType = caller.Info.ObjectOf(f.Name).Type()
+ break loop
+ case *ast.FuncLit:
+ callerType = caller.Info.TypeOf(f)
+ break loop
+ }
+ }
+
+ // Non-trivial return conversions in the callee are permitted
+ // if the same non-trivial conversion would occur after inlining,
+ // i.e. if the caller and callee results tuples are identical.
+ callerResults := callerType.(*types.Signature).Results()
+ calleeResults := calleeSymbol.Type().(*types.Signature).Results()
+ return types.Identical(callerResults, calleeResults)
+}
diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go
index 8918977..8362e44 100644
--- a/internal/refactor/inline/inline_test.go
+++ b/internal/refactor/inline/inline_test.go
@@ -565,6 +565,25 @@
`func _() { f() }`,
`func _() { func() { defer f(); println() }() }`,
},
+ // Tests for issue #63336:
+ {
+ "Tail call with non-trivial return conversion (caller.sig = callee.sig).",
+ `func f() error { if true { return nil } else { return e } }; var e struct{error}`,
+ `func _() error { return f() }`,
+ `func _() error {
+ if true {
+ return nil
+ } else {
+ return e
+ }
+}`,
+ },
+ {
+ "Tail call with non-trivial return conversion (caller.sig != callee.sig).",
+ `func f() error { return E{} }; type E struct{error}`,
+ `func _() any { return f() }`,
+ `func _() any { return func() error { return E{} }() }`,
+ },
})
}