internal/refactor/inline: permit return conversions in tailcall

Previously, the tail-call strategies required that the callee's
implicit return conversions must be trivial. That meant
returning a nil error (for example) defeated these strategies,
even though it is common in tail-call situations for the caller
function to have identical result types. For example:

  func callee() error { return nil }      // nontrivial conversion

  func caller() error { return callee() } // identical result types

This change permits the tail-call strategies when the
callee and caller's results tuples are identical.

Fixes golang/go#63336

Change-Id: I57d62213023861a2cfebed25b01ec28921efe441
Reviewed-on: https://go-review.googlesource.com/c/tools/+/533075
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go
index f3bc4db..9615ab4 100644
--- a/internal/refactor/inline/inline.go
+++ b/internal/refactor/inline/inline.go
@@ -760,6 +760,7 @@
 	//
 	// If:
 	// - the body is just "return expr" with trivial implicit conversions,
+	//   or the caller's return type matches the callee's,
 	// - all parameters and result vars can be eliminated
 	//   or replaced by a binding decl,
 	// then the call expression can be replaced by the
@@ -767,7 +768,7 @@
 	if len(calleeDecl.Body.List) == 1 &&
 		is[*ast.ReturnStmt](calleeDecl.Body.List[0]) &&
 		len(calleeDecl.Body.List[0].(*ast.ReturnStmt).Results) > 0 && // not a bare return
-		callee.TrivialReturns == callee.TotalReturns {
+		safeReturn(caller, calleeSymbol, callee) {
 		results := calleeDecl.Body.List[0].(*ast.ReturnStmt).Results
 
 		context := callContext(caller.path)
@@ -880,7 +881,8 @@
 	// so long as:
 	// - all parameters can be eliminated or replaced by a binding decl,
 	// - call is a tail-call;
-	// - all returns in body have trivial result conversions;
+	// - all returns in body have trivial result conversions,
+	//   or the caller's return type matches the callee's,
 	// - there is no label conflict;
 	// - no result variable is referenced by name,
 	//   or implicitly by a bare return.
@@ -896,7 +898,7 @@
 	// or implicit) return.
 	if ret, ok := callContext(caller.path).(*ast.ReturnStmt); ok &&
 		len(ret.Results) == 1 &&
-		callee.TrivialReturns == callee.TotalReturns &&
+		safeReturn(caller, calleeSymbol, callee) &&
 		!callee.HasBareReturn &&
 		(!needBindingDecl || bindingDeclStmt != nil) &&
 		!hasLabelConflict(caller.path, callee.Labels) &&
@@ -2601,3 +2603,34 @@
 	}
 	return names
 }
+
+// safeReturn reports whether the callee's return statements may be safely
+// used to return from the function enclosing the caller (which must exist).
+func safeReturn(caller *Caller, calleeSymbol *types.Func, callee *gobCallee) bool {
+	// It is safe if all callee returns involve only trivial conversions.
+	if callee.TrivialReturns == callee.TotalReturns {
+		return true
+	}
+
+	var callerType types.Type
+	// Find type of innermost function enclosing call.
+	// (Beware: Caller.enclosingFunc is the outermost.)
+loop:
+	for _, n := range caller.path {
+		switch f := n.(type) {
+		case *ast.FuncDecl:
+			callerType = caller.Info.ObjectOf(f.Name).Type()
+			break loop
+		case *ast.FuncLit:
+			callerType = caller.Info.TypeOf(f)
+			break loop
+		}
+	}
+
+	// Non-trivial return conversions in the callee are permitted
+	// if the same non-trivial conversion would occur after inlining,
+	// i.e. if the caller and callee results tuples are identical.
+	callerResults := callerType.(*types.Signature).Results()
+	calleeResults := calleeSymbol.Type().(*types.Signature).Results()
+	return types.Identical(callerResults, calleeResults)
+}
diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go
index 8918977..8362e44 100644
--- a/internal/refactor/inline/inline_test.go
+++ b/internal/refactor/inline/inline_test.go
@@ -565,6 +565,25 @@
 			`func _() { f() }`,
 			`func _() { func() { defer f(); println() }() }`,
 		},
+		// Tests for issue #63336:
+		{
+			"Tail call with non-trivial return conversion (caller.sig = callee.sig).",
+			`func f() error { if true { return nil } else { return e } }; var e struct{error}`,
+			`func _() error { return f() }`,
+			`func _() error {
+	if true {
+		return nil
+	} else {
+		return e
+	}
+}`,
+		},
+		{
+			"Tail call with non-trivial return conversion (caller.sig != callee.sig).",
+			`func f() error { return E{} }; type E struct{error}`,
+			`func _() any { return f() }`,
+			`func _() any { return func() error { return E{} }() }`,
+		},
 	})
 }