blob: 0dee837e8af6591d63f9b4b594ce67d497a3fbb3 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fix
import (
"fmt"
"go/token"
"strings"
"github.com/dave/dst"
)
// invariant: selFromExpr() must only be called on expressions for which
// c.protoFieldSelectorOrAccessor() returns ok. assignmentIsSwap() ensures that
// for all lhs/rhs expressions.
func selFromExpr(c *cursor, expr dst.Expr) *dst.SelectorExpr {
if ce, ok := expr.(*dst.CallExpr); ok {
sel, _, _ := c.protoFieldSelectorOrAccessor(ce.Fun)
return sel
}
return expr.(*dst.SelectorExpr)
}
func idFromExpr(expr dst.Expr) *dst.Ident {
if expr == nil {
return nil
}
switch x := expr.(type) {
case *dst.Ident:
return x
case *dst.CallExpr:
return idFromExpr(x.Fun)
case *dst.SelectorExpr:
return idFromExpr(x.Sel)
default:
return nil
}
}
func assignmentIsSwap(c *cursor, stmt *dst.AssignStmt) error {
if stmt.Tok != token.ASSIGN {
return fmt.Errorf("ignoring AssignStmt with Tok %v (looking for ASSIGN)", stmt.Tok)
}
if len(stmt.Lhs) != len(stmt.Rhs) {
return fmt.Errorf("ignoring AssignStmt with len(lhs)=%d != len(rhs)=%d (cannot be a swap)", len(stmt.Lhs), len(stmt.Rhs))
}
if len(stmt.Lhs) == 1 {
return fmt.Errorf("ignoring AssignStmt with 1 lhs/rhs (assignment, not swap)")
}
for _, expr := range append(append([]dst.Expr(nil), stmt.Lhs...), stmt.Rhs...) {
if ce, ok := expr.(*dst.CallExpr); ok {
expr = ce.Fun
}
se, ok := expr.(*dst.SelectorExpr)
if !ok {
return fmt.Errorf("ignoring AssignStmt: %T is not a SelectorExpr (cannot be a swap)", expr)
}
if idFromExpr(se.X) == nil {
return fmt.Errorf("ignoring AssignStmt: could not find Ident in SelectorExpr.X %T", expr)
}
sel, _, ok := c.protoFieldSelectorOrAccessor(expr)
if !ok {
return fmt.Errorf("ignoring AssignStmt: %T is not a proto field selector", expr)
}
t := c.typeOfOrNil(sel.X)
if t == nil {
continue // no type info (silo'ed?), assume tracked
}
if !c.shouldUpdateType(t) {
return fmt.Errorf("should not update type %v", t)
}
}
c.Logf("rewriting swap")
// As soon as we find any repetition among the left and right hand side, we
// treat the assignment as a swap.
for _, lhs := range stmt.Lhs {
lhsSel := selFromExpr(c, lhs)
lhsXObj := c.objectOf(idFromExpr(lhsSel.X))
lhsSelName := c.objectOf(lhsSel.Sel).Name()
for _, rhs := range stmt.Rhs {
rhsSel := selFromExpr(c, rhs)
rhsXObj := c.objectOf(idFromExpr(rhsSel.X))
rhsSelName := c.objectOf(rhsSel.Sel).Name()
// If the RHS is a function call (LHS of an assignment cannot be a
// function call), it must be a getter (the other accessors do not
// return a value), so remove the Get prefix to match the LHS name.
if _, ok := rhs.(*dst.CallExpr); ok {
rhsSelName = strings.TrimPrefix(rhsSelName, "Get")
}
if lhsXObj == rhsXObj && lhsSelName == rhsSelName {
return nil
}
}
}
return fmt.Errorf("ignoring AssignStmt: no repetitions among lhs/rhs")
}
// assignSwapPre splits swaps (m.F1, m.F2 = m.F2, m.F1) into two separate assign
// statements which can then be rewritten by subsequent rewrite stages.
func assignSwapPre(c *cursor) bool {
// NOTE(stapelberg): This stage is only yellow level because the subsequent
// getPost and assignPre stages only deal with pointer-typed variables in
// the yellow level. But, safety-wise, this could go into the green level.
if !c.lvl.ge(Yellow) {
return true
}
stmt, ok := c.Node().(*dst.AssignStmt)
if !ok {
c.Logf("ignoring %T (looking for AssignStmt)", c.Node())
return true
}
if err := assignmentIsSwap(c, stmt); err != nil {
c.Logf("%s", err.Error())
return true
}
assign2 := &dst.AssignStmt{
Lhs: stmt.Lhs,
Tok: token.ASSIGN,
Rhs: nil, // will be filled with helper variable names
}
stmt.Lhs = nil // will be filled with helper variable names
for _, rhs := range stmt.Rhs {
rhsSel := selFromExpr(c, rhs)
helperName := c.helperNameFor(rhs, c.typeOf(rhsSel.X))
helperIdent := &dst.Ident{Name: helperName}
updateASTMap(c, rhs, helperIdent)
c.setType(helperIdent, c.typeOf(rhs))
stmt.Lhs = append(stmt.Lhs, helperIdent)
assign2.Rhs = append(assign2.Rhs, cloneIdent(c, helperIdent))
}
stmt.Tok = token.DEFINE
assign2.Decorations().After = stmt.Decorations().After
assign2.Decorations().End = stmt.Decorations().End
stmt.Decorations().After = dst.None
stmt.Decorations().End = nil
c.InsertAfter(assign2)
return true
}