blob: b93ade010650babf3ad1f48f031ce58a3e4bf490 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fillswitch
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/inspector"
)
// Diagnose computes diagnostics for switch statements with missing cases
// overlapping with the provided start and end position.
//
// If either start or end is invalid, the entire package is inspected.
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
var diags []analysis.Diagnostic
nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)}
inspect.Preorder(nodeFilter, func(n ast.Node) {
if start.IsValid() && n.End() < start ||
end.IsValid() && n.Pos() > end {
return // non-overlapping
}
var fix *analysis.SuggestedFix
switch n := n.(type) {
case *ast.SwitchStmt:
fix = suggestedFixSwitch(n, pkg, info)
case *ast.TypeSwitchStmt:
fix = suggestedFixTypeSwitch(n, pkg, info)
}
if fix == nil {
return
}
diags = append(diags, analysis.Diagnostic{
Message: fix.Message,
Pos: n.Pos(),
End: n.Pos() + token.Pos(len("switch")),
SuggestedFixes: []analysis.SuggestedFix{*fix},
})
})
return diags
}
func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix {
if hasDefaultCase(stmt.Body) {
return nil
}
namedType := namedTypeFromTypeSwitch(stmt, info)
if namedType == nil {
return nil
}
existingCases := caseTypes(stmt.Body, info)
// Gather accessible package-level concrete types
// that implement the switch interface type.
scope := namedType.Obj().Pkg().Scope()
var buf bytes.Buffer
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if tname, ok := obj.(*types.TypeName); !ok || tname.IsAlias() {
continue // not a defined type
}
if types.IsInterface(obj.Type()) {
continue
}
samePkg := obj.Pkg() == pkg
if !samePkg && !obj.Exported() {
continue // inaccessible
}
var key caseType
if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
key.named = obj.Type().(*types.Named)
} else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) {
key.named = obj.Type().(*types.Named)
key.ptr = true
}
if key.named != nil {
if existingCases[key] {
continue
}
if buf.Len() > 0 {
buf.WriteString("\t")
}
buf.WriteString("case ")
if key.ptr {
buf.WriteByte('*')
}
if p := key.named.Obj().Pkg(); p != pkg {
// TODO: use the correct package name when the import is renamed
buf.WriteString(p.Name())
buf.WriteByte('.')
}
buf.WriteString(key.named.Obj().Name())
buf.WriteString(":\n")
}
}
if buf.Len() == 0 {
return nil
}
switch assign := stmt.Assign.(type) {
case *ast.AssignStmt:
addDefaultCase(&buf, namedType, assign.Lhs[0])
case *ast.ExprStmt:
if assert, ok := assign.X.(*ast.TypeAssertExpr); ok {
addDefaultCase(&buf, namedType, assert.X)
}
}
return &analysis.SuggestedFix{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - token.Pos(len("}")),
End: stmt.End() - token.Pos(len("}")),
NewText: buf.Bytes(),
}},
}
}
func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix {
if hasDefaultCase(stmt.Body) {
return nil
}
namedType, ok := info.TypeOf(stmt.Tag).(*types.Named)
if !ok {
return nil
}
existingCases := caseConsts(stmt.Body, info)
// Gather accessible named constants of the same type as the switch value.
scope := namedType.Obj().Pkg().Scope()
var buf bytes.Buffer
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if c, ok := obj.(*types.Const); ok &&
(obj.Pkg() == pkg || obj.Exported()) && // accessible
types.Identical(obj.Type(), namedType.Obj().Type()) &&
!existingCases[c] {
if buf.Len() > 0 {
buf.WriteString("\t")
}
buf.WriteString("case ")
if c.Pkg() != pkg {
buf.WriteString(c.Pkg().Name())
buf.WriteByte('.')
}
buf.WriteString(c.Name())
buf.WriteString(":\n")
}
}
if buf.Len() == 0 {
return nil
}
addDefaultCase(&buf, namedType, stmt.Tag)
return &analysis.SuggestedFix{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - token.Pos(len("}")),
End: stmt.End() - token.Pos(len("}")),
NewText: buf.Bytes(),
}},
}
}
func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) {
var dottedBuf bytes.Buffer
// writeDotted emits a dotted path a.b.c.
var writeDotted func(e ast.Expr) bool
writeDotted = func(e ast.Expr) bool {
switch e := e.(type) {
case *ast.SelectorExpr:
if !writeDotted(e.X) {
return false
}
dottedBuf.WriteByte('.')
dottedBuf.WriteString(e.Sel.Name)
return true
case *ast.Ident:
dottedBuf.WriteString(e.Name)
return true
}
return false
}
buf.WriteString("\tdefault:\n")
typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name())
if writeDotted(expr) {
// Switch tag expression is a dotted path.
// It is safe to re-evaluate it in the default case.
format := fmt.Sprintf("unexpected %s: %%#v", typeName)
fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, dottedBuf.String())
} else {
// Emit simpler message, without re-evaluating tag expression.
fmt.Fprintf(buf, "\t\tpanic(%q)\n\t", "unexpected "+typeName)
}
}
func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named {
switch assign := stmt.Assign.(type) {
case *ast.ExprStmt:
if typ, ok := assign.X.(*ast.TypeAssertExpr); ok {
if named, ok := info.TypeOf(typ.X).(*types.Named); ok {
return named
}
}
case *ast.AssignStmt:
if typ, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok {
if named, ok := info.TypeOf(typ.X).(*types.Named); ok {
return named
}
}
}
return nil
}
func hasDefaultCase(body *ast.BlockStmt) bool {
for _, clause := range body.List {
if len(clause.(*ast.CaseClause).List) == 0 {
return true
}
}
return false
}
func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool {
out := map[*types.Const]bool{}
for _, stmt := range body.List {
for _, e := range stmt.(*ast.CaseClause).List {
if info.Types[e].Value == nil {
continue // not a constant
}
if sel, ok := e.(*ast.SelectorExpr); ok {
e = sel.Sel // replace pkg.C with C
}
if e, ok := e.(*ast.Ident); ok {
if c, ok := info.Uses[e].(*types.Const); ok {
out[c] = true
}
}
}
}
return out
}
type caseType struct {
named *types.Named
ptr bool
}
func caseTypes(body *ast.BlockStmt, info *types.Info) map[caseType]bool {
out := map[caseType]bool{}
for _, stmt := range body.List {
for _, e := range stmt.(*ast.CaseClause).List {
if tv, ok := info.Types[e]; ok && tv.IsType() {
t := tv.Type
ptr := false
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
ptr = true
}
if named, ok := t.(*types.Named); ok {
out[caseType{named, ptr}] = true
}
}
}
}
return out
}