blob: 21f3147d000e6ce7f3d2f0eaf329cf59901f710e [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package inline
import (
"bytes"
"fmt"
"go/ast"
"go/constant"
"go/format"
"go/parser"
"go/token"
"go/types"
pathpkg "path"
"reflect"
"strconv"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/imports"
internalastutil "golang.org/x/tools/internal/astutil"
"golang.org/x/tools/internal/typeparams"
)
// A Caller describes the function call and its enclosing context.
//
// The client is responsible for populating this struct and passing it to Inline.
type Caller struct {
Fset *token.FileSet
Types *types.Package
Info *types.Info
File *ast.File
Call *ast.CallExpr
Content []byte // source of file containing
path []ast.Node // path from call to root of file syntax tree
enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any
}
// Options specifies parameters affecting the inliner algorithm.
// All fields are optional.
type Options struct {
Logf func(string, ...any) // log output function, records decision-making process
IgnoreEffects bool // ignore potential side effects of arguments (unsound)
}
// Result holds the result of code transformation.
type Result struct {
Content []byte // formatted, transformed content of caller file
Literalized bool // chosen strategy replaced callee() with func(){...}()
// TODO(adonovan): provide an API for clients that want structured
// output: a list of import additions and deletions plus one or more
// localized diffs (or even AST transformations, though ownership and
// mutation are tricky) near the call site.
}
// Inline inlines the called function (callee) into the function call (caller)
// and returns the updated, formatted content of the caller source file.
//
// Inline does not mutate any public fields of Caller or Callee.
func Inline(caller *Caller, callee *Callee, opts *Options) (*Result, error) {
copy := *opts // shallow copy
opts = &copy
// Set default options.
if opts.Logf == nil {
opts.Logf = func(string, ...any) {}
}
st := &state{
caller: caller,
callee: callee,
opts: opts,
}
return st.inline()
}
// state holds the working state of the inliner.
type state struct {
caller *Caller
callee *Callee
opts *Options
}
func (st *state) inline() (*Result, error) {
logf, caller, callee := st.opts.Logf, st.caller, st.callee
logf("inline %s @ %v",
debugFormatNode(caller.Fset, caller.Call),
caller.Fset.PositionFor(caller.Call.Lparen, false))
if !consistentOffsets(caller) {
return nil, fmt.Errorf("internal error: caller syntax positions are inconsistent with file content (did you forget to use FileSet.PositionFor when computing the file name?)")
}
// TODO(adonovan): use go1.21's ast.IsGenerated.
// Break the string literal so we can use inlining in this file. :)
if bytes.Contains(caller.Content, []byte("// Code generated by "+"cmd/cgo; DO NOT EDIT.")) {
return nil, fmt.Errorf("cannot inline calls from files that import \"C\"")
}
res, err := st.inlineCall()
if err != nil {
return nil, err
}
// Replace the call (or some node that encloses it) by new syntax.
assert(res.old != nil, "old is nil")
assert(res.new != nil, "new is nil")
// A single return operand inlined to a unary
// expression context may need parens. Otherwise:
// func two() int { return 1+1 }
// print(-two()) => print(-1+1) // oops!
//
// Usually it is not necessary to insert ParenExprs
// as the formatter is smart enough to insert them as
// needed by the context. But the res.{old,new}
// substitution is done by formatting res.new in isolation
// and then splicing its text over res.old, so the
// formatter doesn't see the parent node and cannot do
// the right thing. (One solution would be to always
// format the enclosing node of old, but that requires
// non-lossy comment handling, #20744.)
//
// So, we must analyze the call's context
// to see whether ambiguity is possible.
// For example, if the context is x[y:z], then
// the x subtree is subject to precedence ambiguity
// (replacing x by p+q would give p+q[y:z] which is wrong)
// but the y and z subtrees are safe.
if needsParens(caller.path, res.old, res.new) {
res.new = &ast.ParenExpr{X: res.new.(ast.Expr)}
}
// Some reduction strategies return a new block holding the
// callee's statements. The block's braces may be elided when
// there is no conflict between names declared in the block
// with those declared by the parent block, and no risk of
// a caller's goto jumping forward across a declaration.
//
// This elision is only safe when the ExprStmt is beneath a
// BlockStmt, CaseClause.Body, or CommClause.Body;
// (see "statement theory").
//
// The inlining analysis may have already determined that eliding braces is
// safe. Otherwise, we analyze its safety here.
elideBraces := res.elideBraces
if !elideBraces {
if newBlock, ok := res.new.(*ast.BlockStmt); ok {
i := nodeIndex(caller.path, res.old)
parent := caller.path[i+1]
var body []ast.Stmt
switch parent := parent.(type) {
case *ast.BlockStmt:
body = parent.List
case *ast.CommClause:
body = parent.Body
case *ast.CaseClause:
body = parent.Body
}
if body != nil {
callerNames := declares(body)
// If BlockStmt is a function body,
// include its receiver, params, and results.
addFieldNames := func(fields *ast.FieldList) {
if fields != nil {
for _, field := range fields.List {
for _, id := range field.Names {
callerNames[id.Name] = true
}
}
}
}
switch f := caller.path[i+2].(type) {
case *ast.FuncDecl:
addFieldNames(f.Recv)
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
case *ast.FuncLit:
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
}
if len(callerLabels(caller.path)) > 0 {
// TODO(adonovan): be more precise and reject
// only forward gotos across the inlined block.
logf("keeping block braces: caller uses control labels")
} else if intersects(declares(newBlock.List), callerNames) {
logf("keeping block braces: avoids name conflict")
} else {
elideBraces = true
}
}
}
}
// Don't call replaceNode(caller.File, res.old, res.new)
// as it mutates the caller's syntax tree.
// Instead, splice the file, replacing the extent of the "old"
// node by a formatting of the "new" node, and re-parse.
// We'll fix up the imports on this new tree, and format again.
var f *ast.File
{
start := offsetOf(caller.Fset, res.old.Pos())
end := offsetOf(caller.Fset, res.old.End())
var out bytes.Buffer
out.Write(caller.Content[:start])
// TODO(adonovan): might it make more sense to use
// callee.Fset when formatting res.new?
// The new tree is a mix of (cloned) caller nodes for
// the argument expressions and callee nodes for the
// function body. In essence the question is: which
// is more likely to have comments?
// Usually the callee body will be larger and more
// statement-heavy than the arguments, but a
// strategy may widen the scope of the replacement
// (res.old) from CallExpr to, say, its enclosing
// block, so the caller nodes dominate.
// Precise comment handling would make this a
// non-issue. Formatting wouldn't really need a
// FileSet at all.
if elideBraces {
for i, stmt := range res.new.(*ast.BlockStmt).List {
if i > 0 {
out.WriteByte('\n')
}
if err := format.Node(&out, caller.Fset, stmt); err != nil {
return nil, err
}
}
} else {
if err := format.Node(&out, caller.Fset, res.new); err != nil {
return nil, err
}
}
out.Write(caller.Content[end:])
const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors
f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode)
if err != nil {
// Something has gone very wrong.
logf("failed to parse <<%s>>", &out) // debugging
return nil, err
}
}
// Add new imports.
//
// Insert new imports after last existing import,
// to avoid migration of pre-import comments.
// The imports will be organized below.
if len(res.newImports) > 0 {
var importDecl *ast.GenDecl
if len(f.Imports) > 0 {
// Append specs to existing import decl
importDecl = f.Decls[0].(*ast.GenDecl)
} else {
// Insert new import decl.
importDecl = &ast.GenDecl{Tok: token.IMPORT}
f.Decls = prepend[ast.Decl](importDecl, f.Decls...)
}
for _, imp := range res.newImports {
// Check that the new imports are accessible.
path, _ := strconv.Unquote(imp.spec.Path.Value)
if !canImport(caller.Types.Path(), path) {
return nil, fmt.Errorf("can't inline function %v as its body refers to inaccessible package %q", callee, path)
}
importDecl.Specs = append(importDecl.Specs, imp.spec)
}
}
// Delete imports referenced only by caller.Call.Fun.
//
// (We can't let imports.Process take care of it as it may
// mistake obsolete imports for missing new imports when the
// names are similar, as is common during a package migration.)
for _, specToDelete := range res.oldImports {
for _, decl := range f.Decls {
if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT {
decl.Specs = slicesDeleteFunc(decl.Specs, func(spec ast.Spec) bool {
imp := spec.(*ast.ImportSpec)
// Since we re-parsed the file, we can't match by identity;
// instead look for syntactic equivalence.
return imp.Path.Value == specToDelete.Path.Value &&
(imp.Name != nil) == (specToDelete.Name != nil) &&
(imp.Name == nil || imp.Name.Name == specToDelete.Name.Name)
})
// Edge case: import "foo" => import ().
if !decl.Lparen.IsValid() {
decl.Lparen = decl.TokPos + token.Pos(len("import"))
decl.Rparen = decl.Lparen + 1
}
}
}
}
var out bytes.Buffer
if err := format.Node(&out, caller.Fset, f); err != nil {
return nil, err
}
newSrc := out.Bytes()
// Remove imports that are no longer referenced.
//
// It ought to be possible to compute the set of PkgNames used
// by the "old" code, compute the free identifiers of the
// "new" code using a syntax-only (no go/types) algorithm, and
// see if the reduction in the number of uses of any PkgName
// equals the number of times it appears in caller.Info.Uses,
// indicating that it is no longer referenced by res.new.
//
// However, the notorious ambiguity of resolving T{F: 0} makes this
// unreliable: without types, we can't tell whether F refers to
// a field of struct T, or a package-level const/var of a
// dot-imported (!) package.
//
// So, for now, we run imports.Process, which is
// unsatisfactory as it has to run the go command, and it
// looks at the user's module cache state--unnecessarily,
// since this step cannot add new imports.
//
// TODO(adonovan): replace with a simpler implementation since
// all the necessary imports are present but merely untidy.
// That will be faster, and also less prone to nondeterminism
// if there are bugs in our logic for import maintenance.
//
// However, golang.org/x/tools/internal/imports.ApplyFixes is
// too simple as it requires the caller to have figured out
// all the logical edits. In our case, we know all the new
// imports that are needed (see newImports), each of which can
// be specified as:
//
// &imports.ImportFix{
// StmtInfo: imports.ImportInfo{path, name,
// IdentName: name,
// FixType: imports.AddImport,
// }
//
// but we don't know which imports are made redundant by the
// inlining itself. For example, inlining a call to
// fmt.Println may make the "fmt" import redundant.
//
// Also, both imports.Process and internal/imports.ApplyFixes
// reformat the entire file, which is not ideal for clients
// such as gopls. (That said, the point of a canonical format
// is arguably that any tool can reformat as needed without
// this being inconvenient.)
//
// We could invoke imports.Process and parse its result,
// compare against the original AST, compute a list of import
// fixes, and return that too.
// Recompute imports only if there were existing ones.
if len(f.Imports) > 0 {
formatted, err := imports.Process("output", newSrc, nil)
if err != nil {
logf("cannot reformat: %v <<%s>>", err, &out)
return nil, err // cannot reformat (a bug?)
}
newSrc = formatted
}
literalized := false
if call, ok := res.new.(*ast.CallExpr); ok && is[*ast.FuncLit](call.Fun) {
literalized = true
}
return &Result{
Content: newSrc,
Literalized: literalized,
}, nil
}
type newImport struct {
pkgName string
spec *ast.ImportSpec
}
type inlineCallResult struct {
newImports []newImport // to add
oldImports []*ast.ImportSpec // to remove
// If elideBraces is set, old is an ast.Stmt and new is an ast.BlockStmt to
// be spliced in. This allows the inlining analysis to assert that inlining
// the block is OK; if elideBraces is unset and old is an ast.Stmt and new is
// an ast.BlockStmt, braces may still be elided if the post-processing
// analysis determines that it is safe to do so.
//
// Ideally, it would not be necessary for the inlining analysis to "reach
// through" to the post-processing pass in this way. Instead, inlining could
// just set old to be an ast.BlockStmt and rewrite the entire BlockStmt, but
// unfortunately in order to preserve comments, it is important that inlining
// replace as little syntax as possible.
elideBraces bool
old, new ast.Node // e.g. replace call expr by callee function body expression
}
// inlineCall returns a pair of an old node (the call, or something
// enclosing it) and a new node (its replacement, which may be a
// combination of caller, callee, and new nodes), along with the set
// of new imports needed.
//
// TODO(adonovan): rethink the 'result' interface. The assumption of a
// one-to-one replacement seems fragile. One can easily imagine the
// transformation replacing the call and adding new variable
// declarations, for example, or replacing a call statement by zero or
// many statements.)
//
// TODO(adonovan): in earlier drafts, the transformation was expressed
// by splicing substrings of the two source files because syntax
// trees don't preserve comments faithfully (see #20744), but such
// transformations don't compose. The current implementation is
// tree-based but is very lossy wrt comments. It would make a good
// candidate for evaluating an alternative fully self-contained tree
// representation, such as any proposed solution to #20744, or even
// dst or some private fork of go/ast.)
func (st *state) inlineCall() (*inlineCallResult, error) {
logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl
checkInfoFields(caller.Info)
// Inlining of dynamic calls is not currently supported,
// even for local closure calls. (This would be a lot of work.)
calleeSymbol := typeutil.StaticCallee(caller.Info, caller.Call)
if calleeSymbol == nil {
// e.g. interface method
return nil, fmt.Errorf("cannot inline: not a static function call")
}
// Reject cross-package inlining if callee has
// free references to unexported symbols.
samePkg := caller.Types.Path() == callee.PkgPath
if !samePkg && len(callee.Unexported) > 0 {
return nil, fmt.Errorf("cannot inline call to %s because body refers to non-exported %s",
callee.Name, callee.Unexported[0])
}
// -- analyze callee's free references in caller context --
// Compute syntax path enclosing Call, innermost first (Path[0]=Call),
// and outermost enclosing function, if any.
caller.path, _ = astutil.PathEnclosingInterval(caller.File, caller.Call.Pos(), caller.Call.End())
for _, n := range caller.path {
if decl, ok := n.(*ast.FuncDecl); ok {
caller.enclosingFunc = decl
break
}
}
// If call is within a function, analyze all its
// local vars for the "single assignment" property.
// (Taking the address &v counts as a potential assignment.)
var assign1 func(v *types.Var) bool // reports whether v a single-assignment local var
{
updatedLocals := make(map[*types.Var]bool)
if caller.enclosingFunc != nil {
escape(caller.Info, caller.enclosingFunc, func(v *types.Var, _ bool) {
updatedLocals[v] = true
})
logf("multiple-assignment vars: %v", updatedLocals)
}
assign1 = func(v *types.Var) bool { return !updatedLocals[v] }
}
// import map, initially populated with caller imports.
//
// For simplicity we ignore existing dot imports, so that a
// qualified identifier (QI) in the callee is always
// represented by a QI in the caller, allowing us to treat a
// QI like a selection on a package name.
importMap := make(map[string][]string) // maps package path to local name(s)
for _, imp := range caller.File.Imports {
if pkgname, ok := importedPkgName(caller.Info, imp); ok &&
pkgname.Name() != "." &&
pkgname.Name() != "_" {
path := pkgname.Imported().Path()
importMap[path] = append(importMap[path], pkgname.Name())
}
}
var oldImports []*ast.ImportSpec // imports referenced only caller.Call.Fun
// localImportName returns the local name for a given imported package path.
var newImports []newImport
localImportName := func(obj *object) string {
// Does an import exist?
for _, name := range importMap[obj.PkgPath] {
// Check that either the import preexisted,
// or that it was newly added (no PkgName) but is not shadowed,
// either in the callee (shadows) or caller (caller.lookup).
if !obj.Shadow[name] {
found := caller.lookup(name)
if is[*types.PkgName](found) || found == nil {
return name
}
}
}
newlyAdded := func(name string) bool {
for _, new := range newImports {
if new.pkgName == name {
return true
}
}
return false
}
// shadowedInCaller reports whether a candidate package name
// already refers to a declaration in the caller.
shadowedInCaller := func(name string) bool {
existing := caller.lookup(name)
// If the candidate refers to a PkgName p whose sole use is
// in caller.Call.Fun of the form p.F(...), where p.F is a
// qualified identifier, the p import will be deleted,
// so it's safe (and better) to recycle the name.
//
// Only the qualified identifier case matters, as other
// references to imported package names in the Call.Fun
// expression (e.g. x.after(3*time.Second).f()
// or time.Second.String()) will remain after
// inlining, as arguments.
if pkgName, ok := existing.(*types.PkgName); ok {
if sel, ok := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr); ok {
if sole := soleUse(caller.Info, pkgName); sole == sel.X {
for _, spec := range caller.File.Imports {
pkgName2, ok := importedPkgName(caller.Info, spec)
if ok && pkgName2 == pkgName {
oldImports = append(oldImports, spec)
return false
}
}
}
}
}
return existing != nil
}
// import added by callee
//
// Choose local PkgName based on last segment of
// package path plus, if needed, a numeric suffix to
// ensure uniqueness.
//
// "init" is not a legal PkgName.
//
// TODO(rfindley): is it worth preserving local package names for callee
// imports? Are they likely to be better or worse than the name we choose
// here?
base := obj.PkgName
name := base
for n := 0; obj.Shadow[name] || shadowedInCaller(name) || newlyAdded(name) || name == "init"; n++ {
name = fmt.Sprintf("%s%d", base, n)
}
logf("adding import %s %q", name, obj.PkgPath)
spec := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(obj.PkgPath),
},
}
// Use explicit pkgname (out of necessity) when it differs from the declared name,
// or (for good style) when it differs from base(pkgpath).
if name != obj.PkgName || name != pathpkg.Base(obj.PkgPath) {
spec.Name = makeIdent(name)
}
newImports = append(newImports, newImport{
pkgName: name,
spec: spec,
})
importMap[obj.PkgPath] = append(importMap[obj.PkgPath], name)
return name
}
// Compute the renaming of the callee's free identifiers.
objRenames := make([]ast.Expr, len(callee.FreeObjs)) // nil => no change
for i, obj := range callee.FreeObjs {
// obj is a free object of the callee.
//
// Possible cases are:
// - builtin function, type, or value (e.g. nil, zero)
// => check not shadowed in caller.
// - package-level var/func/const/types
// => same package: check not shadowed in caller.
// => otherwise: import other package, form a qualified identifier.
// (Unexported cross-package references were rejected already.)
// - type parameter
// => not yet supported
// - pkgname
// => import other package and use its local name.
//
// There can be no free references to labels, fields, or methods.
// Note that we must consider potential shadowing both
// at the caller side (caller.lookup) and, when
// choosing new PkgNames, within the callee (obj.shadow).
var newName ast.Expr
if obj.Kind == "pkgname" {
// Use locally appropriate import, creating as needed.
newName = makeIdent(localImportName(&obj)) // imported package
} else if !obj.ValidPos {
// Built-in function, type, or value (e.g. nil, zero):
// check not shadowed at caller.
found := caller.lookup(obj.Name) // always finds something
if found.Pos().IsValid() {
return nil, fmt.Errorf("cannot inline, because the callee refers to built-in %q, which in the caller is shadowed by a %s (declared at line %d)",
obj.Name, objectKind(found),
caller.Fset.PositionFor(found.Pos(), false).Line)
}
} else {
// Must be reference to package-level var/func/const/type,
// since type parameters are not yet supported.
qualify := false
if obj.PkgPath == callee.PkgPath {
// reference within callee package
if samePkg {
// Caller and callee are in same package.
// Check caller has not shadowed the decl.
//
// This may fail if the callee is "fake", such as for signature
// refactoring where the callee is modified to be a trivial wrapper
// around the refactored signature.
found := caller.lookup(obj.Name)
if found != nil && !isPkgLevel(found) {
return nil, fmt.Errorf("cannot inline, because the callee refers to %s %q, which in the caller is shadowed by a %s (declared at line %d)",
obj.Kind, obj.Name,
objectKind(found),
caller.Fset.PositionFor(found.Pos(), false).Line)
}
} else {
// Cross-package reference.
qualify = true
}
} else {
// Reference to a package-level declaration
// in another package, without a qualified identifier:
// it must be a dot import.
qualify = true
}
// Form a qualified identifier, pkg.Name.
if qualify {
pkgName := localImportName(&obj)
newName = &ast.SelectorExpr{
X: makeIdent(pkgName),
Sel: makeIdent(obj.Name),
}
}
}
objRenames[i] = newName
}
res := &inlineCallResult{
newImports: newImports,
oldImports: oldImports,
}
// Parse callee function declaration.
calleeFset, calleeDecl, err := parseCompact(callee.Content)
if err != nil {
return nil, err // "can't happen"
}
// replaceCalleeID replaces an identifier in the callee.
// The replacement tree must not belong to the caller; use cloneNode as needed.
replaceCalleeID := func(offset int, repl ast.Expr) {
id := findIdent(calleeDecl, calleeDecl.Pos()+token.Pos(offset))
logf("- replace id %q @ #%d to %q", id.Name, offset, debugFormatNode(calleeFset, repl))
replaceNode(calleeDecl, id, repl)
}
// Generate replacements for each free identifier.
// (The same tree may be spliced in multiple times, resulting in a DAG.)
for _, ref := range callee.FreeRefs {
if repl := objRenames[ref.Object]; repl != nil {
replaceCalleeID(ref.Offset, repl)
}
}
// Gather the effective call arguments, including the receiver.
// Later, elements will be eliminated (=> nil) by parameter substitution.
args, err := st.arguments(caller, calleeDecl, assign1)
if err != nil {
return nil, err // e.g. implicit field selection cannot be made explicit
}
// Gather effective parameter tuple, including the receiver if any.
// Simplify variadic parameters to slices (in all cases but one).
var params []*parameter // including receiver; nil => parameter substituted
{
sig := calleeSymbol.Type().(*types.Signature)
if sig.Recv() != nil {
params = append(params, &parameter{
obj: sig.Recv(),
fieldType: calleeDecl.Recv.List[0].Type,
info: callee.Params[0],
})
}
// Flatten the list of syntactic types.
var types []ast.Expr
for _, field := range calleeDecl.Type.Params.List {
if field.Names == nil {
types = append(types, field.Type)
} else {
for range field.Names {
types = append(types, field.Type)
}
}
}
for i := 0; i < sig.Params().Len(); i++ {
params = append(params, &parameter{
obj: sig.Params().At(i),
fieldType: types[i],
info: callee.Params[len(params)],
})
}
// Variadic function?
//
// There are three possible types of call:
// - ordinary f(a1, ..., aN)
// - ellipsis f(a1, ..., slice...)
// - spread f(recv?, g()) where g() is a tuple.
// The first two are desugared to non-variadic calls
// with an ordinary slice parameter;
// the third is tricky and cannot be reduced, and (if
// a receiver is present) cannot even be literalized.
// Fortunately it is vanishingly rare.
//
// TODO(adonovan): extract this to a function.
if sig.Variadic() {
lastParam := last(params)
if len(args) > 0 && last(args).spread {
// spread call to variadic: tricky
lastParam.variadic = true
} else {
// ordinary/ellipsis call to variadic
// simplify decl: func(T...) -> func([]T)
lastParamField := last(calleeDecl.Type.Params.List)
lastParamField.Type = &ast.ArrayType{
Elt: lastParamField.Type.(*ast.Ellipsis).Elt,
}
if caller.Call.Ellipsis.IsValid() {
// ellipsis call: f(slice...) -> f(slice)
// nop
} else {
// ordinary call: f(a1, ... aN) -> f([]T{a1, ..., aN})
n := len(params) - 1
ordinary, extra := args[:n], args[n:]
var elts []ast.Expr
pure, effects := true, false
for _, arg := range extra {
elts = append(elts, arg.expr)
pure = pure && arg.pure
effects = effects || arg.effects
}
args = append(ordinary, &argument{
expr: &ast.CompositeLit{
Type: lastParamField.Type,
Elts: elts,
},
typ: lastParam.obj.Type(),
constant: nil,
pure: pure,
effects: effects,
duplicable: false,
freevars: nil, // not needed
})
}
}
}
}
// Log effective arguments.
for i, arg := range args {
logf("arg #%d: %s pure=%t effects=%t duplicable=%t free=%v type=%v",
i, debugFormatNode(caller.Fset, arg.expr),
arg.pure, arg.effects, arg.duplicable, arg.freevars, arg.typ)
}
// Note: computation below should be expressed in terms of
// the args and params slices, not the raw material.
// Perform parameter substitution.
// May eliminate some elements of params/args.
substitute(logf, caller, params, args, callee.Effects, callee.Falcon, replaceCalleeID)
// Update the callee's signature syntax.
updateCalleeParams(calleeDecl, params)
// Create a var (param = arg; ...) decl for use by some strategies.
bindingDecl := createBindingDecl(logf, caller, args, calleeDecl, callee.Results)
var remainingArgs []ast.Expr
for _, arg := range args {
if arg != nil {
remainingArgs = append(remainingArgs, arg.expr)
}
}
// -- let the inlining strategies begin --
//
// When we commit to a strategy, we log a message of the form:
//
// "strategy: reduce expr-context call to { return expr }"
//
// This is a terse way of saying:
//
// we plan to reduce a call
// that appears in expression context
// to a function whose body is of the form { return expr }
// TODO(adonovan): split this huge function into a sequence of
// function calls with an error sentinel that means "try the
// next strategy", and make sure each strategy writes to the
// log the reason it didn't match.
// Special case: eliminate a call to a function whose body is empty.
// (=> callee has no results and caller is a statement.)
//
// func f(params) {}
// f(args)
// => _, _ = args
//
if len(calleeDecl.Body.List) == 0 {
logf("strategy: reduce call to empty body")
// Evaluate the arguments for effects and delete the call entirely.
stmt := callStmt(caller.path, false) // cannot fail
res.old = stmt
if nargs := len(remainingArgs); nargs > 0 {
// Emit "_, _ = args" to discard results.
// TODO(adonovan): if args is the []T{a1, ..., an}
// literal synthesized during variadic simplification,
// consider unwrapping it to its (pure) elements.
// Perhaps there's no harm doing this for any slice literal.
// Make correction for spread calls
// f(g()) or recv.f(g()) where g() is a tuple.
if last := last(args); last != nil && last.spread {
nspread := last.typ.(*types.Tuple).Len()
if len(args) > 1 { // [recv, g()]
// A single AssignStmt cannot discard both, so use a 2-spec var decl.
res.new = &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{makeIdent("_")},
Values: []ast.Expr{args[0].expr},
},
&ast.ValueSpec{
Names: blanks[*ast.Ident](nspread),
Values: []ast.Expr{args[1].expr},
},
},
}
return res, nil
}
// Sole argument is spread call.
nargs = nspread
}
res.new = &ast.AssignStmt{
Lhs: blanks[ast.Expr](nargs),
Tok: token.ASSIGN,
Rhs: remainingArgs,
}
} else {
// No remaining arguments: delete call statement entirely
res.new = &ast.EmptyStmt{}
}
return res, nil
}
// If all parameters have been substituted and no result
// variable is referenced, we don't need a binding decl.
// This may enable better reduction strategies.
allResultsUnreferenced := forall(callee.Results, func(i int, r *paramInfo) bool { return len(r.Refs) == 0 })
needBindingDecl := !allResultsUnreferenced ||
exists(params, func(i int, p *parameter) bool { return p != nil })
// The two strategies below overlap for a tail call of {return exprs}:
// The expr-context reduction is nice because it keeps the
// caller's return stmt and merely switches its operand,
// without introducing a new block, but it doesn't work with
// implicit return conversions.
//
// TODO(adonovan): unify these cases more cleanly, allowing return-
// operand replacement and implicit conversions, by adding
// conversions around each return operand (if not a spread return).
// Special case: call to { return exprs }.
//
// Reduces to:
// { var (bindings); _, _ = exprs }
// or _, _ = exprs
// or expr
//
// If:
// - the body is just "return expr" with trivial implicit conversions,
// or the caller's return type matches the callee's,
// - all parameters and result vars can be eliminated
// or replaced by a binding decl,
// then the call expression can be replaced by the
// callee's body expression, suitably substituted.
if len(calleeDecl.Body.List) == 1 &&
is[*ast.ReturnStmt](calleeDecl.Body.List[0]) &&
len(calleeDecl.Body.List[0].(*ast.ReturnStmt).Results) > 0 { // not a bare return
results := calleeDecl.Body.List[0].(*ast.ReturnStmt).Results
parent, grandparent := callContext(caller.path)
// statement context
if stmt, ok := parent.(*ast.ExprStmt); ok &&
(!needBindingDecl || bindingDecl != nil) {
logf("strategy: reduce stmt-context call to { return exprs }")
clearPositions(calleeDecl.Body)
if callee.ValidForCallStmt {
logf("callee body is valid as statement")
// Inv: len(results) == 1
if !needBindingDecl {
// Reduces to: expr
res.old = caller.Call
res.new = results[0]
} else {
// Reduces to: { var (bindings); expr }
res.old = stmt
res.new = &ast.BlockStmt{
List: []ast.Stmt{
bindingDecl.stmt,
&ast.ExprStmt{X: results[0]},
},
}
}
} else {
logf("callee body is not valid as statement")
// The call is a standalone statement, but the
// callee body is not suitable as a standalone statement
// (f() or <-ch), explicitly discard the results:
// Reduces to: _, _ = exprs
discard := &ast.AssignStmt{
Lhs: blanks[ast.Expr](callee.NumResults),
Tok: token.ASSIGN,
Rhs: results,
}
res.old = stmt
if !needBindingDecl {
// Reduces to: _, _ = exprs
res.new = discard
} else {
// Reduces to: { var (bindings); _, _ = exprs }
res.new = &ast.BlockStmt{
List: []ast.Stmt{
bindingDecl.stmt,
discard,
},
}
}
}
return res, nil
}
// Assignment context.
//
// If there is no binding decl, or if the binding decl declares no names,
// an assignment a, b := f() can be reduced to a, b := x, y.
if stmt, ok := parent.(*ast.AssignStmt); ok &&
is[*ast.BlockStmt](grandparent) &&
(!needBindingDecl || (bindingDecl != nil && len(bindingDecl.names) == 0)) {
// Reduces to: { var (bindings); lhs... := rhs... }
if newStmts, ok := st.assignStmts(stmt, results); ok {
logf("strategy: reduce assign-context call to { return exprs }")
clearPositions(calleeDecl.Body)
block := &ast.BlockStmt{
List: newStmts,
}
if needBindingDecl {
block.List = prepend(bindingDecl.stmt, block.List...)
}
// assignStmts does not introduce new bindings, and replacing an
// assignment only works if the replacement occurs in the same scope.
// Therefore, we must ensure that braces are elided.
res.elideBraces = true
res.old = stmt
res.new = block
return res, nil
}
}
// expression context
if !needBindingDecl {
clearPositions(calleeDecl.Body)
anyNonTrivialReturns := hasNonTrivialReturn(callee.Returns)
if callee.NumResults == 1 {
logf("strategy: reduce expr-context call to { return expr }")
// (includes some simple tail-calls)
// Make implicit return conversion explicit.
if anyNonTrivialReturns {
results[0] = convert(calleeDecl.Type.Results.List[0].Type, results[0])
}
res.old = caller.Call
res.new = results[0]
return res, nil
} else if !anyNonTrivialReturns {
logf("strategy: reduce spread-context call to { return expr }")
// There is no general way to reify conversions in a spread
// return, hence the requirement above.
//
// TODO(adonovan): allow this reduction when no
// conversion is required by the context.
// The call returns multiple results but is
// not a standalone call statement. It must
// be the RHS of a spread assignment:
// var x, y = f()
// x, y := f()
// x, y = f()
// or the sole argument to a spread call:
// printf(f())
// or spread return statement:
// return f()
res.old = parent
switch context := parent.(type) {
case *ast.AssignStmt:
// Inv: the call must be in Rhs[0], not Lhs.
assign := shallowCopy(context)
assign.Rhs = results
res.new = assign
case *ast.ValueSpec:
// Inv: the call must be in Values[0], not Names.
spec := shallowCopy(context)
spec.Values = results
res.new = spec
case *ast.CallExpr:
// Inv: the call must be in Args[0], not Fun.
call := shallowCopy(context)
call.Args = results
res.new = call
case *ast.ReturnStmt:
// Inv: the call must be Results[0].
ret := shallowCopy(context)
ret.Results = results
res.new = ret
default:
return nil, fmt.Errorf("internal error: unexpected context %T for spread call", context)
}
return res, nil
}
}
}
// Special case: tail-call.
//
// Inlining:
// return f(args)
// where:
// func f(params) (results) { body }
// reduces to:
// { var (bindings); body }
// { body }
// so long as:
// - all parameters can be eliminated or replaced by a binding decl,
// - call is a tail-call;
// - all returns in body have trivial result conversions,
// or the caller's return type matches the callee's,
// - there is no label conflict;
// - no result variable is referenced by name,
// or implicitly by a bare return.
//
// The body may use defer, arbitrary control flow, and
// multiple returns.
//
// TODO(adonovan): add a strategy for a 'void tail
// call', i.e. a call statement prior to an (explicit
// or implicit) return.
parent, _ := callContext(caller.path)
if ret, ok := parent.(*ast.ReturnStmt); ok &&
len(ret.Results) == 1 &&
tailCallSafeReturn(caller, calleeSymbol, callee) &&
!callee.HasBareReturn &&
(!needBindingDecl || bindingDecl != nil) &&
!hasLabelConflict(caller.path, callee.Labels) &&
allResultsUnreferenced {
logf("strategy: reduce tail-call")
body := calleeDecl.Body
clearPositions(body)
if needBindingDecl {
body.List = prepend(bindingDecl.stmt, body.List...)
}
res.old = ret
res.new = body
return res, nil
}
// Special case: call to void function
//
// Inlining:
// f(args)
// where:
// func f(params) { stmts }
// reduces to:
// { var (bindings); stmts }
// { stmts }
// so long as:
// - callee is a void function (no returns)
// - callee does not use defer
// - there is no label conflict between caller and callee
// - all parameters and result vars can be eliminated
// or replaced by a binding decl,
// - caller ExprStmt is in unrestricted statement context.
if stmt := callStmt(caller.path, true); stmt != nil &&
(!needBindingDecl || bindingDecl != nil) &&
!callee.HasDefer &&
!hasLabelConflict(caller.path, callee.Labels) &&
len(callee.Returns) == 0 {
logf("strategy: reduce stmt-context call to { stmts }")
body := calleeDecl.Body
var repl ast.Stmt = body
clearPositions(repl)
if needBindingDecl {
body.List = prepend(bindingDecl.stmt, body.List...)
}
res.old = stmt
res.new = repl
return res, nil
}
// TODO(adonovan): parameterless call to { stmts; return expr }
// from one of these contexts:
// x, y = f()
// x, y := f()
// var x, y = f()
// =>
// var (x T1, y T2); { stmts; x, y = expr }
//
// Because the params are no longer declared simultaneously
// we need to check that (for example) x ∉ freevars(T2),
// in addition to the usual checks for arg/result conversions,
// complex control, etc.
// Also test cases where expr is an n-ary call (spread returns).
// Literalization isn't quite infallible.
// Consider a spread call to a method in which
// no parameters are eliminated, e.g.
// new(T).f(g())
// where
// func (recv *T) f(x, y int) { body }
// func g() (int, int)
// This would be literalized to:
// func (recv *T, x, y int) { body }(new(T), g()),
// which is not a valid argument list because g() must appear alone.
// Reject this case for now.
if len(args) == 2 && args[0] != nil && args[1] != nil && is[*types.Tuple](args[1].typ) {
return nil, fmt.Errorf("can't yet inline spread call to method")
}
// Infallible general case: literalization.
//
// func(params) { body }(args)
//
logf("strategy: literalization")
funcLit := &ast.FuncLit{
Type: calleeDecl.Type,
Body: calleeDecl.Body,
}
// Literalization can still make use of a binding
// decl as it gives a more natural reading order:
//
// func() { var params = args; body }()
//
// TODO(adonovan): relax the allResultsUnreferenced requirement
// by adding a parameter-only (no named results) binding decl.
if bindingDecl != nil && allResultsUnreferenced {
funcLit.Type.Params.List = nil
remainingArgs = nil
funcLit.Body.List = prepend(bindingDecl.stmt, funcLit.Body.List...)
}
// Emit a new call to a function literal in place of
// the callee name, with appropriate replacements.
newCall := &ast.CallExpr{
Fun: funcLit,
Ellipsis: token.NoPos, // f(slice...) is always simplified
Args: remainingArgs,
}
clearPositions(newCall.Fun)
res.old = caller.Call
res.new = newCall
return res, nil
}
type argument struct {
expr ast.Expr
typ types.Type // may be tuple for sole non-receiver arg in spread call
constant constant.Value // value of argument if constant
spread bool // final arg is call() assigned to multiple params
pure bool // expr is pure (doesn't read variables)
effects bool // expr has effects (updates variables)
duplicable bool // expr may be duplicated
freevars map[string]bool // free names of expr
substitutable bool // is candidate for substitution
}
// arguments returns the effective arguments of the call.
//
// If the receiver argument and parameter have
// different pointerness, make the "&" or "*" explicit.
//
// Also, if x.f() is shorthand for promoted method x.y.f(),
// make the .y explicit in T.f(x.y, ...).
//
// Beware that:
//
// - a method can only be called through a selection, but only
// the first of these two forms needs special treatment:
//
// expr.f(args) -> ([&*]expr, args) MethodVal
// T.f(recv, args) -> ( expr, args) MethodExpr
//
// - the presence of a value in receiver-position in the call
// is a property of the caller, not the callee. A method
// (calleeDecl.Recv != nil) may be called like an ordinary
// function.
//
// - the types.Signatures seen by the caller (from
// StaticCallee) and by the callee (from decl type)
// differ in this case.
//
// In a spread call f(g()), the sole ordinary argument g(),
// always last in args, has a tuple type.
//
// We compute type-based predicates like pure, duplicable,
// freevars, etc, now, before we start modifying syntax.
func (st *state) arguments(caller *Caller, calleeDecl *ast.FuncDecl, assign1 func(*types.Var) bool) ([]*argument, error) {
var args []*argument
callArgs := caller.Call.Args
if calleeDecl.Recv != nil {
sel := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr)
seln := caller.Info.Selections[sel]
var recvArg ast.Expr
switch seln.Kind() {
case types.MethodVal: // recv.f(callArgs)
recvArg = sel.X
case types.MethodExpr: // T.f(recv, callArgs)
recvArg = callArgs[0]
callArgs = callArgs[1:]
}
if recvArg != nil {
// Compute all the type-based predicates now,
// before we start meddling with the syntax;
// the meddling will update them.
arg := &argument{
expr: recvArg,
typ: caller.Info.TypeOf(recvArg),
constant: caller.Info.Types[recvArg].Value,
pure: pure(caller.Info, assign1, recvArg),
effects: st.effects(caller.Info, recvArg),
duplicable: duplicable(caller.Info, recvArg),
freevars: freeVars(caller.Info, recvArg),
}
recvArg = nil // prevent accidental use
// Move receiver argument recv.f(args) to argument list f(&recv, args).
args = append(args, arg)
// Make field selections explicit (recv.f -> recv.y.f),
// updating arg.{expr,typ}.
indices := seln.Index()
for _, index := range indices[:len(indices)-1] {
fld := typeparams.CoreType(typeparams.Deref(arg.typ)).(*types.Struct).Field(index)
if fld.Pkg() != caller.Types && !fld.Exported() {
return nil, fmt.Errorf("in %s, implicit reference to unexported field .%s cannot be made explicit",
debugFormatNode(caller.Fset, caller.Call.Fun),
fld.Name())
}
if isPointer(arg.typ) {
arg.pure = false // implicit *ptr operation => impure
}
arg.expr = &ast.SelectorExpr{
X: arg.expr,
Sel: makeIdent(fld.Name()),
}
arg.typ = fld.Type()
arg.duplicable = false
}
// Make * or & explicit.
argIsPtr := isPointer(arg.typ)
paramIsPtr := isPointer(seln.Obj().Type().Underlying().(*types.Signature).Recv().Type())
if !argIsPtr && paramIsPtr {
// &recv
arg.expr = &ast.UnaryExpr{Op: token.AND, X: arg.expr}
arg.typ = types.NewPointer(arg.typ)
} else if argIsPtr && !paramIsPtr {
// *recv
arg.expr = &ast.StarExpr{X: arg.expr}
arg.typ = typeparams.Deref(arg.typ)
arg.duplicable = false
arg.pure = false
}
}
}
for _, expr := range callArgs {
tv := caller.Info.Types[expr]
args = append(args, &argument{
expr: expr,
typ: tv.Type,
constant: tv.Value,
spread: is[*types.Tuple](tv.Type), // => last
pure: pure(caller.Info, assign1, expr),
effects: st.effects(caller.Info, expr),
duplicable: duplicable(caller.Info, expr),
freevars: freeVars(caller.Info, expr),
})
}
// Re-typecheck each constant argument expression in a neutral context.
//
// In a call such as func(int16){}(1), the type checker infers
// the type "int16", not "untyped int", for the argument 1,
// because it has incorporated information from the left-hand
// side of the assignment implicit in parameter passing, but
// of course in a different context, the expression 1 may have
// a different type.
//
// So, we must use CheckExpr to recompute the type of the
// argument in a neutral context to find its inherent type.
// (This is arguably a bug in go/types, but I'm pretty certain
// I requested it be this way long ago... -adonovan)
//
// This is only needed for constants. Other implicit
// assignment conversions, such as unnamed-to-named struct or
// chan to <-chan, do not result in the type-checker imposing
// the LHS type on the RHS value.
for _, arg := range args {
if arg.constant == nil {
continue
}
info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)}
if err := types.CheckExpr(caller.Fset, caller.Types, caller.Call.Pos(), arg.expr, info); err != nil {
return nil, err
}
arg.typ = info.TypeOf(arg.expr)
}
return args, nil
}
type parameter struct {
obj *types.Var // parameter var from caller's signature
fieldType ast.Expr // syntax of type, from calleeDecl.Type.{Recv,Params}
info *paramInfo // information from AnalyzeCallee
variadic bool // (final) parameter is unsimplified ...T
}
// substitute implements parameter elimination by substitution.
//
// It considers each parameter and its corresponding argument in turn
// and evaluate these conditions:
//
// - the parameter is neither address-taken nor assigned;
// - the argument is pure;
// - if the parameter refcount is zero, the argument must
// not contain the last use of a local var;
// - if the parameter refcount is > 1, the argument must be duplicable;
// - the argument (or types.Default(argument) if it's untyped) has
// the same type as the parameter.
//
// If all conditions are met then the parameter can be substituted and
// each reference to it replaced by the argument. In that case, the
// replaceCalleeID function is called for each reference to the
// parameter, and is provided with its relative offset and replacement
// expression (argument), and the corresponding elements of params and
// args are replaced by nil.
func substitute(logf func(string, ...any), caller *Caller, params []*parameter, args []*argument, effects []int, falcon falconResult, replaceCalleeID func(offset int, repl ast.Expr)) {
// Inv:
// in calls to variadic, len(args) >= len(params)-1
// in spread calls to non-variadic, len(args) < len(params)
// in spread calls to variadic, len(args) <= len(params)
// (In spread calls len(args) = 1, or 2 if call has receiver.)
// Non-spread variadics have been simplified away already,
// so the args[i] lookup is safe if we stop after the spread arg.
next:
for i, param := range params {
arg := args[i]
// Check argument against parameter.
//
// Beware: don't use types.Info on arg since
// the syntax may be synthetic (not created by parser)
// and thus lacking positions and types;
// do it earlier (see pure/duplicable/freevars).
if arg.spread {
// spread => last argument, but not always last parameter
logf("keeping param %q and following ones: argument %s is spread",
param.info.Name, debugFormatNode(caller.Fset, arg.expr))
return // give up
}
assert(!param.variadic, "unsimplified variadic parameter")
if param.info.Escapes {
logf("keeping param %q: escapes from callee", param.info.Name)
continue
}
if param.info.Assigned {
logf("keeping param %q: assigned by callee", param.info.Name)
continue // callee needs the parameter variable
}
if len(param.info.Refs) > 1 && !arg.duplicable {
logf("keeping param %q: argument is not duplicable", param.info.Name)
continue // incorrect or poor style to duplicate an expression
}
if len(param.info.Refs) == 0 {
if arg.effects {
logf("keeping param %q: though unreferenced, it has effects", param.info.Name)
continue
}
// If the caller is within a function body,
// eliminating an unreferenced parameter might
// remove the last reference to a caller local var.
if caller.enclosingFunc != nil {
for free := range arg.freevars {
// TODO(rfindley): we can get this 100% right by looking for
// references among other arguments which have non-zero references
// within the callee.
if v, ok := caller.lookup(free).(*types.Var); ok && within(v.Pos(), caller.enclosingFunc.Body) && !isUsedOutsideCall(caller, v) {
logf("keeping param %q: arg contains perhaps the last reference to caller local %v @ %v",
param.info.Name, v, caller.Fset.PositionFor(v.Pos(), false))
continue next
}
}
}
}
// Check for shadowing.
//
// Consider inlining a call f(z, 1) to
// func f(x, y int) int { z := y; return x + y + z }:
// we can't replace x in the body by z (or any
// expression that has z as a free identifier)
// because there's an intervening declaration of z
// that would shadow the caller's one.
for free := range arg.freevars {
if param.info.Shadow[free] {
logf("keeping param %q: cannot replace with argument as it has free ref to %s that is shadowed", param.info.Name, free)
continue next // shadowing conflict
}
}
arg.substitutable = true // may be substituted, if effects permit
}
// Reject constant arguments as substitution candidates
// if they cause violation of falcon constraints.
checkFalconConstraints(logf, params, args, falcon)
// As a final step, introduce bindings to resolve any
// evaluation order hazards. This must be done last, as
// additional subsequent bindings could introduce new hazards.
resolveEffects(logf, args, effects)
// The remaining candidates are safe to substitute.
for i, param := range params {
if arg := args[i]; arg.substitutable {
// Wrap the argument in an explicit conversion if
// substitution might materially change its type.
// (We already did the necessary shadowing check
// on the parameter type syntax.)
//
// This is only needed for substituted arguments. All
// other arguments are given explicit types in either
// a binding decl or when using the literalization
// strategy.
// If the types are identical, we can eliminate
// redundant type conversions such as this:
//
// Callee:
// func f(i int32) { print(i) }
// Caller:
// func g() { f(int32(1)) }
// Inlined as:
// func g() { print(int32(int32(1)))
//
// Recall that non-trivial does not imply non-identical
// for constant conversions; however, at this point state.arguments
// has already re-typechecked the constant and set arg.type to
// its (possibly "untyped") inherent type, so
// the conversion from untyped 1 to int32 is non-trivial even
// though both arg and param have identical types (int32).
if len(param.info.Refs) > 0 &&
!types.Identical(arg.typ, param.obj.Type()) &&
!trivialConversion(arg.constant, arg.typ, param.obj.Type()) {
arg.expr = convert(param.fieldType, arg.expr)
logf("param %q: adding explicit %s -> %s conversion around argument",
param.info.Name, arg.typ, param.obj.Type())
}
// It is safe to substitute param and replace it with arg.
// The formatter introduces parens as needed for precedence.
//
// Because arg.expr belongs to the caller,
// we clone it before splicing it into the callee tree.
logf("replacing parameter %q by argument %q",
param.info.Name, debugFormatNode(caller.Fset, arg.expr))
for _, ref := range param.info.Refs {
replaceCalleeID(ref, internalastutil.CloneNode(arg.expr).(ast.Expr))
}
params[i] = nil // substituted
args[i] = nil // substituted
}
}
}
// isUsedOutsideCall reports whether v is used outside of caller.Call, within
// the body of caller.enclosingFunc.
func isUsedOutsideCall(caller *Caller, v *types.Var) bool {
used := false
ast.Inspect(caller.enclosingFunc.Body, func(n ast.Node) bool {
if n == caller.Call {
return false
}
switch n := n.(type) {
case *ast.Ident:
if use := caller.Info.Uses[n]; use == v {
used = true
}
case *ast.FuncType:
// All params are used.
for _, fld := range n.Params.List {
for _, n := range fld.Names {
if def := caller.Info.Defs[n]; def == v {
used = true
}
}
}
}
return !used // keep going until we find a use
})
return used
}
// checkFalconConstraints checks whether constant arguments
// are safe to substitute (e.g. s[i] -> ""[0] is not safe.)
//
// Any failed constraint causes us to reject all constant arguments as
// substitution candidates (by clearing args[i].substitution=false).
//
// TODO(adonovan): we could obtain a finer result rejecting only the
// freevars of each failed constraint, and processing constraints in
// order of increasing arity, but failures are quite rare.
func checkFalconConstraints(logf func(string, ...any), params []*parameter, args []*argument, falcon falconResult) {
// Create a dummy package, as this is the only
// way to create an environment for CheckExpr.
pkg := types.NewPackage("falcon", "falcon")
// Declare types used by constraints.
for _, typ := range falcon.Types {
logf("falcon env: type %s %s", typ.Name, types.Typ[typ.Kind])
pkg.Scope().Insert(types.NewTypeName(token.NoPos, pkg, typ.Name, types.Typ[typ.Kind]))
}
// Declared constants and variables for for parameters.
nconst := 0
for i, param := range params {
name := param.info.Name
if name == "" {
continue // unreferenced
}
arg := args[i]
if arg.constant != nil && arg.substitutable && param.info.FalconType != "" {
t := pkg.Scope().Lookup(param.info.FalconType).Type()
pkg.Scope().Insert(types.NewConst(token.NoPos, pkg, name, t, arg.constant))
logf("falcon env: const %s %s = %v", name, param.info.FalconType, arg.constant)
nconst++
} else {
pkg.Scope().Insert(types.NewVar(token.NoPos, pkg, name, arg.typ))
logf("falcon env: var %s %s", name, arg.typ)
}
}
if nconst == 0 {
return // nothing to do
}
// Parse and evaluate the constraints in the environment.
fset := token.NewFileSet()
for _, falcon := range falcon.Constraints {
expr, err := parser.ParseExprFrom(fset, "falcon", falcon, 0)
if err != nil {
panic(fmt.Sprintf("failed to parse falcon constraint %s: %v", falcon, err))
}
if err := types.CheckExpr(fset, pkg, token.NoPos, expr, nil); err != nil {
logf("falcon: constraint %s violated: %v", falcon, err)
for j, arg := range args {
if arg.constant != nil && arg.substitutable {
logf("keeping param %q due falcon violation", params[j].info.Name)
arg.substitutable = false
}
}
break
}
logf("falcon: constraint %s satisfied", falcon)
}
}
// resolveEffects marks arguments as non-substitutable to resolve
// hazards resulting from the callee evaluation order described by the
// effects list.
//
// To do this, each argument is categorized as a read (R), write (W),
// or pure. A hazard occurs when the order of evaluation of a W
// changes with respect to any R or W. Pure arguments can be
// effectively ignored, as they can be safely evaluated in any order.
//
// The callee effects list contains the index of each parameter in the
// order it is first evaluated during execution of the callee. In
// addition, the two special values R∞ and W∞ indicate the relative
// position of the callee's first non-parameter read and its first
// effects (or other unknown behavior).
// For example, the list [0 2 1 R∞ 3 W∞] for func(a, b, c, d)
// indicates that the callee referenced parameters a, c, and b,
// followed by an arbitrary read, then parameter d, and finally
// unknown behavior.
//
// When an argument is marked as not substitutable, we say that it is
// 'bound', in the sense that its evaluation occurs in a binding decl
// or literalized call. Such bindings always occur in the original
// callee parameter order.
//
// In this context, "resolving hazards" means binding arguments so
// that they are evaluated in a valid, hazard-free order. A trivial
// solution to this problem would be to bind all arguments, but of
// course that's not useful. The goal is to bind as few arguments as
// possible.
//
// The algorithm proceeds by inspecting arguments in reverse parameter
// order (right to left), preserving the invariant that every
// higher-ordered argument is either already substituted or does not
// need to be substituted. At each iteration, if there is an
// evaluation hazard in the callee effects relative to the current
// argument, the argument must be bound. Subsequently, if the argument
// is bound for any reason, each lower-ordered argument must also be
// bound if either the argument or lower-order argument is a
// W---otherwise the binding itself would introduce a hazard.
//
// Thus, after each iteration, there are no hazards relative to the
// current argument. Subsequent iterations cannot introduce hazards
// with that argument because they can result only in additional
// binding of lower-ordered arguments.
func resolveEffects(logf func(string, ...any), args []*argument, effects []int) {
effectStr := func(effects bool, idx int) string {
i := fmt.Sprint(idx)
if idx == len(args) {
i = "∞"
}
return string("RW"[btoi(effects)]) + i
}
for i := len(args) - 1; i >= 0; i-- {
argi := args[i]
if argi.substitutable && !argi.pure {
// i is not bound: check whether it must be bound due to hazards.
idx := index(effects, i)
if idx >= 0 {
for _, j := range effects[:idx] {
var (
ji int // effective param index
jw bool // j is a write
)
if j == winf || j == rinf {
jw = j == winf
ji = len(args)
} else {
jw = args[j].effects
ji = j
}
if ji > i && (jw || argi.effects) { // out of order evaluation
logf("binding argument %s: preceded by %s",
effectStr(argi.effects, i), effectStr(jw, ji))
argi.substitutable = false
break
}
}
}
}
if !argi.substitutable {
for j := 0; j < i; j++ {
argj := args[j]
if argj.pure {
continue
}
if (argi.effects || argj.effects) && argj.substitutable {
logf("binding argument %s: %s is bound",
effectStr(argj.effects, j), effectStr(argi.effects, i))
argj.substitutable = false
}
}
}
}
}
// updateCalleeParams updates the calleeDecl syntax to remove
// substituted parameters and move the receiver (if any) to the head
// of the ordinary parameters.
func updateCalleeParams(calleeDecl *ast.FuncDecl, params []*parameter) {
// The logic is fiddly because of the three forms of ast.Field:
//
// func(int), func(x int), func(x, y int)
//
// Also, ensure that all remaining parameters are named
// to avoid a mix of named/unnamed when joining (recv, params...).
// func (T) f(int, bool) -> (_ T, _ int, _ bool)
// (Strictly, we need do this only for methods and only when
// the namednesses of Recv and Params differ; that might be tidier.)
paramIdx := 0 // index in original parameter list (incl. receiver)
var newParams []*ast.Field
filterParams := func(field *ast.Field) {
var names []*ast.Ident
if field.Names == nil {
// Unnamed parameter field (e.g. func f(int)
if params[paramIdx] != nil {
// Give it an explicit name "_" since we will
// make the receiver (if any) a regular parameter
// and one cannot mix named and unnamed parameters.
names = append(names, makeIdent("_"))
}
paramIdx++
} else {
// Named parameter field e.g. func f(x, y int)
// Remove substituted parameters in place.
// If all were substituted, delete field.
for _, id := range field.Names {
if pinfo := params[paramIdx]; pinfo != nil {
// Rename unreferenced parameters with "_".
// This is crucial for binding decls, since
// unlike parameters, they are subject to
// "unreferenced var" checks.
if len(pinfo.info.Refs) == 0 {
id = makeIdent("_")
}
names = append(names, id)
}
paramIdx++
}
}
if names != nil {
newParams = append(newParams, &ast.Field{
Names: names,
Type: field.Type,
})
}
}
if calleeDecl.Recv != nil {
filterParams(calleeDecl.Recv.List[0])
calleeDecl.Recv = nil
}
for _, field := range calleeDecl.Type.Params.List {
filterParams(field)
}
calleeDecl.Type.Params.List = newParams
}
// bindingDeclInfo records information about the binding decl produced by
// createBindingDecl.
type bindingDeclInfo struct {
names map[string]bool // names bound by the binding decl; possibly empty
stmt ast.Stmt // the binding decl itself
}
// createBindingDecl constructs a "binding decl" that implements
// parameter assignment and declares any named result variables
// referenced by the callee. It returns nil if there were no
// unsubstituted parameters.
//
// It may not always be possible to create the decl (e.g. due to
// shadowing), in which case it also returns nil; but if it succeeds,
// the declaration may be used by reduction strategies to relax the
// requirement that all parameters have been substituted.
//
// For example, a call:
//
// f(a0, a1, a2)
//
// where:
//
// func f(p0, p1 T0, p2 T1) { body }
//
// reduces to:
//
// {
// var (
// p0, p1 T0 = a0, a1
// p2 T1 = a2
// )
// body
// }
//
// so long as p0, p1 ∉ freevars(T1) or freevars(a2), and so on,
// because each spec is statically resolved in sequence and
// dynamically assigned in sequence. By contrast, all
// parameters are resolved simultaneously and assigned
// simultaneously.
//
// The pX names should already be blank ("_") if the parameter
// is unreferenced; this avoids "unreferenced local var" checks.
//
// Strategies may impose additional checks on return
// conversions, labels, defer, etc.
func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) *bindingDeclInfo {
// Spread calls are tricky as they may not align with the
// parameters' field groupings nor types.
// For example, given
// func g() (int, string)
// the call
// f(g())
// is legal with these decls of f:
// func f(int, string)
// func f(x, y any)
// func f(x, y ...any)
// TODO(adonovan): support binding decls for spread calls by
// splitting parameter groupings as needed.
if lastArg := last(args); lastArg != nil && lastArg.spread {
logf("binding decls not yet supported for spread calls")
return nil
}
var (
specs []ast.Spec
names = make(map[string]bool) // names defined by previous specs
)
// shadow reports whether any name referenced by spec is
// shadowed by a name declared by a previous spec (since,
// unlike parameters, each spec of a var decl is within the
// scope of the previous specs).
shadow := func(spec *ast.ValueSpec) bool {
// Compute union of free names of type and values
// and detect shadowing. Values is the arguments
// (caller syntax), so we can use type info.
// But Type is the untyped callee syntax,
// so we have to use a syntax-only algorithm.
free := make(map[string]bool)
for _, value := range spec.Values {
for name := range freeVars(caller.Info, value) {
free[name] = true
}
}
freeishNames(free, spec.Type)
for name := range free {
if names[name] {
logf("binding decl would shadow free name %q", name)
return true
}
}
for _, id := range spec.Names {
if id.Name != "_" {
names[id.Name] = true
}
}
return false
}
// parameters
//
// Bind parameters that were not eliminated through
// substitution. (Non-nil arguments correspond to the
// remaining parameters in calleeDecl.)
var values []ast.Expr
for _, arg := range args {
if arg != nil {
values = append(values, arg.expr)
}
}
for _, field := range calleeDecl.Type.Params.List {
// Each field (param group) becomes a ValueSpec.
spec := &ast.ValueSpec{
Names: field.Names,
Type: field.Type,
Values: values[:len(field.Names)],
}
values = values[len(field.Names):]
if shadow(spec) {
return nil
}
specs = append(specs, spec)
}
assert(len(values) == 0, "args/params mismatch")
// results
//
// Add specs to declare any named result
// variables that are referenced by the body.
if calleeDecl.Type.Results != nil {
resultIdx := 0
for _, field := range calleeDecl.Type.Results.List {
if field.Names == nil {
resultIdx++
continue // unnamed field
}
var names []*ast.Ident
for _, id := range field.Names {
if len(results[resultIdx].Refs) > 0 {
names = append(names, id)
}
resultIdx++
}
if len(names) > 0 {
spec := &ast.ValueSpec{
Names: names,
Type: field.Type,
}
if shadow(spec) {
return nil
}
specs = append(specs, spec)
}
}
}
if len(specs) == 0 {
logf("binding decl not needed: all parameters substituted")
return nil
}
stmt := &ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: specs,
},
}
logf("binding decl: %s", debugFormatNode(caller.Fset, stmt))
return &bindingDeclInfo{names: names, stmt: stmt}
}
// lookup does a symbol lookup in the lexical environment of the caller.
func (caller *Caller) lookup(name string) types.Object {
pos := caller.Call.Pos()
for _, n := range caller.path {
if scope := scopeFor(caller.Info, n); scope != nil {
if _, obj := scope.LookupParent(name, pos); obj != nil {
return obj
}
}
}
return nil
}
func scopeFor(info *types.Info, n ast.Node) *types.Scope {
// The function body scope (containing not just params)
// is associated with the function's type, not body.
switch fn := n.(type) {
case *ast.FuncDecl:
n = fn.Type
case *ast.FuncLit:
n = fn.Type
}
return info.Scopes[n]
}
// -- predicates over expressions --
// freeVars returns the names of all free identifiers of e:
// those lexically referenced by it but not defined within it.
// (Fields and methods are not included.)
func freeVars(info *types.Info, e ast.Expr) map[string]bool {
free := make(map[string]bool)
ast.Inspect(e, func(n ast.Node) bool {
if id, ok := n.(*ast.Ident); ok {
// The isField check is so that we don't treat T{f: 0} as a ref to f.
if obj, ok := info.Uses[id]; ok && !within(obj.Pos(), e) && !isField(obj) {
free[obj.Name()] = true
}
}
return true
})
return free
}
// freeishNames computes an over-approximation to the free names
// of the type syntax t, inserting values into the map.
//
// Because we don't have go/types annotations, we can't give an exact
// result in all cases. In particular, an array type [n]T might have a
// size such as unsafe.Sizeof(func() int{stmts...}()) and now the
// precise answer depends upon all the statement syntax too. But that
// never happens in practice.
func freeishNames(free map[string]bool, t ast.Expr) {
var visit func(n ast.Node) bool
visit = func(n ast.Node) bool {
switch n := n.(type) {
case *ast.Ident:
free[n.Name] = true
case *ast.SelectorExpr:
ast.Inspect(n.X, visit)
return false // don't visit .Sel
case *ast.Field:
ast.Inspect(n.Type, visit)
// Don't visit .Names:
// FuncType parameters, interface methods, struct fields
return false
}
return true
}
ast.Inspect(t, visit)
}
// effects reports whether an expression might change the state of the
// program (through function calls and channel receives) and affect
// the evaluation of subsequent expressions.
func (st *state) effects(info *types.Info, expr ast.Expr) bool {
effects := false
ast.Inspect(expr, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit:
return false // prune descent
case *ast.CallExpr:
if info.Types[n.Fun].IsType() {
// A conversion T(x) has only the effect of its operand.
} else if !callsPureBuiltin(info, n) {
// A handful of built-ins have no effect
// beyond those of their arguments.
// All other calls (including append, copy, recover)
// have unknown effects.
//
// As with 'pure', there is room for
// improvement by inspecting the callee.
effects = true
}
case *ast.UnaryExpr:
if n.Op == token.ARROW { // <-ch
effects = true
}
}
return true
})
// Even if consideration of effects is not desired,
// we continue to compute, log, and discard them.
if st.opts.IgnoreEffects && effects {
effects = false
st.opts.Logf("ignoring potential effects of argument %s",
debugFormatNode(st.caller.Fset, expr))
}
return effects
}
// pure reports whether an expression has the same result no matter
// when it is executed relative to other expressions, so it can be
// commuted with any other expression or statement without changing
// its meaning.
//
// An expression is considered impure if it reads the contents of any
// variable, with the exception of "single assignment" local variables
// (as classified by the provided callback), which are never updated
// after their initialization.
//
// Pure does not imply duplicable: for example, new(T) and T{} are
// pure expressions but both return a different value each time they
// are evaluated, so they are not safe to duplicate.
//
// Purity does not imply freedom from run-time panics. We assume that
// target programs do not encounter run-time panics nor depend on them
// for correct operation.
//
// TODO(adonovan): add unit tests of this function.
func pure(info *types.Info, assign1 func(*types.Var) bool, e ast.Expr) bool {
var pure func(e ast.Expr) bool
pure = func(e ast.Expr) bool {
switch e := e.(type) {
case *ast.ParenExpr:
return pure(e.X)
case *ast.Ident:
if v, ok := info.Uses[e].(*types.Var); ok {
// In general variables are impure
// as they may be updated, but
// single-assignment local variables
// never change value.
//
// We assume all package-level variables
// may be updated, but for non-exported
// ones we could do better by analyzing
// the complete package.
return !isPkgLevel(v) && assign1(v)
}
// All other kinds of reference are pure.
return true
case *ast.FuncLit:
// A function literal may allocate a closure that
// references mutable variables, but mutation
// cannot be observed without calling the function,
// and calls are considered impure.
return true
case *ast.BasicLit:
return true
case *ast.UnaryExpr: // + - ! ^ & but not <-
return e.Op != token.ARROW && pure(e.X)
case *ast.BinaryExpr: // arithmetic, shifts, comparisons, &&/||
return pure(e.X) && pure(e.Y)
case *ast.CallExpr:
// A conversion is as pure as its operand.
if info.Types[e.Fun].IsType() {
return pure(e.Args[0])
}
// Calls to some built-ins are as pure as their arguments.
if callsPureBuiltin(info, e) {
for _, arg := range e.Args {
if !pure(arg) {
return false
}
}
return true
}
// All other calls are impure, so we can
// reject them without even looking at e.Fun.
//
// More sophisticated analysis could infer purity in
// commonly used functions such as strings.Contains;
// perhaps we could offer the client a hook so that
// go/analysis-based implementation could exploit the
// results of a purity analysis. But that would make
// the inliner's choices harder to explain.
return false
case *ast.CompositeLit:
// T{...} is as pure as its elements.
for _, elt := range e.Elts {
if kv, ok := elt.(*ast.KeyValueExpr); ok {
if !pure(kv.Value) {
return false
}
if id, ok := kv.Key.(*ast.Ident); ok {
if v, ok := info.Uses[id].(*types.Var); ok && v.IsField() {
continue // struct {field: value}
}
}
// map/slice/array {key: value}
if !pure(kv.Key) {
return false
}
} else if !pure(elt) {
return false
}
}
return true
case *ast.SelectorExpr:
if seln, ok := info.Selections[e]; ok {
// See types.SelectionKind for background.
switch seln.Kind() {
case types.MethodExpr:
// A method expression T.f acts like a
// reference to a func decl, so it is pure.
return true
case types.MethodVal, types.FieldVal:
// A field or method selection x.f is pure
// if x is pure and the selection does
// not indirect a pointer.
return !indirectSelection(seln) && pure(e.X)
default:
panic(seln)
}
} else {
// A qualified identifier is
// treated like an unqualified one.
return pure(e.Sel)
}
case *ast.StarExpr:
return false // *ptr depends on the state of the heap
default:
return false
}
}
return pure(e)
}
// callsPureBuiltin reports whether call is a call of a built-in
// function that is a pure computation over its operands (analogous to
// a + operator). Because it does not depend on program state, it may
// be evaluated at any point--though not necessarily at multiple
// points (consider new, make).
func callsPureBuiltin(info *types.Info, call *ast.CallExpr) bool {
if id, ok := ast.Unparen(call.Fun).(*ast.Ident); ok {
if b, ok := info.ObjectOf(id).(*types.Builtin); ok {
switch b.Name() {
case "len", "cap", "complex", "imag", "real", "make", "new", "max", "min":
return true
}
// Not: append clear close copy delete panic print println recover
}
}
return false
}
// duplicable reports whether it is appropriate for the expression to
// be freely duplicated.
//
// Given the declaration
//
// func f(x T) T { return x + g() + x }
//
// an argument y is considered duplicable if we would wish to see a
// call f(y) simplified to y+g()+y. This is true for identifiers,
// integer literals, unary negation, and selectors x.f where x is not
// a pointer. But we would not wish to duplicate expressions that:
// - have side effects (e.g. nearly all calls),
// - are not referentially transparent (e.g. &T{}, ptr.field, *ptr), or
// - are long (e.g. "huge string literal").
func duplicable(info *types.Info, e ast.Expr) bool {
switch e := e.(type) {
case *ast.ParenExpr:
return duplicable(info, e.X)
case *ast.Ident:
return true
case *ast.BasicLit:
v := info.Types[e].Value
switch e.Kind {
case token.INT:
return true // any int
case token.STRING:
return consteq(v, kZeroString) // only ""
case token.FLOAT:
return consteq(v, kZeroFloat) || consteq(v, kOneFloat) // only 0.0 or 1.0
}
case *ast.UnaryExpr: // e.g. +1, -1
return (e.Op == token.ADD || e.Op == token.SUB) && duplicable(info, e.X)
case *ast.CompositeLit:
// Empty struct or array literals T{} are duplicable.
// (Non-empty literals are too verbose, and slice/map
// literals allocate indirect variables.)
if len(e.Elts) == 0 {
switch info.TypeOf(e).Underlying().(type) {
case *types.Struct, *types.Array:
return true
}
}
return false
case *ast.CallExpr:
// Treat type conversions as duplicable if they do not observably allocate.
// The only cases of observable allocations are
// the `[]byte(string)` and `[]rune(string)` conversions.
//
// Duplicating string([]byte) conversions increases
// allocation but doesn't change behavior, but the
// reverse, []byte(string), allocates a distinct array,
// which is observable.
if !info.Types[e.Fun].IsType() { // check whether e.Fun is a type conversion
return false
}
fun := info.TypeOf(e.Fun)
arg := info.TypeOf(e.Args[0])
switch fun := fun.Underlying().(type) {
case *types.Slice:
// Do not mark []byte(string) and []rune(string) as duplicable.
elem, ok := fun.Elem().Underlying().(*types.Basic)
if ok && (elem.Kind() == types.Rune || elem.Kind() == types.Byte) {
from, ok := arg.Underlying().(*types.Basic)
isString := ok && from.Info()&types.IsString != 0
return !isString
}
case *types.TypeParam:
return false // be conservative
}
return true
case *ast.SelectorExpr:
if seln, ok := info.Selections[e]; ok {
// A field or method selection x.f is referentially
// transparent if it does not indirect a pointer.
return !indirectSelection(seln)
}
// A qualified identifier pkg.Name is referentially transparent.
return true
}
return false
}
func consteq(x, y constant.Value) bool {
return constant.Compare(x, token.EQL, y)
}
var (
kZeroInt = constant.MakeInt64(0)
kZeroString = constant.MakeString("")
kZeroFloat = constant.MakeFloat64(0.0)
kOneFloat = constant.MakeFloat64(1.0)
)
// -- inline helpers --
func assert(cond bool, msg string) {
if !cond {
panic(msg)
}
}
// blanks returns a slice of n > 0 blank identifiers.
func blanks[E ast.Expr](n int) []E {
if n == 0 {
panic("blanks(0)")
}
res := make([]E, n)
for i := range res {
res[i] = ast.Expr(makeIdent("_")).(E) // ugh
}
return res
}
func makeIdent(name string) *ast.Ident {
return &ast.Ident{Name: name}
}
// importedPkgName returns the PkgName object declared by an ImportSpec.
// TODO(adonovan): make this a method of types.Info (#62037).
func importedPkgName(info *types.Info, imp *ast.ImportSpec) (*types.PkgName, bool) {
var obj types.Object
if imp.Name != nil {
obj = info.Defs[imp.Name]
} else {
obj = info.Implicits[imp]
}
pkgname, ok := obj.(*types.PkgName)
return pkgname, ok
}
func isPkgLevel(obj types.Object) bool {
// TODO(adonovan): consider using the simpler obj.Parent() ==
// obj.Pkg().Scope() instead. But be sure to test carefully
// with instantiations of generics.
return obj.Pkg().Scope().Lookup(obj.Name()) == obj
}
// callContext returns the two nodes immediately enclosing the call
// (specified as a PathEnclosingInterval), ignoring parens.
func callContext(callPath []ast.Node) (parent, grandparent ast.Node) {
_ = callPath[0].(*ast.CallExpr) // sanity check
for _, n := range callPath[1:] {
if !is[*ast.ParenExpr](n) {
if parent == nil {
parent = n
} else {
return parent, n
}
}
}
return parent, nil
}
// hasLabelConflict reports whether the set of labels of the function
// enclosing the call (specified as a PathEnclosingInterval)
// intersects with the set of callee labels.
func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool {
labels := callerLabels(callPath)
for _, label := range calleeLabels {
if labels[label] {
return true // conflict
}
}
return false
}
// callerLabels returns the set of control labels in the function (if
// any) enclosing the call (specified as a PathEnclosingInterval).
func callerLabels(callPath []ast.Node) map[string]bool {
var callerBody *ast.BlockStmt
switch f := callerFunc(callPath).(type) {
case *ast.FuncDecl:
callerBody = f.Body
case *ast.FuncLit:
callerBody = f.Body
}
var labels map[string]bool
if callerBody != nil {
ast.Inspect(callerBody, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit:
return false // prune traversal
case *ast.LabeledStmt:
if labels == nil {
labels = make(map[string]bool)
}
labels[n.Label.Name] = true
}
return true
})
}
return labels
}
// callerFunc returns the innermost Func{Decl,Lit} node enclosing the
// call (specified as a PathEnclosingInterval).
func callerFunc(callPath []ast.Node) ast.Node {
_ = callPath[0].(*ast.CallExpr) // sanity check
for _, n := range callPath[1:] {
if is[*ast.FuncDecl](n) || is[*ast.FuncLit](n) {
return n
}
}
return nil
}
// callStmt reports whether the function call (specified
// as a PathEnclosingInterval) appears within an ExprStmt,
// and returns it if so.
//
// If unrestricted, callStmt returns nil if the ExprStmt f() appears
// in a restricted context (such as "if f(); cond {") where it cannot
// be replaced by an arbitrary statement. (See "statement theory".)
func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt {
parent, _ := callContext(callPath)
stmt, ok := parent.(*ast.ExprStmt)
if ok && unrestricted {
switch callPath[nodeIndex(callPath, stmt)+1].(type) {
case *ast.LabeledStmt,
*ast.BlockStmt,
*ast.CaseClause,
*ast.CommClause:
// unrestricted
default:
// TODO(adonovan): handle restricted
// XYZStmt.Init contexts (but not ForStmt.Post)
// by creating a block around the if/for/switch:
// "if f(); cond {" -> "{ stmts; if cond {"
return nil // restricted
}
}
return stmt
}
// Statement theory
//
// These are all the places a statement may appear in the AST:
//
// LabeledStmt.Stmt Stmt -- any
// BlockStmt.List []Stmt -- any (but see switch/select)
// IfStmt.Init Stmt? -- simple
// IfStmt.Body BlockStmt
// IfStmt.Else Stmt? -- IfStmt or BlockStmt
// CaseClause.Body []Stmt -- any
// SwitchStmt.Init Stmt? -- simple
// SwitchStmt.Body BlockStmt -- CaseClauses only
// TypeSwitchStmt.Init Stmt? -- simple
// TypeSwitchStmt.Assign Stmt -- AssignStmt(TypeAssertExpr) or ExprStmt(TypeAssertExpr)
// TypeSwitchStmt.Body BlockStmt -- CaseClauses only
// CommClause.Comm Stmt? -- SendStmt or ExprStmt(UnaryExpr) or AssignStmt(UnaryExpr)
// CommClause.Body []Stmt -- any
// SelectStmt.Body BlockStmt -- CommClauses only
// ForStmt.Init Stmt? -- simple
// ForStmt.Post Stmt? -- simple
// ForStmt.Body BlockStmt
// RangeStmt.Body BlockStmt
//
// simple = AssignStmt | SendStmt | IncDecStmt | ExprStmt.
//
// A BlockStmt cannot replace an ExprStmt in
// {If,Switch,TypeSwitch}Stmt.Init or ForStmt.Post.
// That is allowed only within:
// LabeledStmt.Stmt Stmt
// BlockStmt.List []Stmt
// CaseClause.Body []Stmt
// CommClause.Body []Stmt
// replaceNode performs a destructive update of the tree rooted at
// root, replacing each occurrence of "from" with "to". If to is nil and
// the element is within a slice, the slice element is removed.
//
// The root itself cannot be replaced; an attempt will panic.
//
// This function must not be called on the caller's syntax tree.
//
// TODO(adonovan): polish this up and move it to astutil package.
// TODO(adonovan): needs a unit test.
func replaceNode(root ast.Node, from, to ast.Node) {
if from == nil {
panic("from == nil")
}
if reflect.ValueOf(from).IsNil() {
panic(fmt.Sprintf("from == (%T)(nil)", from))
}
if from == root {
panic("from == root")
}
found := false
var parent reflect.Value // parent variable of interface type, containing a pointer
var visit func(reflect.Value)
visit = func(v reflect.Value) {
switch v.Kind() {
case reflect.Ptr:
if v.Interface() == from {
found = true
// If v is a struct field or array element
// (e.g. Field.Comment or Field.Names[i])
// then it is addressable (a pointer variable).
//
// But if it was the value an interface
// (e.g. *ast.Ident within ast.Node)
// then it is non-addressable, and we need
// to set the enclosing interface (parent).
if !v.CanAddr() {
v = parent
}
// to=nil => use zero value
var toV reflect.Value
if to != nil {
toV = reflect.ValueOf(to)
} else {
toV = reflect.Zero(v.Type()) // e.g. ast.Expr(nil)
}
v.Set(toV)
} else if !v.IsNil() {
switch v.Interface().(type) {
case *ast.Object, *ast.Scope:
// Skip fields of types potentially involved in cycles.
default:
visit(v.Elem())
}
}
case reflect.Struct:
for i := 0; i < v.Type().NumField(); i++ {
visit(v.Field(i))
}
case reflect.Slice:
compact := false
for i := 0; i < v.Len(); i++ {
visit(v.Index(i))
if v.Index(i).IsNil() {
compact = true
}
}
if compact {
// Elements were deleted. Eliminate nils.
// (Do this is a second pass to avoid
// unnecessary writes in the common case.)
j := 0
for i := 0; i < v.Len(); i++ {
if !v.Index(i).IsNil() {
v.Index(j).Set(v.Index(i))
j++
}
}
v.SetLen(j)
}
case reflect.Interface:
parent = v
visit(v.Elem())
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer:
panic(v) // unreachable in AST
default:
// bool, string, number: nop
}
parent = reflect.Value{}
}
visit(reflect.ValueOf(root))
if !found {
panic(fmt.Sprintf("%T not found", from))
}
}
// clearPositions destroys token.Pos information within the tree rooted at root,
// as positions in callee trees may cause caller comments to be emitted prematurely.
//
// In general it isn't safe to clear a valid Pos because some of them
// (e.g. CallExpr.Ellipsis, TypeSpec.Assign) are significant to
// go/printer, so this function sets each non-zero Pos to 1, which
// suffices to avoid advancing the printer's comment cursor.
//
// This function mutates its argument; do not invoke on caller syntax.
//
// TODO(adonovan): remove this horrendous workaround when #20744 is finally fixed.
func clearPositions(root ast.Node) {
posType := reflect.TypeOf(token.NoPos)
ast.Inspect(root, func(n ast.Node) bool {
if n != nil {
v := reflect.ValueOf(n).Elem() // deref the pointer to struct
fields := v.Type().NumField()
for i := 0; i < fields; i++ {
f := v.Field(i)
// Clearing Pos arbitrarily is destructive,
// as its presence may be semantically significant
// (e.g. CallExpr.Ellipsis, TypeSpec.Assign)
// or affect formatting preferences (e.g. GenDecl.Lparen).
//
// Note: for proper formatting, it may be necessary to be selective
// about which positions we set to 1 vs which we set to token.NoPos.
// (e.g. we can set most to token.NoPos, save the few that are
// significant).
if f.Type() == posType {
if f.Interface() != token.NoPos {
f.Set(reflect.ValueOf(token.Pos(1)))
}
}
}
}
return true
})
}
// findIdent returns the Ident beneath root that has the given pos.
func findIdent(root ast.Node, pos token.Pos) *ast.Ident {
// TODO(adonovan): opt: skip subtrees that don't contain pos.
var found *ast.Ident
ast.Inspect(root, func(n ast.Node) bool {
if found != nil {
return false
}
if id, ok := n.(*ast.Ident); ok {
if id.Pos() == pos {
found = id
}
}
return true
})
if found == nil {
panic(fmt.Sprintf("findIdent %d not found in %s",
pos, debugFormatNode(token.NewFileSet(), root)))
}
return found
}
func prepend[T any](elem T, slice ...T) []T {
return append([]T{elem}, slice...)
}
// debugFormatNode formats a node or returns a formatting error.
// Its sloppy treatment of errors is appropriate only for logging.
func debugFormatNode(fset *token.FileSet, n ast.Node) string {
var out strings.Builder
if err := format.Node(&out, fset, n); err != nil {
out.WriteString(err.Error())
}
return out.String()
}
func shallowCopy[T any](ptr *T) *T {
copy := *ptr
return &copy
}
// ∀
func forall[T any](list []T, f func(i int, x T) bool) bool {
for i, x := range list {
if !f(i, x) {
return false
}
}
return true
}
// ∃
func exists[T any](list []T, f func(i int, x T) bool) bool {
for i, x := range list {
if f(i, x) {
return true
}
}
return false
}
// last returns the last element of a slice, or zero if empty.
func last[T any](slice []T) T {
n := len(slice)
if n > 0 {
return slice[n-1]
}
return *new(T)
}
// canImport reports whether one package is allowed to import another.
//
// TODO(adonovan): allow customization of the accessibility relation
// (e.g. for Bazel).
func canImport(from, to string) bool {
// TODO(adonovan): better segment hygiene.
if strings.HasPrefix(to, "internal/") {
// Special case: only std packages may import internal/...
// We can't reliably know whether we're in std, so we
// use a heuristic on the first segment.
first, _, _ := strings.Cut(from, "/")
if strings.Contains(first, ".") {
return false // example.com/foo ∉ std
}
if first == "testdata" {
return false // testdata/foo ∉ std
}
}
if i := strings.LastIndex(to, "/internal/"); i >= 0 {
return strings.HasPrefix(from, to[:i])
}
return true
}
// consistentOffsets reports whether the portion of caller.Content
// that corresponds to caller.Call can be parsed as a call expression.
// If not, the client has provided inconsistent information, possibly
// because they forgot to ignore line directives when computing the
// filename enclosing the call.
// This is just a heuristic.
func consistentOffsets(caller *Caller) bool {
start := offsetOf(caller.Fset, caller.Call.Pos())
end := offsetOf(caller.Fset, caller.Call.End())
if !(0 < start && start < end && end <= len(caller.Content)) {
return false
}
expr, err := parser.ParseExpr(string(caller.Content[start:end]))
if err != nil {
return false
}
return is[*ast.CallExpr](expr)
}
// needsParens reports whether parens are required to avoid ambiguity
// around the new node replacing the specified old node (which is some
// ancestor of the CallExpr identified by its PathEnclosingInterval).
func needsParens(callPath []ast.Node, old, new ast.Node) bool {
// Find enclosing old node and its parent.
i := nodeIndex(callPath, old)
if i == -1 {
panic("not found")
}
// There is no precedence ambiguity when replacing
// (e.g.) a statement enclosing the call.
if !is[ast.Expr](old) {
return false
}
// An expression beneath a non-expression
// has no precedence ambiguity.
parent, ok := callPath[i+1].(ast.Expr)
if !ok {
return false
}
precedence := func(n ast.Node) int {
switch n := n.(type) {
case *ast.UnaryExpr, *ast.StarExpr:
return token.UnaryPrec
case *ast.BinaryExpr:
return n.Op.Precedence()
}
return -1
}
// Parens are not required if the new node
// is not unary or binary.
newprec := precedence(new)
if newprec < 0 {
return false
}
// Parens are required if parent and child are both
// unary or binary and the parent has higher precedence.
if precedence(parent) > newprec {
return true
}
// Was the old node the operand of a postfix operator?
// f().sel
// f()[i:j]
// f()[i]
// f().(T)
// f()(x)
switch parent := parent.(type) {
case *ast.SelectorExpr:
return parent.X == old
case *ast.IndexExpr:
return parent.X == old
case *ast.SliceExpr:
return parent.X == old
case *ast.TypeAssertExpr:
return parent.X == old
case *ast.CallExpr:
return parent.Fun == old
}
return false
}
func nodeIndex(nodes []ast.Node, n ast.Node) int {
// TODO(adonovan): Use index[ast.Node]() in go1.20.
for i, node := range nodes {
if node == n {
return i
}
}
return -1
}
// declares returns the set of lexical names declared by a
// sequence of statements from the same block, excluding sub-blocks.
// (Lexical names do not include control labels.)
func declares(stmts []ast.Stmt) map[string]bool {
names := make(map[string]bool)
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case *ast.DeclStmt:
for _, spec := range stmt.Decl.(*ast.GenDecl).Specs {
switch spec := spec.(type) {
case *ast.ValueSpec:
for _, id := range spec.Names {
names[id.Name] = true
}
case *ast.TypeSpec:
names[spec.Name.Name] = true
}
}
case *ast.AssignStmt:
if stmt.Tok == token.DEFINE {
for _, lhs := range stmt.Lhs {
names[lhs.(*ast.Ident).Name] = true
}
}
}
}
delete(names, "_")
return names
}
// assignStmts rewrites a statement assigning the results of a call into zero
// or more statements that assign its return operands, or (nil, false) if no
// such rewrite is possible. The set of bindings created by the result of
// assignStmts is the same as the set of bindings created by the callerStmt.
//
// The callee must contain exactly one return statement.
//
// This is (once again) a surprisingly complex task. For example, depending on
// types and existing bindings, the assignment
//
// a, b := f()
//
// could be rewritten as:
//
// a, b := 1, 2
//
// but may need to be written as:
//
// a, b := int8(1), int32(2)
//
// In the case where the return statement within f is a spread call to another
// function g(), we cannot explicitly convert the return values inline, and so
// it may be necessary to split the declaration and assignment of variables
// into separate statements:
//
// a, b := g()
//
// or
//
// var a int32
// a, b = g()
//
// or
//
// var (
// a int8
// b int32
// )
// a, b = g()
//
// Note: assignStmts may return (nil, true) if it determines that the rewritten
// assignment consists only of _ = nil assignments.
func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Expr) ([]ast.Stmt, bool) {
logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl
assert(len(callee.Returns) == 1, "unexpected multiple returns")
resultInfo := callee.Returns[0]
// When constructing assign statements, we need to make sure that we don't
// modify types on the left-hand side, such as would happen if the type of a
// RHS expression does not match the corresponding LHS type at the caller
// (due to untyped conversion or interface widening).
//
// This turns out to be remarkably tricky to handle correctly.
//
// Substrategies below are labeled as `Substrategy <name>:`.
// Collect LHS information.
var (
lhs []ast.Expr // shallow copy of the LHS slice, for mutation
defs = make([]*ast.Ident, len(callerStmt.Lhs)) // indexes in lhs of defining identifiers
blanks = make([]bool, len(callerStmt.Lhs)) // indexes in lhs of blank identifiers
byType typeutil.Map // map of distinct types -> indexes, for writing specs later
)
for i, expr := range callerStmt.Lhs {
lhs = append(lhs, expr)
if name, ok := expr.(*ast.Ident); ok {
if name.Name == "_" {
blanks[i] = true
continue // no type
}
if obj, isDef := caller.Info.Defs[name]; isDef {
defs[i] = name
typ := obj.Type()
idxs, _ := byType.At(typ).([]int)
idxs = append(idxs, i)
byType.Set(typ, idxs)
}
}
}
// Collect RHS information
//
// The RHS is either a parallel assignment or spread assignment, but by
// looping over both callerStmt.Rhs and returnOperands we handle both.
var (
rhs []ast.Expr // new RHS of assignment, owned by the inliner
callIdx = -1 // index of the call among the original RHS
nilBlankAssigns = make(map[int]unit) // indexes in rhs of _ = nil assignments, which can be deleted
freeNames = make(map[string]bool) // free(ish) names among rhs expressions
nonTrivial = make(map[int]bool) // indexes in rhs of nontrivial result conversions
)
for i, expr := range callerStmt.Rhs {
if expr == caller.Call {
assert(callIdx == -1, "malformed (duplicative) AST")
callIdx = i
for j, returnOperand := range returnOperands {
freeishNames(freeNames, returnOperand)
rhs = append(rhs, returnOperand)
if resultInfo[j]&nonTrivialResult != 0 {
nonTrivial[i+j] = true
}
if blanks[i+j] && resultInfo[j]&untypedNilResult != 0 {
nilBlankAssigns[i+j] = unit{}
}
}
} else {
// We must clone before clearing positions, since e came from the caller.
expr = internalastutil.CloneNode(expr)
clearPositions(expr)
freeishNames(freeNames, expr)
rhs = append(rhs, expr)
}
}
assert(callIdx >= 0, "failed to find call in RHS")
// Substrategy "splice": Check to see if we can simply splice in the result
// expressions from the callee, such as simplifying
//
// x, y := f()
//
// to
//
// x, y := e1, e2
//
// where the types of x and y match the types of e1 and e2.
//
// This works as long as we don't need to write any additional type
// information.
if callerStmt.Tok == token.ASSIGN && // LHS types already determined before call
len(nonTrivial) == 0 { // no non-trivial conversions to worry about
logf("substrategy: slice assignment")
return []ast.Stmt{&ast.AssignStmt{
Lhs: lhs,
Tok: callerStmt.Tok,
TokPos: callerStmt.TokPos,
Rhs: rhs,
}}, true
}
// Inlining techniques below will need to write type information in order to
// preserve the correct types of LHS identifiers.
//
// writeType is a simple helper to write out type expressions.
// TODO(rfindley):
// 1. handle qualified type names (potentially adding new imports)
// 2. expand this to handle more type expressions.
// 3. refactor to share logic with callee rewriting.
universeAny := types.Universe.Lookup("any")
typeExpr := func(typ types.Type, shadows ...map[string]bool) ast.Expr {
var typeName string
switch typ := typ.(type) {
case *types.Basic:
typeName = typ.Name()
case interface{ Obj() *types.TypeName }: // Named, Alias, TypeParam
typeName = typ.Obj().Name()
}
// Special case: check for universe "any".
// TODO(golang/go#66921): this may become unnecessary if any becomes a proper alias.
if typ == universeAny.Type() {
typeName = "any"
}
if typeName == "" {
return nil
}
for _, shadow := range shadows {
if shadow[typeName] {
logf("cannot write shadowed type name %q", typeName)
return nil
}
}
obj, _ := caller.lookup(typeName).(*types.TypeName)
if obj != nil && types.Identical(obj.Type(), typ) {
return ast.NewIdent(typeName)
}
return nil
}
// Substrategy "spread": in the case of a spread call (func f() (T1, T2) return
// g()), since we didn't hit the 'splice' substrategy, there must be some
// non-declaring expression on the LHS. Simplify this by pre-declaring
// variables, rewriting
//
// x, y := f()
//
// to
//
// var x int
// x, y = g()
//
// Which works as long as the predeclared variables do not overlap with free
// names on the RHS.
if len(rhs) != len(lhs) {
assert(len(rhs) == 1 && len(returnOperands) == 1, "expected spread call")
for _, id := range defs {
if id != nil && freeNames[id.Name] {
// By predeclaring variables, we're changing them to be in scope of the
// RHS. We can't do this if their names are free on the RHS.
return nil, false
}
}
// Write out the specs, being careful to avoid shadowing free names in
// their type expressions.
var (
specs []ast.Spec
specIdxs []int
shadow = make(map[string]bool)
)
failed := false
byType.Iterate(func(typ types.Type, v any) {
if failed {
return
}
idxs := v.([]int)
specIdxs = append(specIdxs, idxs[0])
texpr := typeExpr(typ, shadow)
if texpr == nil {
failed = true
return
}
spec := &ast.ValueSpec{
Type: texpr,
}
for _, idx := range idxs {
spec.Names = append(spec.Names, ast.NewIdent(defs[idx].Name))
}
specs = append(specs, spec)
})
if failed {
return nil, false
}
logf("substrategy: spread assignment")
return []ast.Stmt{
&ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: specs,
},
},
&ast.AssignStmt{
Lhs: callerStmt.Lhs,
Tok: token.ASSIGN,
Rhs: returnOperands,
},
}, true
}
assert(len(lhs) == len(rhs), "mismatching LHS and RHS")
// Substrategy "convert": write out RHS expressions with explicit type conversions
// as necessary, rewriting
//
// x, y := f()
//
// to
//
// x, y := 1, int32(2)
//
// As required to preserve types.
//
// In the special case of _ = nil, which is disallowed by the type checker
// (since nil has no default type), we delete the assignment.
var origIdxs []int // maps back to original indexes after lhs and rhs are pruned
i := 0
for j := range lhs {
if _, ok := nilBlankAssigns[j]; !ok {
lhs[i] = lhs[j]
rhs[i] = rhs[j]
origIdxs = append(origIdxs, j)
i++
}
}
lhs = lhs[:i]
rhs = rhs[:i]
if len(lhs) == 0 {
logf("trivial assignment after pruning nil blanks assigns")
// After pruning, we have no remaining assignments.
// Signal this by returning a non-nil slice of statements.
return nil, true
}
// Write out explicit conversions as necessary.
//
// A conversion is necessary if the LHS is being defined, and the RHS return
// involved a nontrivial implicit conversion.
for i, expr := range rhs {
idx := origIdxs[i]
if nonTrivial[idx] && defs[idx] != nil {
typ := caller.Info.TypeOf(lhs[i])
texpr := typeExpr(typ)
if texpr == nil {
return nil, false
}
if _, ok := texpr.(*ast.StarExpr); ok {
// TODO(rfindley): is this necessary? Doesn't the formatter add these parens?
texpr = &ast.ParenExpr{X: texpr} // *T -> (*T) so that (*T)(x) is valid
}
rhs[i] = &ast.CallExpr{
Fun: texpr,
Args: []ast.Expr{expr},
}
}
}
logf("substrategy: convert assignment")
return []ast.Stmt{&ast.AssignStmt{
Lhs: lhs,
Tok: callerStmt.Tok,
Rhs: rhs,
}}, true
}
// tailCallSafeReturn reports whether the callee's return statements may be safely
// used to return from the function enclosing the caller (which must exist).
func tailCallSafeReturn(caller *Caller, calleeSymbol *types.Func, callee *gobCallee) bool {
// It is safe if all callee returns involve only trivial conversions.
if !hasNonTrivialReturn(callee.Returns) {
return true
}
var callerType types.Type
// Find type of innermost function enclosing call.
// (Beware: Caller.enclosingFunc is the outermost.)
loop:
for _, n := range caller.path {
switch f := n.(type) {
case *ast.FuncDecl:
callerType = caller.Info.ObjectOf(f.Name).Type()
break loop
case *ast.FuncLit:
callerType = caller.Info.TypeOf(f)
break loop
}
}
// Non-trivial return conversions in the callee are permitted
// if the same non-trivial conversion would occur after inlining,
// i.e. if the caller and callee results tuples are identical.
callerResults := callerType.(*types.Signature).Results()
calleeResults := calleeSymbol.Type().(*types.Signature).Results()
return types.Identical(callerResults, calleeResults)
}
// hasNonTrivialReturn reports whether any of the returns involve a nontrivial
// implicit conversion of a result expression.
func hasNonTrivialReturn(returnInfo [][]returnOperandFlags) bool {
for _, resultInfo := range returnInfo {
for _, r := range resultInfo {
if r&nonTrivialResult != 0 {
return true
}
}
}
return false
}
// soleUse returns the ident that refers to obj, if there is exactly one.
func soleUse(info *types.Info, obj types.Object) (sole *ast.Ident) {
// This is not efficient, but it is called infrequently.
for id, obj2 := range info.Uses {
if obj2 == obj {
if sole != nil {
return nil // not unique
}
sole = id
}
}
return sole
}
type unit struct{} // for representing sets as maps
// slicesDeleteFunc removes any elements from s for which del returns true,
// returning the modified slice.
// slicesDeleteFunc zeroes the elements between the new length and the original length.
// TODO(adonovan): use go1.21 slices.DeleteFunc
func slicesDeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i := slicesIndexFunc(s, del)
if i == -1 {
return s
}
// Don't start copying elements until we find one to delete.
for j := i + 1; j < len(s); j++ {
if v := s[j]; !del(v) {
s[i] = v
i++
}
}
// clear(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i]
}
// slicesIndexFunc returns the first index i satisfying f(s[i]),
// or -1 if none do.
func slicesIndexFunc[S ~[]E, E any](s S, f func(E) bool) int {
for i := range s {
if f(s[i]) {
return i
}
}
return -1
}