go/callgraph/vta: improve support for function value flow

Nodes involving function types should be treated the same way as the
nodes involving interfaces. This was not the case earlier, causing vta
to miss producing edges for, say, struct fields that are functions. This
CL addresses that.

Change-Id: I1f6969868babfd0eeec8991f7403192d4ba0afe3
Reviewed-on: https://go-review.googlesource.com/c/tools/+/350732
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Tim King <taking@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go
index 6c9e6a5..f846418 100644
--- a/go/callgraph/vta/graph.go
+++ b/go/callgraph/vta/graph.go
@@ -190,6 +190,25 @@
 	return fmt.Sprintf("PtrInterface(%v)", l.typ)
 }
 
+// nestedPtrFunction node represents all references and dereferences of locals
+// and globals that have a nested pointer to function type. We merge such
+// constructs into a single node for simplicity and without much precision
+// sacrifice as such variables are rare in practice. Both a and b would be
+// represented as the same PtrFunction(func()) node in:
+//   var a *func()
+//   var b **func()
+type nestedPtrFunction struct {
+	typ types.Type
+}
+
+func (p nestedPtrFunction) Type() types.Type {
+	return p.typ
+}
+
+func (p nestedPtrFunction) String() string {
+	return fmt.Sprintf("PtrFunction(%v)", p.typ)
+}
+
 // panicArg models types of all arguments passed to panic.
 type panicArg struct{}
 
@@ -615,12 +634,16 @@
 
 // Creates const, pointer, global, func, and local nodes based on register instructions.
 func (b *builder) nodeFromVal(val ssa.Value) node {
-	if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) {
+	if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) && !isFunction(p.Elem()) {
 		// Nested pointer to interfaces are modeled as a special
 		// nestedPtrInterface node.
 		if i := interfaceUnderPtr(p.Elem()); i != nil {
 			return nestedPtrInterface{typ: i}
 		}
+		// The same goes for nested function types.
+		if f := functionUnderPtr(p.Elem()); f != nil {
+			return nestedPtrFunction{typ: f}
+		}
 		return pointer{typ: p}
 	}
 
@@ -665,6 +688,8 @@
 		return channelElem{typ: t}
 	case nestedPtrInterface:
 		return nestedPtrInterface{typ: t}
+	case nestedPtrFunction:
+		return nestedPtrFunction{typ: t}
 	case field:
 		return field{StructType: canonicalize(i.StructType, &b.canon), index: i.index}
 	case indexedLocal:
diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go
index 61bb05a..7ccfe49 100644
--- a/go/callgraph/vta/graph_test.go
+++ b/go/callgraph/vta/graph_test.go
@@ -43,6 +43,8 @@
 	pint := types.NewPointer(bint)
 	i := types.NewInterface(nil, nil)
 
