vta: finalizes VTA graph construction by adding support for function calls

These include call and return statements as well as closure creations
and panics/recovers.

Change-Id: Iee4a4e48e1b9c304959fbce4f3eb43eecd8cb851
Reviewed-on: https://go-review.googlesource.com/c/tools/+/323049
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go
index 30e514f..1b7b105 100644
--- a/go/callgraph/vta/graph.go
+++ b/go/callgraph/vta/graph.go
@@ -262,6 +262,9 @@
 }
 
 func (b *builder) visit(funcs map[*ssa.Function]bool) {
+	// Add the fixed edge Panic -> Recover
+	b.graph.addEdge(panicArg{}, recoverReturn{})
+
 	for f, in := range funcs {
 		if in {
 			b.fun(f)
@@ -283,6 +286,8 @@
 		b.addInFlowAliasEdges(b.nodeFromVal(i.Addr), b.nodeFromVal(i.Val))
 	case *ssa.MakeInterface:
 		b.addInFlowEdge(b.nodeFromVal(i.X), b.nodeFromVal(i))
+	case *ssa.MakeClosure:
+		b.closure(i)
 	case *ssa.UnOp:
 		b.unop(i)
 	case *ssa.Phi:
@@ -333,14 +338,19 @@
 		b.mapUpdate(i)
 	case *ssa.Next:
 		b.next(i)
+	case ssa.CallInstruction:
+		b.call(i)
+	case *ssa.Panic:
+		b.panic(i)
+	case *ssa.Return:
+		b.rtrn(i)
 	case *ssa.MakeChan, *ssa.MakeMap, *ssa.MakeSlice, *ssa.BinOp,
 		*ssa.Alloc, *ssa.DebugRef, *ssa.Convert, *ssa.Jump, *ssa.If,
 		*ssa.Slice, *ssa.Range, *ssa.RunDefers:
 		// No interesting flow here.
 		return
 	default:
-		// TODO(zpavlinovic): make into a panic once all instructions are supported.
-		fmt.Printf("unsupported instruction %v\n", instr)
+		panic(fmt.Sprintf("unsupported instruction %v\n", instr))
 	}
 }
 
@@ -500,6 +510,97 @@
 	}
 }
 
+func (b *builder) closure(c *ssa.MakeClosure) {
+	f := c.Fn.(*ssa.Function)
+	b.addInFlowEdge(function{f: f}, b.nodeFromVal(c))
+
+	for i, fv := range f.FreeVars {
+		b.addInFlowAliasEdges(b.nodeFromVal(fv), b.nodeFromVal(c.Bindings[i]))
+	}
+}
+
+// panic creates a flow from arguments to panic instructions to return
+// registers of all recover statements in the program. Introduces a
+// global panic node Panic and
+//  1) for every panic statement p: add p -> Panic
+//  2) for every recover statement r: add Panic -> r (handled in call)
+// TODO(zpavlinovic): improve precision by explicitly modeling how panic
+// values flow from callees to callers and into deferred recover instructions.
+func (b *builder) panic(p *ssa.Panic) {
+	// Panics often have, for instance, strings as arguments which do
+	// not create interesting flows.
+	if !canHaveMethods(p.X.Type()) {
+		return
+	}
+
+	b.addInFlowEdge(b.nodeFromVal(p.X), panicArg{})
+}
+
+// call adds flows between arguments/parameters and return values/registers
+// for both static and dynamic calls, as well as go and defer calls.
+func (b *builder) call(c ssa.CallInstruction) {
+	// When c is r := recover() call register instruction, we add Recover -> r.
+	if bf, ok := c.Common().Value.(*ssa.Builtin); ok && bf.Name() == "recover" {
+		b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(c.(*ssa.Call)))
+		return
+	}
+
+	for _, f := range siteCallees(c, b.callGraph) {
+		addArgumentFlows(b, c, f)
+	}
+}
+
+func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) {
+	cc := c.Common()
+	// When c is an unresolved method call (cc.Method != nil), cc.Value contains
+	// the receiver object rather than cc.Args[0].
+	if cc.Method != nil {
+		b.addInFlowAliasEdges(b.nodeFromVal(f.Params[0]), b.nodeFromVal(cc.Value))
+	}
+
+	offset := 0
+	if cc.Method != nil {
+		offset = 1
+	}
+	for i, v := range cc.Args {
+		b.addInFlowAliasEdges(b.nodeFromVal(f.Params[i+offset]), b.nodeFromVal(v))
+	}
+}
+
+// rtrn produces flows between values of r and c where
+// c is a call instruction that resolves to the enclosing
+// function of r based on b.callGraph.
+func (b *builder) rtrn(r *ssa.Return) {
+	n := b.callGraph.Nodes[r.Parent()]
+	// n != nil when b.callgraph is sound, but the client can
+	// pass any callgraph, including an underapproximate one.
+	if n == nil {
+		return
+	}
+
+	for _, e := range n.In {
+		if cv, ok := e.Site.(ssa.Value); ok {
+			addReturnFlows(b, r, cv)
+		}
+	}
+}
+
+func addReturnFlows(b *builder, r *ssa.Return, site ssa.Value) {
+	results := r.Results
+	if len(results) == 1 {
+		// When there is only one return value, the destination register does not
+		// have a tuple type.
+		b.addInFlowEdge(b.nodeFromVal(results[0]), b.nodeFromVal(site))
+		return
+	}
+
+	tup := site.Type().Underlying().(*types.Tuple)
+	for i, r := range results {
+		local := indexedLocal{val: site, typ: tup.At(i).Type(), index: i}
+		b.addInFlowEdge(b.nodeFromVal(r), local)
+	}
+}
+
 // addInFlowEdge adds s -> d to g if d is node that can have an inflow, i.e., a node
 // that represents an interface or an unresolved function value. Otherwise, there
 // is no interesting type flow so the edge is ommited.
diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go
index 499817c..08866c8 100644
--- a/go/callgraph/vta/graph_test.go
+++ b/go/callgraph/vta/graph_test.go
@@ -241,6 +241,11 @@
 		"testdata/stores_arrays.go",
 		"testdata/maps.go",
 		"testdata/ranges.go",
+		"testdata/closures.go",
+		"testdata/static_calls.go",
+		"testdata/dynamic_calls.go",
+		"testdata/returns.go",
+		"testdata/panic.go",
 	} {
 		t.Run(file, func(t *testing.T) {
 			prog, want, err := testProg(file)
diff --git a/go/callgraph/vta/testdata/closures.go b/go/callgraph/vta/testdata/closures.go
new file mode 100644
index 0000000..6e6c0ac
--- /dev/null
+++ b/go/callgraph/vta/testdata/closures.go
@@ -0,0 +1,53 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface {
+	Foo()
+}
+
+func Do(i I) { i.Foo() }
+
+func Baz(b bool, h func(I)) {
+	var i I
+	a := func(g func(I)) {
+		g(i)
+	}
+
+	if b {
+		h = Do
+	}
+
+	a(h)
+}
+
+// Relevant SSA:
+//  func Baz(b bool, h func(I)):
+//    t0 = new I (i)
+//    t1 = make closure Baz$1 [t0]
+//    if b goto 1 else 2
+//   1:
+//         jump 2
+//   2:
+//    t2 = phi [0: h, 1: Do] #h
+//    t3 = t1(t2)
+//    return
+//
+// func Baz$1(g func(I)):
+//    t0 = *i
+//    t1 = g(t0)
+//    return
+
+// In the edge set Local(i) -> Local(t0), Local(t0) below,
+// two occurrences of t0 come from t0 in Baz and Baz$1.
+
+// WANT:
+// Function(Do) -> Local(t2)
+// Function(Baz$1) -> Local(t1)
+// Local(h) -> Local(t2)
+// Local(t0) -> Local(i)
+// Local(i) -> Local(t0), Local(t0)
diff --git a/go/callgraph/vta/testdata/dynamic_calls.go b/go/callgraph/vta/testdata/dynamic_calls.go
new file mode 100644
index 0000000..fa4270b
--- /dev/null
+++ b/go/callgraph/vta/testdata/dynamic_calls.go
@@ -0,0 +1,43 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface {
+	foo(I)
+}
+
+type A struct{}
+
+func (a A) foo(ai I) {}
+
+type B struct{}
+
+func (b B) foo(bi I) {}
+
+func doWork() I { return nil }
+func close() I  { return nil }
+
+func Baz(x B, h func() I, i I) I {
+	i.foo(x)
+
+	return h()
+}
+
+// Relevant SSA:
+// func Baz(x B, h func() I, i I) I:
+//   t0 = local B (x)
+//   *t0 = x
+//   t1 = *t0
+//   t2 = make I <- B (t1)
+//   t3 = invoke i.foo(t2)
+//   t4 = h()
+//   return t4
+
+// WANT:
+// Local(t2) -> Local(ai), Local(bi)
+// Constant(testdata.I) -> Local(t4)
+// Local(t1) -> Local(t2)
diff --git a/go/callgraph/vta/testdata/panic.go b/go/callgraph/vta/testdata/panic.go
new file mode 100644
index 0000000..2d39c70
--- /dev/null
+++ b/go/callgraph/vta/testdata/panic.go
@@ -0,0 +1,66 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface {
+	foo()
+}
+
+type A struct{}
+
+func (a A) foo() {}
+
+func recover1() {
+	print("only this recover should execute")
+	if r, ok := recover().(I); ok {
+		r.foo()
+	}
+}
+
+func recover2() {
+	recover()
+}
+
+func Baz(a A) {
+	defer recover1()
+	panic(a)
+}
+
+// Relevant SSA:
+// func recover1():
+// 	0:
+//   t0 = print("only this recover...":string)
+//   t1 = recover()
+//   t2 = typeassert,ok t1.(I)
+//   t3 = extract t2 #0
+//   t4 = extract t2 #1
+//   if t4 goto 1 else 2
+//  1:
+//   t5 = invoke t3.foo()
+//   jump 2
+//  2:
+//   return
+//
+// func recover2():
+//   t0 = recover()
+//   return
+//
+// func Baz(i I):
+//   t0 = local A (a)
+//   *t0 = a
+//   defer recover1()
+//   t1 = *t0
+//   t2 = make interface{} <- A (t1)
+//   panic t2
+
+// t2 argument to panic in Baz gets ultimately connected to recover
+// registers t1 in recover1() and t0 in recover2().
+
+// WANT:
+// Panic -> Recover
+// Local(t2) -> Panic
+// Recover -> Local(t0), Local(t1)
diff --git a/go/callgraph/vta/testdata/returns.go b/go/callgraph/vta/testdata/returns.go
new file mode 100644
index 0000000..b11b432
--- /dev/null
+++ b/go/callgraph/vta/testdata/returns.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface{}
+
+func Bar(ii I) (I, I) {
+	return Foo(ii)
+}
+
+func Foo(iii I) (I, I) {
+	return iii, iii
+}
+
+func Do(j I) *I {
+	return &j
+}
+
+func Baz(i I) *I {
+	Bar(i)
+	return Do(i)
+}
+
+// Relevant SSA:
+// func Bar(ii I) (I, I):
+//   t0 = Foo(ii)
+//   t1 = extract t0 #0
+//   t2 = extract t0 #1
+//   return t1, t2
+//
+// func Foo(iii I) (I, I):
+//   return iii, iii
+//
+// func Do(j I) *I:
+//   t0 = new I (j)
+//   *t0 = j
+//   return t0
+//
+// func Baz(i I):
+//   t0 = Bar(i)
+//   t1 = Do(i)
+//   return t1
+
+// t0 and t1 in the last edge correspond to the nodes
+// of Do and Baz. This edge is induced by Do(i).
+
+// WANT:
+// Local(i) -> Local(ii), Local(j)
+// Local(ii) -> Local(iii)
+// Local(iii) -> Local(t0[0]), Local(t0[1])
+// Local(t1) -> Local(t0[0])
+// Local(t2) -> Local(t0[1])
+// Local(t0) -> Local(t1)
diff --git a/go/callgraph/vta/testdata/static_calls.go b/go/callgraph/vta/testdata/static_calls.go
new file mode 100644
index 0000000..74a27c1
--- /dev/null
+++ b/go/callgraph/vta/testdata/static_calls.go
@@ -0,0 +1,41 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface{}
+
+func foo(i I) (I, I) {
+	return i, i
+}
+
+func doWork(ii I) {}
+
+func close(iii I) {}
+
+func Baz(inp I) {
+	a, b := foo(inp)
+	defer close(a)
+	go doWork(b)
+}
+
+// Relevant SSA:
+// func Baz(inp I):
+//   t0 = foo(inp)
+//   t1 = extract t0 #0
+//   t2 = extract t0 #1
+//   defer close(t1)
+//   go doWork(t2)
+//   rundefers
+//   ...
+// func foo(i I) (I, I):
+//   return i, i
+
+// WANT:
+// Local(inp) -> Local(i)
+// Local(t1) -> Local(iii)
+// Local(t2) -> Local(ii)
+// Local(i) -> Local(t0[0]), Local(t0[1])
diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go
index 274d9a3..3142893 100644
--- a/go/callgraph/vta/utils.go
+++ b/go/callgraph/vta/utils.go
@@ -6,6 +6,9 @@
 
 import (
 	"go/types"
+
+	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/ssa"
 )
 
 func canAlias(n1, n2 node) bool {
@@ -100,3 +103,36 @@
 	}
 	return u.(*types.Slice).Elem()
 }
+
+// siteCallees computes a set of callees for call site `c` given program `callgraph`.
+func siteCallees(c ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
+	var matches []*ssa.Function
+
+	node := callgraph.Nodes[c.Parent()]
+	if node == nil {
+		return nil
+	}
+
+	for _, edge := range node.Out {
+		callee := edge.Callee.Func
+		// Skip synthetic functions wrapped around source functions.
+		if edge.Site == c && callee.Synthetic == "" {
+			matches = append(matches, callee)
+		}
+	}
+	return matches
+}
+
+func canHaveMethods(t types.Type) bool {
+	if _, ok := t.(*types.Named); ok {
+		return true
+	}
+
+	u := t.Underlying()
+	switch u.(type) {
+	case *types.Interface, *types.Signature, *types.Struct:
+		return true
+	default:
+		return false
+	}
+}