go/cfg: record block kind and associated statement
This change adds Block.Kind, and enumeration of possible
control structures that give rise to blocks, and Block.Stmt,
which records the syntax node associated with it. This
allows clients to reconstruct the control more accurately,
and compute positions.
It also adds a CFG.Digraph method to dump the graph
in AT&T GraphViz form for debugging convenience, and
a simple helper command to call this function.
Also, stop using ast.Object.
Fixes golang/go#53367
Change-Id: I19557d636eb4c620899463c489411c360540289b
Reviewed-on: https://go-review.googlesource.com/c/tools/+/555255
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Tim King <taking@google.com>
diff --git a/go/cfg/builder.go b/go/cfg/builder.go
index dad6a44..ac4d63c 100644
--- a/go/cfg/builder.go
+++ b/go/cfg/builder.go
@@ -16,8 +16,8 @@
cfg *CFG
mayReturn func(*ast.CallExpr) bool
current *Block
- lblocks map[*ast.Object]*lblock // labeled blocks
- targets *targets // linked stack of branch targets
+ lblocks map[string]*lblock // labeled blocks
+ targets *targets // linked stack of branch targets
}
func (b *builder) stmt(_s ast.Stmt) {
@@ -42,7 +42,7 @@
b.add(s)
if call, ok := s.X.(*ast.CallExpr); ok && !b.mayReturn(call) {
// Calls to panic, os.Exit, etc, never return.
- b.current = b.newBlock("unreachable.call")
+ b.current = b.newBlock(KindUnreachable, s)
}
case *ast.DeclStmt:
@@ -57,7 +57,7 @@
}
case *ast.LabeledStmt:
- label = b.labeledBlock(s.Label)
+ label = b.labeledBlock(s.Label, s)
b.jump(label._goto)
b.current = label._goto
_s = s.Stmt
@@ -65,7 +65,7 @@
case *ast.ReturnStmt:
b.add(s)
- b.current = b.newBlock("unreachable.return")
+ b.current = b.newBlock(KindUnreachable, s)
case *ast.BranchStmt:
b.branchStmt(s)
@@ -77,11 +77,11 @@
if s.Init != nil {
b.stmt(s.Init)
}
- then := b.newBlock("if.then")
- done := b.newBlock("if.done")
+ then := b.newBlock(KindIfThen, s)
+ done := b.newBlock(KindIfDone, s)
_else := done
if s.Else != nil {
- _else = b.newBlock("if.else")
+ _else = b.newBlock(KindIfElse, s)
}
b.add(s.Cond)
b.ifelse(then, _else)
@@ -128,7 +128,7 @@
switch s.Tok {
case token.BREAK:
if s.Label != nil {
- if lb := b.labeledBlock(s.Label); lb != nil {
+ if lb := b.labeledBlock(s.Label, nil); lb != nil {
block = lb._break
}
} else {
@@ -139,7 +139,7 @@
case token.CONTINUE:
if s.Label != nil {
- if lb := b.labeledBlock(s.Label); lb != nil {
+ if lb := b.labeledBlock(s.Label, nil); lb != nil {
block = lb._continue
}
} else {
@@ -155,14 +155,14 @@
case token.GOTO:
if s.Label != nil {
- block = b.labeledBlock(s.Label)._goto
+ block = b.labeledBlock(s.Label, nil)._goto
}
}
- if block == nil {
- block = b.newBlock("undefined.branch")
+ if block == nil { // ill-typed (e.g. undefined label)
+ block = b.newBlock(KindUnreachable, s)
}
b.jump(block)
- b.current = b.newBlock("unreachable.branch")
+ b.current = b.newBlock(KindUnreachable, s)
}
func (b *builder) switchStmt(s *ast.SwitchStmt, label *lblock) {
@@ -172,7 +172,7 @@
if s.Tag != nil {
b.add(s.Tag)
}
- done := b.newBlock("switch.done")
+ done := b.newBlock(KindSwitchDone, s)
if label != nil {
label._break = done
}
@@ -188,13 +188,13 @@
for i, clause := range s.Body.List {
body := fallthru
if body == nil {
- body = b.newBlock("switch.body") // first case only
+ body = b.newBlock(KindSwitchCaseBody, clause) // first case only
}
// Preallocate body block for the next case.
fallthru = done
if i+1 < ncases {
- fallthru = b.newBlock("switch.body")
+ fallthru = b.newBlock(KindSwitchCaseBody, s.Body.List[i+1])
}
cc := clause.(*ast.CaseClause)
@@ -208,7 +208,7 @@
var nextCond *Block
for _, cond := range cc.List {
- nextCond = b.newBlock("switch.next")
+ nextCond = b.newBlock(KindSwitchNextCase, cc)
b.add(cond) // one half of the tag==cond condition
b.ifelse(body, nextCond)
b.current = nextCond
@@ -247,7 +247,7 @@
b.add(s.Assign)
}
- done := b.newBlock("typeswitch.done")
+ done := b.newBlock(KindSwitchDone, s)
if label != nil {
label._break = done
}
@@ -258,10 +258,10 @@
default_ = cc
continue
}
- body := b.newBlock("typeswitch.body")
+ body := b.newBlock(KindSwitchCaseBody, cc)
var next *Block
for _, casetype := range cc.List {
- next = b.newBlock("typeswitch.next")
+ next = b.newBlock(KindSwitchNextCase, cc)
// casetype is a type, so don't call b.add(casetype).
// This block logically contains a type assertion,
// x.(casetype), but it's unclear how to represent x.
@@ -300,7 +300,7 @@
}
}
- done := b.newBlock("select.done")
+ done := b.newBlock(KindSelectDone, s)
if label != nil {
label._break = done
}
@@ -312,8 +312,8 @@
defaultBody = &clause.Body
continue
}
- body := b.newBlock("select.body")
- next := b.newBlock("select.next")
+ body := b.newBlock(KindSelectCaseBody, clause)
+ next := b.newBlock(KindSelectAfterCase, clause)
b.ifelse(body, next)
b.current = body
b.targets = &targets{
@@ -358,15 +358,15 @@
if s.Init != nil {
b.stmt(s.Init)
}
- body := b.newBlock("for.body")
- done := b.newBlock("for.done") // target of 'break'
- loop := body // target of back-edge
+ body := b.newBlock(KindForBody, s)
+ done := b.newBlock(KindForDone, s) // target of 'break'
+ loop := body // target of back-edge
if s.Cond != nil {
- loop = b.newBlock("for.loop")
+ loop = b.newBlock(KindForLoop, s)
}
cont := loop // target of 'continue'
if s.Post != nil {
- cont = b.newBlock("for.post")
+ cont = b.newBlock(KindForPost, s)
}
if label != nil {
label._break = done
@@ -414,12 +414,12 @@
// jump loop
// done: (target of break)
- loop := b.newBlock("range.loop")
+ loop := b.newBlock(KindRangeLoop, s)
b.jump(loop)
b.current = loop
- body := b.newBlock("range.body")
- done := b.newBlock("range.done")
+ body := b.newBlock(KindRangeBody, s)
+ done := b.newBlock(KindRangeDone, s)
b.ifelse(body, done)
b.current = body
@@ -461,14 +461,19 @@
// labeledBlock returns the branch target associated with the
// specified label, creating it if needed.
-func (b *builder) labeledBlock(label *ast.Ident) *lblock {
- lb := b.lblocks[label.Obj]
+func (b *builder) labeledBlock(label *ast.Ident, stmt *ast.LabeledStmt) *lblock {
+ lb := b.lblocks[label.Name]
if lb == nil {
- lb = &lblock{_goto: b.newBlock(label.Name)}
+ lb = &lblock{_goto: b.newBlock(KindLabel, nil)}
if b.lblocks == nil {
- b.lblocks = make(map[*ast.Object]*lblock)
+ b.lblocks = make(map[string]*lblock)
}
- b.lblocks[label.Obj] = lb
+ b.lblocks[label.Name] = lb
+ }
+ // Fill in the label later (in case of forward goto).
+ // Stmt may be set already if labels are duplicated (ill-typed).
+ if stmt != nil && lb._goto.Stmt == nil {
+ lb._goto.Stmt = stmt
}
return lb
}
@@ -477,11 +482,12 @@
// slice and returns it.
// It does not automatically become the current block.
// comment is an optional string for more readable debugging output.
-func (b *builder) newBlock(comment string) *Block {
+func (b *builder) newBlock(kind BlockKind, stmt ast.Stmt) *Block {
g := b.cfg
block := &Block{
- Index: int32(len(g.Blocks)),
- comment: comment,
+ Index: int32(len(g.Blocks)),
+ Kind: kind,
+ Stmt: stmt,
}
block.Succs = block.succs2[:0]
g.Blocks = append(g.Blocks, block)
diff --git a/go/cfg/cfg.go b/go/cfg/cfg.go
index e9c48d5..0166835 100644
--- a/go/cfg/cfg.go
+++ b/go/cfg/cfg.go
@@ -9,7 +9,10 @@
//
// The blocks of the CFG contain all the function's non-control
// statements. The CFG does not contain control statements such as If,
-// Switch, Select, and Branch, but does contain their subexpressions.
+// Switch, Select, and Branch, but does contain their subexpressions;
+// also, each block records the control statement (Block.Stmt) that
+// gave rise to it and its relationship (Block.Kind) to that statement.
+//
// For example, this source code:
//
// if x := f(); x != nil {
@@ -20,14 +23,14 @@
//
// produces this CFG:
//
-// 1: x := f()
+// 1: x := f() Body
// x != nil
// succs: 2, 3
-// 2: T()
+// 2: T() IfThen
// succs: 4
-// 3: F()
+// 3: F() IfElse
// succs: 4
-// 4:
+// 4: IfDone
//
// The CFG does contain Return statements; even implicit returns are
// materialized (at the position of the function's closing brace).
@@ -50,6 +53,7 @@
//
// The entry point is Blocks[0]; there may be multiple return blocks.
type CFG struct {
+ fset *token.FileSet
Blocks []*Block // block[0] is entry; order otherwise undefined
}
@@ -64,9 +68,63 @@
Succs []*Block // successor nodes in the graph
Index int32 // index within CFG.Blocks
Live bool // block is reachable from entry
+ Kind BlockKind // block kind
+ Stmt ast.Stmt // statement that gave rise to this block (see BlockKind for details)
- comment string // for debugging
- succs2 [2]*Block // underlying array for Succs
+ succs2 [2]*Block // underlying array for Succs
+}
+
+// A BlockKind identifies the purpose of a block.
+// It also determines the possible types of its Stmt field.
+type BlockKind uint8
+
+const (
+ KindInvalid BlockKind = iota // Stmt=nil
+
+ KindUnreachable // unreachable block after {Branch,Return}Stmt / no-return call ExprStmt
+ KindBody // function body BlockStmt
+ KindForBody // body of ForStmt
+ KindForDone // block after ForStmt
+ KindForLoop // head of ForStmt
+ KindForPost // post condition of ForStmt
+ KindIfDone // block after IfStmt
+ KindIfElse // else block of IfStmt
+ KindIfThen // then block of IfStmt
+ KindLabel // labeled block of BranchStmt (Stmt may be nil for dangling label)
+ KindRangeBody // body of RangeStmt
+ KindRangeDone // block after RangeStmt
+ KindRangeLoop // head of RangeStmt
+ KindSelectCaseBody // body of SelectStmt
+ KindSelectDone // block after SelectStmt
+ KindSelectAfterCase // block after a CommClause
+ KindSwitchCaseBody // body of CaseClause
+ KindSwitchDone // block after {Type.}SwitchStmt
+ KindSwitchNextCase // secondary expression of a multi-expression CaseClause
+)
+
+func (kind BlockKind) String() string {
+ return [...]string{
+ KindInvalid: "Invalid",
+ KindUnreachable: "Unreachable",
+ KindBody: "Body",
+ KindForBody: "ForBody",
+ KindForDone: "ForDone",
+ KindForLoop: "ForLoop",
+ KindForPost: "ForPost",
+ KindIfDone: "IfDone",
+ KindIfElse: "IfElse",
+ KindIfThen: "IfThen",
+ KindLabel: "Label",
+ KindRangeBody: "RangeBody",
+ KindRangeDone: "RangeDone",
+ KindRangeLoop: "RangeLoop",
+ KindSelectCaseBody: "SelectCaseBody",
+ KindSelectDone: "SelectDone",
+ KindSelectAfterCase: "SelectAfterCase",
+ KindSwitchCaseBody: "SwitchCaseBody",
+ KindSwitchDone: "SwitchDone",
+ KindSwitchNextCase: "SwitchNextCase",
+ }[kind]
}
// New returns a new control-flow graph for the specified function body,
@@ -82,7 +140,7 @@
mayReturn: mayReturn,
cfg: new(CFG),
}
- b.current = b.newBlock("entry")
+ b.current = b.newBlock(KindBody, body)
b.stmt(body)
// Compute liveness (reachability from entry point), breadth-first.
@@ -110,7 +168,15 @@
}
func (b *Block) String() string {
- return fmt.Sprintf("block %d (%s)", b.Index, b.comment)
+ return fmt.Sprintf("block %d (%s)", b.Index, b.comment(nil))
+}
+
+func (b *Block) comment(fset *token.FileSet) string {
+ s := b.Kind.String()
+ if fset != nil && b.Stmt != nil {
+ s = fmt.Sprintf("%s@L%d", s, fset.Position(b.Stmt.Pos()).Line)
+ }
+ return s
}
// Return returns the return statement at the end of this block if present, nil
@@ -129,7 +195,7 @@
func (g *CFG) Format(fset *token.FileSet) string {
var buf bytes.Buffer
for _, b := range g.Blocks {
- fmt.Fprintf(&buf, ".%d: # %s\n", b.Index, b.comment)
+ fmt.Fprintf(&buf, ".%d: # %s\n", b.Index, b.comment(fset))
for _, n := range b.Nodes {
fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n))
}
@@ -145,6 +211,35 @@
return buf.String()
}
+// digraph emits AT&T GraphViz (dot) syntax for the CFG.
+// TODO(adonovan): publish; needs a proposal.
+func (g *CFG) digraph(fset *token.FileSet) string {
+ var buf bytes.Buffer
+ buf.WriteString("digraph CFG {\n")
+ buf.WriteString(" node [shape=box];\n")
+ for _, b := range g.Blocks {
+ // node label
+ var text bytes.Buffer
+ text.WriteString(b.comment(fset))
+ for _, n := range b.Nodes {
+ fmt.Fprintf(&text, "\n%s", formatNode(fset, n))
+ }
+
+ // node and edges
+ fmt.Fprintf(&buf, " n%d [label=%q];\n", b.Index, &text)
+ for _, succ := range b.Succs {
+ fmt.Fprintf(&buf, " n%d -> n%d;\n", b.Index, succ.Index)
+ }
+ }
+ buf.WriteString("}\n")
+ return buf.String()
+}
+
+// exposed to main.go
+func digraph(g *CFG, fset *token.FileSet) string {
+ return g.digraph(fset)
+}
+
func formatNode(fset *token.FileSet, n ast.Node) string {
var buf bytes.Buffer
format.Node(&buf, fset, n)
diff --git a/go/cfg/cfg_test.go b/go/cfg/cfg_test.go
index f22bda3..536d2fe 100644
--- a/go/cfg/cfg_test.go
+++ b/go/cfg/cfg_test.go
@@ -2,15 +2,20 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package cfg
+package cfg_test
import (
"bytes"
"fmt"
"go/ast"
+ "go/format"
"go/parser"
"go/token"
"testing"
+
+ "golang.org/x/tools/go/cfg"
+ "golang.org/x/tools/go/packages"
+ "golang.org/x/tools/internal/testenv"
)
const src = `package main
@@ -140,7 +145,7 @@
}
for _, decl := range f.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
- g := New(decl.Body, mayReturn)
+ g := cfg.New(decl.Body, mayReturn)
// Print statements in unreachable blocks
// (in order determined by builder).
@@ -165,6 +170,57 @@
}
}
+// TestSmoke runs the CFG builder on every FuncDecl in the standard
+// library and x/tools. (This is all well-typed code, but it gives
+// some coverage.)
+func TestSmoke(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ testenv.NeedsTool(t, "go")
+
+ // The Mode API is just hateful.
+ // https://github.com/golang/go/issues/48226#issuecomment-1948792315
+ mode := packages.NeedDeps | packages.NeedImports | packages.NeedSyntax | packages.NeedTypes
+ pkgs, err := packages.Load(&packages.Config{Mode: mode}, "std", "golang.org/x/tools/...")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, pkg := range pkgs {
+ for _, file := range pkg.Syntax {
+ for _, decl := range file.Decls {
+ if decl, ok := decl.(*ast.FuncDecl); ok && decl.Body != nil {
+ g := cfg.New(decl.Body, mayReturn)
+
+ // Run a few quick sanity checks.
+ failed := false
+ for i, b := range g.Blocks {
+ errorf := func(format string, args ...any) {
+ if !failed {
+ t.Errorf("%s\n%s", pkg.Fset.Position(decl.Pos()), g.Format(pkg.Fset))
+ failed = true
+ }
+ msg := fmt.Sprintf(format, args...)
+ t.Errorf("block %d: %s", i, msg)
+ }
+
+ if b.Kind == cfg.KindInvalid {
+ errorf("invalid Block.Kind %v", b.Kind)
+ }
+ if b.Stmt == nil && b.Kind != cfg.KindLabel {
+ errorf("nil Block.Stmt (Kind=%v)", b.Kind)
+ }
+ if i != int(b.Index) {
+ errorf("invalid Block.Index")
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
// A trivial mayReturn predicate that looks only at syntax, not types.
func mayReturn(call *ast.CallExpr) bool {
switch fun := call.Fun.(type) {
@@ -175,3 +231,10 @@
}
return true
}
+
+func formatNode(fset *token.FileSet, n ast.Node) string {
+ var buf bytes.Buffer
+ format.Node(&buf, fset, n)
+ // Indent secondary lines by a tab.
+ return string(bytes.Replace(buf.Bytes(), []byte("\n"), []byte("\n\t"), -1))
+}
diff --git a/go/cfg/main.go b/go/cfg/main.go
new file mode 100644
index 0000000..e25b368
--- /dev/null
+++ b/go/cfg/main.go
@@ -0,0 +1,72 @@
+//go:build ignore
+
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The cfg command prints the control-flow graph of the first function
+// or method whose name matches 'funcname' in the specified package.
+//
+// Usage: cfg package funcname
+//
+// Example:
+//
+// $ go build -o cfg ./go/cfg/main.go
+// $ cfg ./go/cfg stmt | dot -Tsvg > cfg.svg && open cfg.svg
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "log"
+ "os"
+ _ "unsafe" // for linkname
+
+ "golang.org/x/tools/go/cfg"
+ "golang.org/x/tools/go/packages"
+)
+
+func main() {
+ flag.Parse()
+ if len(flag.Args()) != 2 {
+ log.Fatal("Usage: package funcname")
+ }
+ pattern, funcname := flag.Args()[0], flag.Args()[1]
+ pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadSyntax}, pattern)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if packages.PrintErrors(pkgs) > 0 {
+ os.Exit(1)
+ }
+ for _, pkg := range pkgs {
+ for _, f := range pkg.Syntax {
+ for _, decl := range f.Decls {
+ if decl, ok := decl.(*ast.FuncDecl); ok {
+ if decl.Name.Name == funcname {
+ g := cfg.New(decl.Body, mayReturn)
+ fmt.Println(digraph(g, pkg.Fset))
+ os.Exit(0)
+ }
+ }
+ }
+ }
+ }
+ log.Fatalf("no function %q found in %s", funcname, pattern)
+}
+
+// A trivial mayReturn predicate that looks only at syntax, not types.
+func mayReturn(call *ast.CallExpr) bool {
+ switch fun := call.Fun.(type) {
+ case *ast.Ident:
+ return fun.Name != "panic"
+ case *ast.SelectorExpr:
+ return fun.Sel.Name != "Fatal"
+ }
+ return true
+}
+
+//go:linkname digraph golang.org/x/tools/go/cfg.digraph
+func digraph(g *cfg.CFG, fset *token.FileSet) string