+	voidFunc := main.Signature.Underlying()
+
 	for _, test := range []struct {
 		n node
 		s string
@@ -59,8 +61,9 @@
 		{global{val: gl}, "Global(gl)", gl.Type()},
 		{local{val: reg}, "Local(t0)", bint},
 		{indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X},
-		{function{f: main}, "Function(main)", main.Signature.Underlying()},
+		{function{f: main}, "Function(main)", voidFunc},
 		{nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i},
+		{nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc},
 		{panicArg{}, "Panic", nil},
 		{recoverReturn{}, "Recover", nil},
 	} {
@@ -181,6 +184,7 @@
 		"testdata/maps.go",
 		"testdata/ranges.go",
 		"testdata/closures.go",
+		"testdata/function_alias.go",
 		"testdata/static_calls.go",
 		"testdata/dynamic_calls.go",
 		"testdata/returns.go",
diff --git a/go/callgraph/vta/testdata/callgraph_field_funcs.go b/go/callgraph/vta/testdata/callgraph_field_funcs.go
new file mode 100644
index 0000000..cf4c0f1
--- /dev/null
+++ b/go/callgraph/vta/testdata/callgraph_field_funcs.go
@@ -0,0 +1,67 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type WrappedFunc struct {
+	F func() complex64
+}
+
+func callWrappedFunc(f WrappedFunc) {
+	f.F()
+}
+
+func foo() complex64 {
+	println("foo")
+	return -1
+}
+
+func Foo(b bool) {
+	callWrappedFunc(WrappedFunc{foo})
+	x := func() {}
+	y := func() {}
+	var a *func()
+	if b {
+		a = &x
+	} else {
+		a = &y
+	}
+	(*a)()
+}
+
+// Relevant SSA:
+// func Foo(b bool):
+//         t0 = local WrappedFunc (complit)
+//         t1 = &t0.F [#0]
+//         *t1 = foo
+//         t2 = *t0
+//         t3 = callWrappedFunc(t2)
+//         t4 = new func() (x)
+//         *t4 = Foo$1
+//         t5 = new func() (y)
+//         *t5 = Foo$2
+//         if b goto 1 else 3
+// 1:
+//         jump 2
+// 2:
+//         t6 = phi [1: t4, 3: t5] #a
+//         t7 = *t6
+//         t8 = t7()
+//         return
+// 3:
+//         jump 2
+//
+// func callWrappedFunc(f WrappedFunc):
+//         t0 = local WrappedFunc (f)
+//         *t0 = f
+//         t1 = &t0.F [#0]
+//         t2 = *t1
+//         t3 = t2()
+//         return
+
+// WANT:
+// callWrappedFunc: t2() -> foo
+// Foo: callWrappedFunc(t2) -> callWrappedFunc; t7() -> Foo$1, Foo$2
diff --git a/go/callgraph/vta/testdata/function_alias.go b/go/callgraph/vta/testdata/function_alias.go
new file mode 100644
index 0000000..b38e0e0
--- /dev/null
+++ b/go/callgraph/vta/testdata/function_alias.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type Doer func()
+
+type A struct {
+	foo func()
+	do  Doer
+}
+
+func Baz(f func()) {
+	j := &f
+	k := &j
+	**k = func() {}
+	a := A{}
+	a.foo = **k
+	a.foo()
+	a.do = a.foo
+	a.do()
+}
+
+// Relevant SSA:
+// func Baz(f func()):
+//        t0 = new func() (f)
+//        *t0 = f
+//        t1 = new *func() (j)
+//        *t1 = t0
+//        t2 = *t1
+//        *t2 = Baz$1
+//        t3 = local A (a)
+//        t4 = &t3.foo [#0]
+//        t5 = *t1
+//        t6 = *t5
+//        *t4 = t6
+//        t7 = &t3.foo [#0]
+//        t8 = *t7
+//        t9 = t8()
+//        t10 = &t3.do [#1]                                                 *Doer
+//        t11 = &t3.foo [#0]                                              *func()
+//        t12 = *t11                                                       func()
+//        t13 = changetype Doer <- func() (t12)                              Doer
+//        *t10 = t13
+//        t14 = &t3.do [#1]                                                 *Doer
+//        t15 = *t14                                                         Doer
+//        t16 = t15()                                                          ()
+
+// Flow chain showing that Baz$1 reaches t8():
+//   Baz$1 -> t2 <-> PtrFunction(func()) <-> t5 -> t6 -> t4 <-> Field(testdata.A:foo) <-> t7 -> t8
+// Flow chain showing that Baz$1 reaches t15():
+//  Field(testdata.A:foo) <-> t11 -> t12 -> t13 -> t10 <-> Field(testdata.A:do) <-> t14 -> t15
+
+// WANT:
+// Local(f) -> Local(t0)
+// Local(t0) -> PtrFunction(func())
+// Function(Baz$1) -> Local(t2)
+// PtrFunction(func()) -> Local(t0), Local(t2), Local(t5)
+// Local(t2) -> PtrFunction(func())
+// Local(t4) -> Field(testdata.A:foo)
+// Local(t5) -> Local(t6), PtrFunction(func())
+// Local(t6) -> Local(t4)
+// Local(t7) -> Field(testdata.A:foo), Local(t8)
+// Field(testdata.A:foo) -> Local(t11), Local(t4), Local(t7)
+// Local(t4) -> Field(testdata.A:foo)
+// Field(testdata.A:do) -> Local(t10), Local(t14)
+// Local(t10) -> Field(testdata.A:do)
+// Local(t11) -> Field(testdata.A:foo), Local(t12)
+// Local(t12) -> Local(t13)
+// Local(t13) -> Local(t10)
+// Local(t14) -> Field(testdata.A:do), Local(t15)
diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go
index cabc93b..9633b86 100644
--- a/go/callgraph/vta/utils.go
+++ b/go/callgraph/vta/utils.go
@@ -19,6 +19,9 @@
 	if _, ok := n.(nestedPtrInterface); ok {
 		return true
 	}
+	if _, ok := n.(nestedPtrFunction); ok {
+		return true
+	}
 
 	if _, ok := n.Type().(*types.Pointer); ok {
 		return true
@@ -33,7 +36,9 @@
 //  2) is a (nested) pointer to interface (needed for, say,
 //     slice elements of nested pointers to interface type)
 //  3) is a function type (needed for higher-order type flow)
-//  4) is a global Recover or Panic node
+//  4) is a (nested) pointer to function (needed for, say,
+//     slice elements of nested pointers to function type)
+//  5) is a global Recover or Panic node
 func hasInFlow(n node) bool {
 	if _, ok := n.(panicArg); ok {
 		return true
@@ -44,15 +49,14 @@
 
 	t := n.Type()
 
-	if _, ok := t.Underlying().(*types.Signature); ok {
-		return true
-	}
-
 	if i := interfaceUnderPtr(t); i != nil {
 		return true
 	}
+	if f := functionUnderPtr(t); f != nil {
+		return true
+	}
 
-	return isInterface(t)
+	return isInterface(t) || isFunction(t)
 }
 
 // hasInitialTypes check if a node can have initial types.
@@ -72,6 +76,11 @@
 	return ok
 }
 
+func isFunction(t types.Type) bool {
+	_, ok := t.Underlying().(*types.Signature)
+	return ok
+}
+
 // interfaceUnderPtr checks if type `t` is a potentially nested
 // pointer to interface and if yes, returns the interface type.
 // Otherwise, returns nil.
@@ -88,6 +97,22 @@
 	return interfaceUnderPtr(p.Elem())
 }
 
+// functionUnderPtr checks if type `t` is a potentially nested
+// pointer to function type and if yes, returns the function type.
+// Otherwise, returns nil.
+func functionUnderPtr(t types.Type) types.Type {
+	p, ok := t.Underlying().(*types.Pointer)
+	if !ok {
+		return nil
+	}
+
+	if isFunction(p.Elem()) {
+		return p.Elem()
+	}
+
+	return functionUnderPtr(p.Elem())
+}
+
 // sliceArrayElem returns the element type of type `t` that is
 // expected to be a (pointer to) array or slice, consistent with
 // the ssa.Index and ssa.IndexAddr instructions. Panics otherwise.
diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go
index b0d2de7..e5a9b41 100644
--- a/go/callgraph/vta/vta_test.go
+++ b/go/callgraph/vta/vta_test.go
@@ -20,6 +20,7 @@
 		"testdata/callgraph_pointers.go",
 		"testdata/callgraph_collections.go",
 		"testdata/callgraph_fields.go",
+		"testdata/callgraph_field_funcs.go",
 	} {
 		t.Run(file, func(t *testing.T) {
 			prog, want, err := testProg(file)