vulncheck: fix support for packages whose every symbol is vulnerable

Some vulnerabilities have an empty slice for the osv.Entry Symbols
field, meaning every symbol is vulnerable. This CL adds support for such
vulnerabilities by doing the following, when an import of the
corresponding packages is seen, we list every top-level method and
function of that package and add it as a Vuln in vulncheck.Result.

Change-Id: I0edac51f7e3923ecfd3f203db80e6f7d22272dd2
Reviewed-on: https://go-review.googlesource.com/c/exp/+/371254
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulncheck/source.go b/vulncheck/source.go
index cd20d50..72151d4 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -127,7 +127,15 @@
 			if affected.Package.Name != pkgNode.Path {
 				continue
 			}
-			for _, symbol := range affected.EcosystemSpecific.Symbols {
+
+			var symbols []string
+			if len(affected.EcosystemSpecific.Symbols) != 0 {
+				symbols = affected.EcosystemSpecific.Symbols
+			} else {
+				symbols = allSymbols(pkg.Pkg)
+			}
+
+			for _, symbol := range symbols {
 				vuln := &Vuln{
 					OSV:        osv,
 					Symbol:     symbol,
@@ -333,9 +341,7 @@
 			if affected.Package.Name != funNode.PkgPath {
 				continue
 			}
-			for _, symbol := range affected.EcosystemSpecific.Symbols {
-				addCallSinkForVuln(funNode.ID, osv, symbol, funNode.PkgPath, result)
-			}
+			addCallSinkForVuln(funNode.ID, osv, dbFuncName(f), funNode.PkgPath, result)
 		}
 	}
 	return funNode
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index c97adf0..24127e0 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -467,3 +467,81 @@
 		t.Errorf("want 0 Vulns, got %d", len(result.Vulns))
 	}
 }
+
+func TestAllSymbolsVulnerable(t *testing.T) {
+	e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+		{
+			Name: "golang.org/entry",
+			Files: map[string]interface{}{
+				"x/x.go": `
+			package x
+
+			import "golang.org/vmod/vuln"
+
+			func X() {
+				vuln.V1()
+			}`,
+			},
+		},
+		{
+			Name: "golang.org/vmod@v1.2.3",
+			Files: map[string]interface{}{"vuln/vuln.go": `
+			package vuln
+
+			func V1() {}
+			func V2() {}
+			func v() {}
+			type a struct{}
+			func (x a) foo() {}
+			func (x *a) bar() {}
+			`},
+		},
+	})
+	defer e.Cleanup()
+
+	client := &mockClient{
+		ret: map[string][]*osv.Entry{
+			"golang.org/vmod": []*osv.Entry{
+				{
+					ID: "V",
+					Affected: []osv.Affected{{
+						Package:           osv.Package{Name: "golang.org/vmod/vuln"},
+						Ranges:            osv.Affects{{Type: osv.TypeSemver, Events: []osv.RangeEvent{{Introduced: "1.2.0"}}}},
+						EcosystemSpecific: osv.EcosystemSpecific{Symbols: []string{}},
+					}},
+				},
+			},
+		},
+	}
+
+	// Make sure local vulns can be loaded.
+	fetchingInTesting = true
+	// Load x as entry package.
+	pkgs, err := loadPackages(e, path.Join(e.Temp(), "entry/x"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(pkgs) != 1 {
+		t.Fatal("failed to load x test package")
+	}
+
+	cfg := &Config{
+		Client: client,
+	}
+	result, err := Source(context.Background(), Convert(pkgs), cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(result.Vulns) != 5 {
+		t.Errorf("want 5 Vulns, got %d", len(result.Vulns))
+	}
+
+	for _, v := range result.Vulns {
+		if v.Symbol == "V1" && v.CallSink == 0 {
+			t.Errorf("expected a call sink for V1; got none")
+		} else if v.Symbol != "V1" && v.CallSink != 0 {
+			t.Errorf("expected no call sink for %v; got %v", v.Symbol, v.CallSink)
+		}
+	}
+}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index 70d37f5..08448da 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -96,6 +96,21 @@
 	return matches
 }
 
+// dbTypeFormat formats the name of t according how types
+// are encoded in vulnerability database:
+//  - pointer designation * is skipped
+//  - full path prefix is skipped as well
+func dbTypeFormat(t types.Type) string {
+	switch tt := t.(type) {
+	case *types.Pointer:
+		return dbTypeFormat(tt.Elem())
+	case *types.Named:
+		return tt.Obj().Name()
+	default:
+		return types.TypeString(t, func(p *types.Package) string { return "" })
+	}
+}
+
 // dbFuncName computes a function name consistent with the namings used in vulnerability
 // databases. Effectively, a qualified name of a function local to its enclosing package.
 // If a receiver is a pointer, this information is not encoded in the resulting name. The
@@ -107,17 +122,6 @@
 //   func foo(...) {...}         -> foo
 //   func (b *B) bar (...) {...} -> B.bar
 func dbFuncName(f *ssa.Function) string {
-	var typeFormat func(t types.Type) string
-	typeFormat = func(t types.Type) string {
-		switch tt := t.(type) {
-		case *types.Pointer:
-			return typeFormat(tt.Elem())
-		case *types.Named:
-			return tt.Obj().Name()
-		default:
-			return types.TypeString(t, func(p *types.Package) string { return "" })
-		}
-	}
 	selectBound := func(f *ssa.Function) types.Type {
 		// If f is a "bound" function introduced by ssa for a given type, return the type.
 		// When "f" is a "bound" function, it will have 1 free variable of that type within
@@ -141,11 +145,11 @@
 	}
 	var qprefix string
 	if recv := f.Signature.Recv(); recv != nil {
-		qprefix = typeFormat(recv.Type())
+		qprefix = dbTypeFormat(recv.Type())
 	} else if btype := selectBound(f); btype != nil {
-		qprefix = typeFormat(btype)
+		qprefix = dbTypeFormat(btype)
 	} else if ttype := selectThunk(f); ttype != nil {
-		qprefix = typeFormat(ttype)
+		qprefix = dbTypeFormat(ttype)
 	}
 
 	if qprefix == "" {
@@ -154,6 +158,15 @@
 	return qprefix + "." + f.Name()
 }
 
+// dbTypesFuncName is dbFuncName defined over *types.Func.
+func dbTypesFuncName(f *types.Func) string {
+	sig := f.Type().(*types.Signature)
+	if sig.Recv() == nil {
+		return f.Name()
+	}
+	return dbTypeFormat(sig.Recv().Type()) + "." + f.Name()
+}
+
 // memberFuncs returns functions associated with the `member`:
 // 1) `member` itself if `member` is a function
 // 2) `member` methods if `member` is a type
@@ -222,3 +235,24 @@
 	}
 	return defaultValue
 }
+
+// allSymbols returns all top-level functions and methods defined in pkg.
+func allSymbols(pkg *types.Package) []string {
+	var names []string
+	scope := pkg.Scope()
+	for _, name := range scope.Names() {
+		o := scope.Lookup(name)
+		switch o := o.(type) {
+		case *types.Func:
+			names = append(names, dbTypesFuncName(o))
+		case *types.TypeName:
+			ms := types.NewMethodSet(types.NewPointer(o.Type()))
+			for i := 0; i < ms.Len(); i++ {
+				if f, ok := ms.At(i).Obj().(*types.Func); ok {
+					names = append(names, dbTypesFuncName(f))
+				}
+			}
+		}
+	}
+	return names
+}