blob: 987a325bf51638ebee074eb1da6dfd92bcc0def6 [file] [log] [blame] [edit]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fix
import (
"go/token"
"go/types"
"github.com/dave/dst"
"github.com/dave/dst/dstutil"
"google.golang.org/open2opaque/internal/protodetecttypes"
)
// usePointersPre replaces usage of Go proto structs as values with pointer to structs (i.e. the
// only correct way to handle protos).
//
// NOTE: usePointersPre is called (once) for the *dst.File. Recursion aborts
// because it returns false; the dstutil.Apply() calls within the function
// traverse the entire file.
func usePointersPre(c *cursor) bool {
// The rewrite below is only for the Red level.
if c.lvl.le(Yellow) {
return false
}
// x := pb.M{} => x := &pb.M{}
// x := pb.M{F: V} => x := &pb.M{F: V}
// This is ok if x is never copied (shallow copy).
// It could be ok in other cases too but we implement the common case for now.
// We only deal with the short-form definition for now (the long form hasn't showed up in practice).
dstutil.Apply(c.Node(), func(cur *dstutil.Cursor) bool {
switch n := cur.Node().(type) {
case *dst.BlockStmt:
for j, stmt := range n.List {
ident, ok := nonPtrProtoStructDef(c, stmt)
if !ok {
continue
}
if replaceWithPtr(c, ident, n.List[j+1:]) {
c.setType(ident, types.NewPointer(c.typeOf(ident)))
a := stmt.(*dst.AssignStmt)
a.Rhs[0] = addr(c, a.Rhs[0])
c.numUnsafeRewritesByReason[PotentialBuildBreakage]++
}
}
case *dst.Field:
T := c.typeOf(n.Type)
_, alreadyStar := n.Type.(*dst.StarExpr)
if (protodetecttypes.Type{T: T}).IsMessage() && !alreadyStar {
sexpr := &dst.StarExpr{X: n.Type}
c.setType(sexpr, types.NewPointer(T))
n.Type = sexpr
c.numUnsafeRewritesByReason[PotentialBuildBreakage]++
}
}
return true
}, nil)
// Rewrite all non-pointer composite literals, even though it breaks
// compilation. We annotate these rewrites with a FIXME comment. It is up to
// the user to change the usage of this literal to cope with it now being a
// pointer.
dstutil.Apply(c.Node(), func(cur *dstutil.Cursor) bool {
lit, ok := cur.Node().(*dst.CompositeLit)
if !ok {
return true
}
if isAddr(cur.Parent()) {
// The code already takes the address of this composite literal
// (e.g. &pb.M2{literal}), resulting in a pointer. Skip.
return true
}
typ := c.typeOf(lit)
if _, ok := typ.Underlying().(*types.Pointer); ok {
// The composite literal implicitly is already a pointer type
// (e.g. []*pb.M2{{literal}}). Skip.
return true
}
if !c.shouldUpdateType(typ) {
return true
}
addCommentAbove(c.Node(), lit, "// DO NOT SUBMIT: fix callers to work with a pointer (go/goprotoapi-findings#message-value)")
c.numUnsafeRewritesByReason[IncompleteRewrite]++
cur.Replace(addr(c, lit))
return true
}, nil)
return false
}
// nonPtrProtoStructDef returns true if stmt is of the form:
//
// x := pb.M{...}
//
// It also returns the assigned identifier ('x' in this example) if that's the case.
func nonPtrProtoStructDef(c *cursor, stmt dst.Stmt) (*dst.Ident, bool) {
// Not handled: multi-assign
a, ok := stmt.(*dst.AssignStmt)
if !ok || a.Tok != token.DEFINE || len(a.Lhs) != 1 {
return nil, false
}
lit, ok := a.Rhs[0].(*dst.CompositeLit)
if !ok {
return nil, false
}
ident, ok := a.Lhs[0].(*dst.Ident)
if !ok {
// Shouldn't happen as only identifiers can be on LHS of a definition.
return nil, false
}
if !c.shouldUpdateType(c.typeOf(lit)) {
return nil, false
}
return ident, true
}
// replaceWithPtr modifies (if possible) stmts so that they work if ident was a pointer type (it
// must be a non-pointer). For example, for identifier 'x' the following statements:
//
// x.S = nil
// _ = x.GetS()
// f(&x)
//
// would be changed to:
//
// x.S = nil
// _ = x.GetS()
// f(x) // notice dropped "&"
//
// Changes are only made if they are possible. Any usage of the identifier for shallow copies
// prevents rewrites. For example, in the following example statements won't be changed:
//
// f(&x)
// E
//
// where E is any of
//
// x = pb.M{}
// y = x
// *p = x
// f(x)
// []pb.M{x}
//
// or anything else that's not a field access, a method call, &x, or an argument to a printer function.
//
// replaceWithPtr returns true if it modifies stmts.
func replaceWithPtr(c *cursor, ident *dst.Ident, stmts []dst.Stmt) bool {
b := &dst.BlockStmt{List: stmts}
canRewrite := true
dstutil.Apply(b, func(cur *dstutil.Cursor) bool {
n, ok := cur.Node().(*dst.Ident)
if !ok {
return true
}
if c.objectOf(n) != c.objectOf(ident) {
return false
}
if isAddr(cur.Parent()) {
return false
}
// If n's parent is a selector expression then we have one of those situations:
//
// n.Field
// n.Func()
//
// This works whether n is a pointer or a value.
if _, ok := cur.Parent().(*dst.SelectorExpr); ok {
return false
}
// Address expressions are ok. We have to drop "&" once ident becomes a pointer.
if e, ok := cur.Parent().(*dst.UnaryExpr); ok && e.Op == token.AND {
return false
}
// If n is an argument to t.Errorf or other function used for printing then it's ok to make
// n a pointer. That will change output format from:
// {F:...}
// to
// &{F:...}
// A shallow copy like this is incorrect. It's rare to depend on exact log statement though (I'm sure it happens).
if c.lvl.ge(Yellow) && c.looksLikePrintf(cur.Parent()) {
return false
}
// Otherwise we found a use of ident that can't be changed to a pointer.
canRewrite = false
return false
}, nil)
if !canRewrite {
return false
}
// We can change the type of ident to be a pointer. We must replace all "&ident" expressions with "ident".
dstutil.Apply(b, func(cur *dstutil.Cursor) bool {
ue, ok := cur.Node().(*dst.UnaryExpr)
if !ok || ue.Op != token.AND {
return true
}
v, ok := ue.X.(*dst.Ident)
if !ok || c.objectOf(ident) != c.objectOf(v) {
return true
}
cur.Replace(v)
return false
}, nil)
return true
}