vulncheck: build callgraph in parallel with fetching db

Source(...) now builds the *ssa.Program and callgraph from
the *ssa.Program in parallel with fetching vulnerabilities.
Returns as soon as the vuln set is empty.

Updates golang/go#57357

Change-Id: I310b93f7125b5edcc2a5744db9f9f595c70aa5d4
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/460420
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Alan Donovan <adonovan@google.com>
Run-TryBot: Tim King <taking@google.com>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulncheck/source.go b/vulncheck/source.go
index 49cfce4..fcdfb26 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -9,6 +9,7 @@
 	"fmt"
 	"go/token"
 	"sort"
+	"sync"
 
 	"golang.org/x/tools/go/callgraph"
 	"golang.org/x/tools/go/ssa"
@@ -54,6 +55,28 @@
 		stdlibModule.Version = semver.GoTagToSemver(internal.GoEnv("GOVERSION"))
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
+	// If we are building the callgraph, build ssa and the callgraph in parallel
+	// with fetching vulnerabilities. If the vulns set is empty, return without
+	// waiting for SSA construction or callgraph to finish.
+	var (
+		wg       sync.WaitGroup // guards entries, cg, and buildErr
+		entries  []*ssa.Function
+		cg       *callgraph.Graph
+		buildErr error
+	)
+	if !cfg.ImportsOnly {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			prog, ssaPkgs := buildSSA(pkgs, fset)
+			entries = entryPoints(ssaPkgs)
+			cg, buildErr = callGraph(ctx, prog, entries)
+		}()
+	}
+
 	mods := extractModules(pkgs)
 	modVulns, err := fetchVulnerabilities(ctx, cfg.Client, mods)
 	if err != nil {
@@ -69,15 +92,16 @@
 	vulnPkgModSlice(pkgs, modVulns, result)
 	setModules(result, mods)
 	// Return result immediately if in ImportsOnly mode or
-	// if there are no vulnerable packages, as there is no
-	// need to build the call graph.
+	// if there are no vulnerable packages.
 	if cfg.ImportsOnly || len(result.Imports.Packages) == 0 {
 		return result, nil
 	}
 
-	prog, ssaPkgs := buildSSA(pkgs, fset)
-	entries := entryPoints(ssaPkgs)
-	cg := callGraph(prog, entries)
+	wg.Wait() // wait for build to finish
+	if buildErr != nil {
+		return nil, err
+	}
+
 	vulnCallGraphSlice(entries, modVulns, cg, result)
 
 	// Release residual memory.
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index 29effbe..ee7ab9b 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -6,6 +6,7 @@
 
 import (
 	"bytes"
+	"context"
 	"go/token"
 	"go/types"
 	"strings"
@@ -57,17 +58,25 @@
 }
 
 // callGraph builds a call graph of prog based on VTA analysis.
-func callGraph(prog *ssa.Program, entries []*ssa.Function) *callgraph.Graph {
+func callGraph(ctx context.Context, prog *ssa.Program, entries []*ssa.Function) (*callgraph.Graph, error) {
 	entrySlice := make(map[*ssa.Function]bool)
 	for _, e := range entries {
 		entrySlice[e] = true
 	}
+
+	if err := ctx.Err(); err != nil { // cancelled?
+		return nil, err
+	}
 	initial := cha.CallGraph(prog)
 	allFuncs := ssautil.AllFunctions(prog)
 
 	fslice := forwardReachableFrom(entrySlice, initial)
 	// Keep only actually linked functions.
 	pruneSet(fslice, allFuncs)
+
+	if err := ctx.Err(); err != nil { // cancelled?
+		return nil, err
+	}
 	vtaCg := vta.CallGraph(fslice, initial)
 
 	// Repeat the process once more, this time using
@@ -75,9 +84,12 @@
 	fslice = forwardReachableFrom(entrySlice, vtaCg)
 	pruneSet(fslice, allFuncs)
 
+	if err := ctx.Err(); err != nil { // cancelled?
+		return nil, err
+	}
 	cg := vta.CallGraph(fslice, vtaCg)
 	cg.DeleteSyntheticNodes()
-	return cg
+	return cg, nil
 }
 
 // siteCallees computes a set of callees for call site `call` given program `callgraph`.