blob: a953f13119970871f52df3be6d1d39753502b0cc [file] [log] [blame] [edit]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fix
import (
"fmt"
"go/token"
"go/types"
"slices"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/dstutil"
)
// oneofSwitchPost handles type switches of protocol buffer oneofs.
func oneofSwitchPost(c *cursor) bool {
stmt, ok := c.Node().(*dst.TypeSwitchStmt)
if !ok {
c.Logf("ignoring %T (looking for TypeSwitchStmt)", c.Node())
return true
}
initStmt := stmt.Init
var typeAssert *dst.TypeAssertExpr
// oneofIdent is the variable introduced for the expression under the type
// assertion (e.g. "x" in "x := m.F.(type)").
//
// We keep track of it so that we can replace it with accesses to m later on.
//
// This can be nil as "m.F.(type)" is also valid (type-switch without
// assignment), in which case we don't have to rewrite any references to the
// oneof in type-switch clauses.
var oneofIdent *dst.Ident
switch a := stmt.Assign.(type) {
case *dst.AssignStmt: // "switch x := m.F.(type) {"
if len(a.Lhs) != 1 || len(a.Rhs) != 1 {
c.Logf("ignoring AssignStmt with len(lhs)=%d, len(rhs)=%d (looking for 1, 1)", len(a.Lhs), len(a.Rhs))
return true
}
ident, ok := a.Lhs[0].(*dst.Ident)
if !ok {
c.Logf("ignoring Lhs=%T (looking for Ident)", a.Lhs[0])
return true
}
oneofIdent = ident
typeAssert, ok = a.Rhs[0].(*dst.TypeAssertExpr)
if !ok {
c.Logf("ignoring Rhs=%T (looking for TypeAssertExpr)", a.Rhs[0])
return true
}
case *dst.ExprStmt: // "switch m.F.(type) {"
typeAssert, ok = a.X.(*dst.TypeAssertExpr)
if !ok {
c.Logf("ignoring TypeSwitchStmt.Assign.X=%T (looking for TypeAssertExpr)", a.X)
return true
}
default:
c.Logf("ignoring TypeSwitchStmt.Assign=%T (looking for TypeAssertExpr or AssignStmt)", a)
return true
}
var oneofScope *types.Scope
if oneofIdent != nil {
ooi, ok := c.typesInfo.astMap[oneofIdent]
if !ok {
panic(fmt.Sprintf("failed to retrieve ast.Node for dst.Node: %v", oneofIdent))
}
oneofScope = c.pkg.TypePkg.Scope().Innermost(ooi.Pos())
}
// Below we assume how the "Which" method is named based on either the name of
// the corresponding field or "Get" method. This isn't necessarily sound in
// the presence of naming conflicts but we don't want to repeat that
// complicated logic here.
var whichSig *types.Signature
var whichName string
var recv dst.Expr // "X" in "X.F.(type)" or "X.GetF().(type)"
// We track oneof field and getter name here to keep the name-related logic
// in one place, at the top. We use those to rewrite oneof field access
// through field/getter later on. Ideally, we would use object identity
// instead of comparison by name, but we can't get that without messing with
// names, so we might as well got for name comaprisons.
var oneofFieldName string // "F" in "X.F.(type)" or "X.GetF().(type)"
var oneofGetName string // "GetF" corresponding to oneofFieldName
// Find the WhichF() method that should replace "m.F.(type)" in
// switch m.F.(type) {
// =>
// switch m.WhichF() {
if field, ok := c.trackedProtoFieldSelector(typeAssert.X); ok {
if !isOneof(c.typeOf(field)) {
c.Logf("ignoring field %v because it is not a oneof", field)
return true
}
oneofFieldName = field.Sel.Name
oneofGetName = "Get" + field.Sel.Name
whichName = "Which" + oneofFieldName
c.Logf("rewriting TypeAssertExpr on oneof field %s to %s CallExpr", oneofFieldName, whichName)
whichSig = whichOneofSignature(c.typeOf(field.X), whichName)
recv = field.X
}
// Find the WhichF() method that should replace "m.GetF().(type)" in
// switch m.GetF().(type) {
// =>
// switch m.WhichF() {
if call, ok := typeAssert.X.(*dst.CallExpr); ok {
sel, ok := call.Fun.(*dst.SelectorExpr)
if !ok || !isOneof(c.typeOf(call)) || !strings.HasPrefix(sel.Sel.Name, "Get") || !c.shouldUpdateType(c.typeOf(sel.X)) {
// Even if this is a call returning a oneof, we can't do anything about
// it. We don't add DO_NOT_SUBMIT tag because it's too easy to
// accidentally add it to unrelated type switches here.
c.Logf("ignoring: TypeAssertExpr recondition is not met")
return true
}
recv = sel.X
oneofGetName = sel.Sel.Name
oneofFieldName = oneofGetName[len("Get"):]
whichName = "Which" + oneofFieldName
c.Logf("rewriting TypeAssertExpr on oneof field %s to %s CallExpr", oneofFieldName, whichName)
whichSig = whichOneofSignature(c.typeOf(sel.X), whichName)
}
// Not a oneof access (call or field) or we couldn't find the "Which" method.
if whichSig == nil {
c.Logf("ignoring: cannot find Which${Field} method")
return true
}
// If recv may have side effects, then move it to the init statement if
// possible (bail otherwise) in order to avoid cloning side effects into
// the body of the type-switch.
if !c.isSideEffectFree(recv) {
if initStmt != nil {
if c.lvl.ge(Red) {
c.numUnsafeRewritesByReason[IncompleteRewrite]++
markMissingRewrite(stmt, "type switch with side effects and init statement")
}
c.Logf("ignoring: cannot move init statement with side effects")
return true
}
newVar := dst.NewIdent("xmsg")
c.setType(newVar, c.typeOf(recv))
initStmt = &dst.AssignStmt{
Lhs: []dst.Expr{newVar},
Tok: token.DEFINE,
Rhs: []dst.Expr{recv},
}
recv = cloneIdent(c, newVar)
}
// Construct the "recv.WhichOneofField()" method call.
whichMethod := &dst.SelectorExpr{
X: recv,
Sel: dst.NewIdent(whichName),
}
c.setType(whichMethod, whichSig)
c.setType(whichMethod.Sel, whichSig)
whichCall := &dst.CallExpr{Fun: whichMethod}
c.setType(whichCall, whichSig.Results().At(0).Type())
// First, verify that we can do necessary rewrites to the entire type switch
// statement.
//
// In "switch x := m.F.(type) {" and similar, we replace all references
// to "x" (the oneof object, "oneofIdent") with method calls on "m".
definedVarUsedInDefaultCase := false
if oneofIdent != nil {
var maybeUnsafe bool
if !c.lvl.ge(Red) {
// Unfortunately "x" in "x := X.(type)" has no identity, so we must
// rely on name matching to replace all instances of "x" with
// something else. We want to be careful: if any clause refers to
// more than one "x" ("x" has identity in non-default clauses) that
// could be a oneof then we don't do anything.
//
// If any non-default clause refers to "x" (that could be a oneof)
// not in the context of a selector (i.e. it's not "x.F") then we
// also don't do anything. Oneofs are not a "thing" and can't be
// referenced like this in the new API.
//
// The sole exception is printf-like calls in the default clause
// (see below), where we replace the oneof reference with a "Which"
// method call.
for _, clause := range stmt.Body.List {
var clauseIdent *dst.Ident
dstutil.Apply(clause, func(cur *dstutil.Cursor) bool {
ident, ok := cur.Node().(*dst.Ident)
if !ok || ident.Name != oneofIdent.Name {
return true
}
isDefaultClause := clause.(*dst.CaseClause).List == nil // Ident has no type in "default".
if !isDefaultClause && !isOneofWrapper(c, ident) {
return true
}
// Usually we don't allow references to the "oneof" itself, but
// we make an exception for print-like statements as
// it's comment to say: `fmt.Errorf("unknown case %v", oneofField)`.
if c.looksLikePrintf(cur.Parent()) {
return true
}
if _, ok := cur.Parent().(*dst.SelectorExpr); !ok {
maybeUnsafe = true
return false
}
if clauseIdent == nil {
clauseIdent = ident
return true
}
if c.objectOf(clauseIdent) == c.objectOf(ident) {
return true
}
maybeUnsafe = true
return false
}, nil)
if maybeUnsafe {
c.Logf("ignoring: cannot rewrite oneof usage safely")
return true
}
}
}
// Then do the necessary rewrites.
for _, clause := range stmt.Body.List {
dstutil.Apply(clause, func(cur *dstutil.Cursor) bool {
n := cur.Node()
// First, handle direct references to the oneof itself,
// in default clauses:
// fmt.Errorf("Unknown case %v", oneofField)
// =>
// fmt.Errorf("Unknown case %v", m.WhichOneofField())
//
// This works well, because oneof cases have String() method.
//
// Note that we do the rewrite regardless of what's in the format
// string (we don't want to get into the business of parsing format
// strings now), so we risk:
// fmt.Errorf("Unknown case %T", m.WhichOneofField())
// Which is not ideal, but not unreasonable.
if c.looksLikePrintf(cur.Parent()) {
// rewrite the formatting verb (if any) to %v:
// fmt.Errorf("%T", m2.OneofField) => fmt.Errorf("%v", m2.OneofField)
// The other part of the expression is rewritten below.
// This function does not handle escaped '%' character.
// Implementing a full formatting string parser is out of scope here.
rewriteVerb := func() {
call := cur.Parent().(*dst.CallExpr)
argIndex := slices.Index(call.Args, cur.Node().(dst.Expr))
if argIndex == -1 {
panic(fmt.Sprintf("could not find argument %v in call expression %v. Did you call rewriteVerb after replacing the Node?", cur.Node(), call))
}
for i, p := range call.Args {
slit, ok := p.(*dst.BasicLit)
if !ok {
continue
}
if slit.Kind != token.STRING {
continue
}
numFormattedArg := argIndex - i
idx := -1
for i, r := range slit.Value {
if r == '%' {
numFormattedArg--
if numFormattedArg == 0 {
idx = i
break
}
}
}
// There are not enough formatting verbs.
if idx == -1 {
return
}
if idx == len(slit.Value) || slit.Value[idx+1] == 'v' {
break
}
// No need to clone as a literal cannot be shared between Nodes.
slit.Value = slit.Value[:idx] + "%v" + slit.Value[idx+2:]
}
}
// switch m := f(); oneofField := m.OneofField.(type) {...
// default:
// return fmt.Errorf("Unknown case %v", oneofField)
// =>
// switch m := f(); oneofField := m.OneofField.(type) {...
// default:
// return fmt.Errorf("Unknown case %v", m.WhichOneofField())
if ident, ok := n.(*dst.Ident); ok && ident.Name == oneofIdent.Name {
rewriteVerb()
definedVarUsedInDefaultCase = true
if initStmt == nil {
return true
}
c.Logf("rewriting usage of %q", oneofIdent.Name)
if maybeUnsafe {
c.numUnsafeRewritesByReason[MaybeSemanticChange]++
}
cur.Replace(cloneSelectorCallExpr(c, whichCall))
return true
}
// fmt.Errorf("Unknown case %v", m.OneofField)
// =>
// fmt.Errorf("Unknown case %v", m.WhichOneofField())
if field, ok := c.trackedProtoFieldSelector(n); ok && isOneof(c.typeOf(field)) && field.Sel.Name == oneofFieldName {
c.Logf("rewriting usage of %q", oneofIdent.Name)
if maybeUnsafe {
c.numUnsafeRewritesByReason[MaybeSemanticChange]++
}
rewriteVerb()
cur.Replace(cloneSelectorCallExpr(c, whichCall))
return true
}
// fmt.Errorf("Unknown case %v", m.GetOneofField())
// =>
// fmt.Errorf("Unknown case %v", m.WhichOneofField())
if call, ok := n.(*dst.CallExpr); ok && isOneof(c.typeOf(call)) {
if sel, ok := call.Fun.(*dst.SelectorExpr); ok && sel.Sel.Name == oneofGetName {
c.Logf("rewriting usage of %q", oneofIdent.Name)
if maybeUnsafe {
c.numUnsafeRewritesByReason[MaybeSemanticChange]++
}
rewriteVerb()
cur.Replace(cloneSelectorCallExpr(c, whichCall))
return true
}
}
}
// Handle "oneofField.F" => "recv.F"
// Subsequent passes may then do more rewrites. For example:
// "recv.F" => "recv.{Get,Set,Has,Clear}F()"
sel, ok := n.(*dst.SelectorExpr)
if !ok {
c.Logf("ignoring usage %v of type %T (looking for SelectorExpr)", n, n)
return true
}
ident, ok := sel.X.(*dst.Ident)
if !ok {
c.Logf("ignoring usage %v of type SelectorExpr.X.%T (looking for SelectorExpr.X.Ident)", sel, sel.X)
return true
}
// We cannot do an object based comparison here
// because variable declared in the initializer
// of a switch don't have an associated object
// (or there is a bug in our tooling that removes
// this association).
if ident.Name != oneofIdent.Name {
c.Logf("ignoring %v because it is not using %v", sel, oneofIdent)
return true
}
isDefaultClause := clause.(*dst.CaseClause).List == nil
if !isDefaultClause && !isOneofWrapper(c, ident) {
c.Logf("ignoring usage %v: not in default clause and not a oneof wrapper", sel)
return true
}
// Check if the name is shadwed in a nested block:
//
// switch oneofField := m2.OneofField.(type) {
// case *pb.M2_StringOneof:
// {
// oneofField := &O{}
// _ = oneofField.StringOneof
// }
// }
//
// For switch-case statements there is one scope for the switch
// which has child scopes, one for each case clause. A variable
// defined in the switch condition (oneofField in the example)
// is not part of the parent (switch) scope but is defined in
// every child (case) scope.
// This means to check if an identifier references the variable
// defined by the switch condition, we must check if the
// referenced object is any of the child (case) scopes.
if oneofScope == nil {
c.Logf("ignoring usage %v: no scope for oneof found", n)
return true
}
astNode, ok := c.typesInfo.astMap[ident]
if !ok {
panic(fmt.Sprintf("failed to retrieve ast.Node for dst.Node: %p", ident))
}
scope := c.pkg.TypePkg.Scope().Innermost(astNode.Pos())
s, _ := scope.LookupParent(ident.Name, astNode.Pos())
if s == nil {
panic(fmt.Sprintf("invalid package: cannot resolve %v", astNode))
}
found := false
for i := 0; i < oneofScope.NumChildren(); i++ {
if oneofScope.Child(i) == s {
found = true
break
}
}
// The variable defined in the switch is shadowed by another
// definition.
if !found {
c.Logf("ignoring usage %v: does not reference oneof type switch variable", astNode)
return true
}
c.Logf("rewriting usage %v", recv)
switch recv := recv.(type) {
case *dst.Ident:
sel.X = cloneIdent(c, recv)
case *dst.IndexExpr:
sel.X = cloneIndexExpr(c, recv)
case *dst.SelectorExpr:
sel.X = cloneSelectorExpr(c, recv)
case *dst.CallExpr:
sel.X = cloneSelectorCallExpr(c, recv)
default:
panic(fmt.Sprintf("unsupported receiver AST type %T in oneof type switch", recv))
}
if maybeUnsafe {
c.numUnsafeRewritesByReason[MaybeSemanticChange]++
}
return true
}, nil)
}
}
// Rewrite clauses:
// case *pb.M_Oneof:
// =>
// case pb.M_Oneof_case:
for _, clause := range stmt.Body.List {
cl, ok := clause.(*dst.CaseClause)
if !ok {
continue
}
for i, typ := range cl.List {
if ident, ok := typ.(*dst.Ident); ok && ident.Name == "nil" {
zero := &dst.BasicLit{Kind: token.INT, Value: "0"}
c.setType(zero, types.Typ[types.Int32])
c.Logf("rewriting nil-case to 0-case in %v", cl)
cl.List[i] = zero
continue
}
se, ok := typ.(*dst.StarExpr)
if !ok {
c.Logf("skipping case %v of type %T (looking for StarExpr)", cl, typ)
continue
}
sel, ok := se.X.(*dst.SelectorExpr)
if !ok {
c.Logf("skipping case %v of type %T (looking for SelectorExpr)", cl, se)
continue
}
// Split the oneof wrapper type name into its parts:
parts := strings.Split(strings.TrimRight(sel.Sel.Name, "_"), "_")
// parts[0] is the parent of the oneof field
// parts[len(parts)-1] is the oneof field
if len(parts) < 2 {
c.Logf("skipping case %v, sel %q does not split into parent_field", cl, sel.Sel.Name)
continue
}
field := parts[len(parts)-1]
suffix := "_case"
// There are two cases in which the wrapper type name ends in an
// underscore:
//
// 1. The field name itself conflicts. This means the field was
// called reset, marshal, etc., conflicting with the proto
// generated code method names. In this case, the _case constant
// will use two underscores.
//
// 2. The field name conflicts with another name in the .proto file.
// In this case, the _case constant will not use two underscores.
if strings.HasSuffix(sel.Sel.Name, "_") && !conflictingNames[field] {
suffix = "case"
}
sel.Sel.Name += suffix
sel.Decs.Before = se.Decs.Before
sel.Decs.Start = append(sel.Decs.Start, se.Decs.Start...)
sel.Decs.End = append(sel.Decs.End, se.Decs.End...)
sel.Decs.After = se.Decs.After
c.Logf("rewriting case %+#v", cl.List[i])
cl.List[i] = sel
}
}
c.Logf("rewriting SwitchStmt")
// The TypeSwitchStmt initializes a local variable that we need to keep
// and there is no init statement.
// switch oneofField := m.OneofField.(type) {...
// default:
// return fmt.Errorf("Unknown case %v", oneofField)
// =>
// switch oneofField := m.WhichOneof() {...
// default:
// return fmt.Errorf("Unknown case %v", oneofField)
// If it were not used we would rewrite it to:
// switch oneofField := m.OneofField.(type) {
// ...
// default:
// return fmt.Errorf("Unknown case")
// =>
// switch m.WhichOneof() {...
// ...
// default:
// return fmt.Errorf("Unknown case")
var tag dst.Expr = cloneSelectorCallExpr(c, whichCall)
if definedVarUsedInDefaultCase && initStmt == nil {
if oneofIdent == nil {
panic("identNil")
}
c.setType(oneofIdent, types.Typ[types.Invalid])
initStmt = &dst.AssignStmt{
Lhs: []dst.Expr{
cloneIdent(c, oneofIdent),
},
Rhs: []dst.Expr{tag},
Tok: token.DEFINE,
}
tag = cloneIdent(c, oneofIdent)
c.setType(tag, c.typeOf(oneofIdent))
}
// Change the type from TypeSwitchStmt to SwitchStmt.
c.Replace(&dst.SwitchStmt{
Init: initStmt,
Tag: tag,
Body: stmt.Body,
Decs: dst.SwitchStmtDecorations{
NodeDecs: stmt.Decs.NodeDecs,
Switch: stmt.Decs.Switch,
Init: stmt.Decs.Init,
Tag: stmt.Decs.Assign,
},
})
return true
}
// whichOneofSignature returns the signature type of the method, of type t, with
// the given name. It must have "Which" prefix (i.e. it is the method for
// getting the oneof type). It returns an in-band nil (instead of a separate
// "ok" bool) if such method does not exist in order to simplify the caller code
// above.
func whichOneofSignature(t types.Type, name string) *types.Signature {
if !strings.HasPrefix(name, "Which") {
panic(fmt.Sprintf("whichOneofSignature called with %q; must have 'Which' prefix", name))
}
ptr, ok := types.Unalias(t).(*types.Pointer)
if !ok {
return nil
}
typ, ok := types.Unalias(ptr.Elem()).(*types.Named)
if !ok {
return nil
}
var m *types.Func
for i := 0; i < typ.NumMethods(); i++ {
m = typ.Method(i)
if m.Name() == name {
break
}
}
if m == nil {
return nil
}
sig := m.Type().(*types.Signature)
if sig.Results().Len() != 1 {
return nil
}
return sig
}