exp/vulncheck: populate module graph in Source

Change-Id: I5dc940125d9e5341e983bc58bbbc0ba167f05ff8
Reviewed-on: https://go-review.googlesource.com/c/exp/+/360774
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/fetch.go b/vulncheck/fetch.go
index b065fe7..5cef708 100644
--- a/vulncheck/fetch.go
+++ b/vulncheck/fetch.go
@@ -20,21 +20,7 @@
 	if mod == nil {
 		return ""
 	}
-	return fmt.Sprintf("%s@%s", modPath(mod), modVersion(mod))
-}
-
-func modPath(mod *packages.Module) string {
-	if mod.Replace != nil {
-		return mod.Replace.Path
-	}
-	return mod.Path
-}
-
-func modVersion(mod *packages.Module) string {
-	if mod.Replace != nil {
-		return mod.Replace.Version
-	}
-	return mod.Version
+	return fmt.Sprintf("%s@%s", mod.Path, mod.Version)
 }
 
 // extractModules collects modules in `pkgs` up to uniqueness of
@@ -49,7 +35,11 @@
 			return
 		}
 		if pkg.Module != nil {
-			modMap[modKey(pkg.Module)] = pkg.Module
+			if pkg.Module.Replace != nil {
+				modMap[modKey(pkg.Module.Replace)] = pkg.Module
+			} else {
+				modMap[modKey(pkg.Module)] = pkg.Module
+			}
 		}
 		seen[pkg] = true
 		for _, imp := range pkg.Imports {
diff --git a/vulncheck/helpers_test.go b/vulncheck/helpers_test.go
index 8aa6d18..453d47a 100644
--- a/vulncheck/helpers_test.go
+++ b/vulncheck/helpers_test.go
@@ -82,6 +82,17 @@
 	return m
 }
 
+func reqGraphToStrMap(rg *RequireGraph) map[string][]string {
+	m := make(map[string][]string)
+	for _, n := range rg.Modules {
+		for _, predId := range n.RequiredBy {
+			pred := rg.Modules[predId]
+			m[pred.Path] = append(m[pred.Path], n.Path)
+		}
+	}
+	return m
+}
+
 func loadPackages(e *packagestest.Exported, patterns ...string) ([]*packages.Package, error) {
 	e.Config.Mode |= packages.NeedModule | packages.NeedName | packages.NeedFiles |
 		packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedTypes |
diff --git a/vulncheck/source.go b/vulncheck/source.go
index 46eab37..28a665a 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -29,34 +29,38 @@
 		Imports:  &ImportGraph{Packages: make(map[int]*PkgNode)},
 		Requires: &RequireGraph{Modules: make(map[int]*ModNode)},
 	}
-	vulnPkgImportSlice(pkgs, modVulns, result)
-	// TODO(zpavlinovic): compute module and call graph slice.
+	vulnPkgModSlice(pkgs, modVulns, result)
 	return result, nil
 }
 
-// pkgId is an id counter for nodes of Imports graph.
+// pkgID is an id counter for nodes of Imports graph.
 var pkgID int = 0
 
 func nextPkgID() int {
-	pkgID += 1
+	pkgID++
 	return pkgID
 }
 
-// vulnPkgImportSlice computes the slice of pkg imports graph leading to imports of vulnerable
-// packages in modVulns and stores the slice to result.
-func vulnPkgImportSlice(pkgs []*packages.Package, modVulns moduleVulnerabilities, result *Result) {
-	// analyzed contains information on packages analyzed thus far.
+// vulnPkgModSlice computes the slice of pkgs imports and requires graph
+// leading to imports/requires of vulnerable packages/modules in modVulns
+// and stores the computed slices to result.
+func vulnPkgModSlice(pkgs []*packages.Package, modVulns moduleVulnerabilities, result *Result) {
+	// analyzedPkgs contains information on packages analyzed thus far.
 	// If a package is mapped to nil, this means it has been visited
 	// but it does not lead to a vulnerable imports. Otherwise, a
 	// visited package is mapped to Imports package node.
-	analyzed := make(map[*packages.Package]*PkgNode)
+	analyzedPkgs := make(map[*packages.Package]*PkgNode)
 	for _, pkg := range pkgs {
 		// Top level packages that lead to vulnerable imports are
 		// stored as result.Imports graph entry points.
-		if e := vulnImportSlice(pkg, modVulns, result, analyzed); e != nil {
+		if e := vulnImportSlice(pkg, modVulns, result, analyzedPkgs); e != nil {
 			result.Imports.Entries = append(result.Imports.Entries, e)
 		}
 	}
+
+	// Populate module requires slice as an overlay
+	// of package imports slice.
+	vulnModuleSlice(result)
 }
 
 // vulnImportSlice checks if pkg has some vulnerabilities or transitively imports
@@ -90,6 +94,7 @@
 	pkgNode := &PkgNode{
 		Name: pkg.Name,
 		Path: pkg.PkgPath,
+		pkg:  pkg,
 	}
 	analyzed[pkg] = pkgNode
 
@@ -112,7 +117,6 @@
 					OSV:        osv,
 					Symbol:     symbol,
 					PkgPath:    pkgNode.Path,
-					ModPath:    modPath(pkg.Module),
 					ImportSink: id,
 				}
 				result.Vulns = append(result.Vulns, vuln)
@@ -121,3 +125,99 @@
 	}
 	return pkgNode
 }
+
+// vulnModuleSlice populates result.Requires as an overlay
+// of result.Imports.
+func vulnModuleSlice(result *Result) {
+	// Map from module nodes, identified with their
+	// path and version, to their unique ids.
+	modNodeIDs := make(map[string]int)
+	// We first collect inverse requires by (predecessor)
+	// relation on module node ids.
+	modPredRelation := make(map[int]map[int]bool)
+	for _, pkgNode := range result.Imports.Packages {
+		// Create or get module node for pkgNode.
+		pkgModID := moduleNodeID(pkgNode, result, modNodeIDs)
+		pkgNode.Module = pkgModID
+
+		// Get the set of predecessors.
+		predSet := make(map[int]bool)
+		for _, predPkgID := range pkgNode.ImportedBy {
+			predModID := moduleNodeID(result.Imports.Packages[predPkgID], result, modNodeIDs)
+			predSet[predModID] = true
+		}
+		modPredRelation[pkgModID] = predSet
+	}
+
+	// Store the predecessor requires relation to result.
+	for modID := range modPredRelation {
+		if modID == 0 {
+			continue
+		}
+
+		var predIDs []int
+		for predID := range modPredRelation[modID] {
+			predIDs = append(predIDs, predID)
+		}
+		modNode := result.Requires.Modules[modID]
+		modNode.RequiredBy = predIDs
+	}
+
+	// And finally update Vulns with module information.
+	for _, vuln := range result.Vulns {
+		pkgNode := result.Imports.Packages[vuln.ImportSink]
+		modNode := result.Requires.Modules[pkgNode.Module]
+
+		vuln.RequireSink = pkgNode.Module
+		vuln.ModPath = modNode.Path
+	}
+}
+
+// modID is an id counter for nodes of Requires graph.
+var modID int = 0
+
+func nextModID() int {
+	modID++
+	return modID
+}
+
+// moduleNode creates a module node associated with pkgNode, if one does
+// not exist already, and returns id of the module node. The actual module
+// node is stored to result.
+func moduleNodeID(pkgNode *PkgNode, result *Result, modNodeIDs map[string]int) int {
+	mod := pkgNode.pkg.Module
+	if mod == nil {
+		return 0
+	}
+
+	mk := modKey(mod)
+	if id, ok := modNodeIDs[mk]; ok {
+		return id
+	}
+
+	id := nextModID()
+	n := &ModNode{
+		Path:    mod.Path,
+		Version: mod.Version,
+	}
+	result.Requires.Modules[id] = n
+	modNodeIDs[mk] = id
+
+	// Create a replace module too when applicable.
+	if mod.Replace != nil {
+		rmk := modKey(mod.Replace)
+		if rid, ok := modNodeIDs[rmk]; ok {
+			n.Replace = rid
+		} else {
+			rid := nextModID()
+			rn := &ModNode{
+				Path:    mod.Replace.Path,
+				Version: mod.Replace.Version,
+			}
+			result.Requires.Modules[rid] = rn
+			modNodeIDs[rmk] = rid
+			n.Replace = rid
+		}
+	}
+	return id
+}
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index ef8a850..8e1effa 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -114,8 +114,6 @@
 		t.Fatal(err)
 	}
 
-	// TODO(zpavlinovic): add test below for module graph too.
-
 	// Check that we find the right number of vulnerabilities.
 	// There should be three entries as there are three vulnerable
 	// symbols in the two import-reachable OSVs.
@@ -123,14 +121,15 @@
 		t.Errorf("want 3 Vulns, got %d", len(result.Vulns))
 	}
 
-	// Check that vulnerabilities are connected to the imports graph.
+	// Check that vulnerabilities are connected to the imports
+	// and requires graph.
 	for _, v := range result.Vulns {
-		if v.ImportSink == 0 {
-			t.Errorf("want ImportSink !=0 for vuln %v:%v; got 0", v.Symbol, v.PkgPath)
+		if v.ImportSink == 0 || v.RequireSink == 0 {
+			t.Errorf("want ImportSink !=0 and RequireSink !=0 for %v:%v; got %v and %v", v.Symbol, v.PkgPath, v.ImportSink, v.RequireSink)
 		}
 	}
 
-	// The slice should include import chains:
+	// The imports slice should include import chains:
 	//   x -> avuln -> w -> bvuln
 	//         |
 	//   y ---->
@@ -145,4 +144,17 @@
 	if igStrMap := impGraphToStrMap(result.Imports); !reflect.DeepEqual(wantImports, igStrMap) {
 		t.Errorf("want %v imports graph; got %v", wantImports, igStrMap)
 	}
+
+	// The requires slice should include requires chains:
+	//   entry -> amod -> wmod -> bmod
+	// That is, zmod module shoud not appear in the slice.
+	wantRequires := map[string][]string{
+		"golang.org/entry": {"golang.org/amod"},
+		"golang.org/amod":  {"golang.org/wmod"},
+		"golang.org/wmod":  {"golang.org/bmod"},
+	}
+
+	if rgStrMap := reqGraphToStrMap(result.Requires); !reflect.DeepEqual(wantRequires, rgStrMap) {
+		t.Errorf("want %v requires graph; got %v", wantRequires, rgStrMap)
+	}
 }
diff --git a/vulncheck/vulncheck.go b/vulncheck/vulncheck.go
index 132ebf9..2506045 100644
--- a/vulncheck/vulncheck.go
+++ b/vulncheck/vulncheck.go
@@ -80,9 +80,11 @@
 	RequireSink int
 }
 
-// CallGraph whose sinks are vulnerable functions and sources are entry points of user
-// packages. CallGraph is backwards directed, i.e., from a function node to the place
-// where the function is called.
+// CallGraph is a slice of a full program call graph whose sinks are conceptually
+// vulnerable functions and sources are entry points of user packages. In order to
+// support succinct traversal of the slice related to a particular vulnerability,
+// CallGraph is technically backwards directed, i.e., from a vulnerable function
+// towards the program entry functions (see FuncNode).
 type CallGraph struct {
 	// Funcs contains all call graph nodes as a map: func node id -> func node.
 	Funcs map[int]*FuncNode
@@ -113,10 +115,11 @@
 	Resolved bool
 }
 
-// RequireGraph models part of module requires graph where sinks are modules with
-// some known vulnerabilities and sources are modules of user entry packages.
-// RequireGraph is backwards directed, i.e., from a module to the set of modules
-// it is required by.
+// RequireGraph is a slice of a full program module requires graph whose sinks
+// are conceptually modules with some known vulnerabilities and sources are modules
+// of user entry packages. In order to support succinct traversal of the slice
+// related to a particular vulnerability, RequireGraph is technically backwards
+// directed, i.e., from a vulnerable module towards the program entry modules (see ModNode).
 type RequireGraph struct {
 	// Modules contains all module nodes as a map: module node id -> module node.
 	Modules map[int]*ModNode
@@ -127,14 +130,17 @@
 type ModNode struct {
 	Path    string
 	Version string
-	Replace *ModNode
+	// Replace is the ID of the replacement module node, if any.
+	Replace int
 	// RequiredBy contains IDs of the modules requiring this module.
 	RequiredBy []int
 }
 
-// ImportGraph models part of package import graph where sinks are packages with
-// some known vulnerabilities and sources are user specified packages. The graph
-// is backwards directed, i.e., from a package to the set of packages importing it.
+// ImportGraph is a slice of a full program package import graph whose sinks are
+// conceptually packages with some known vulnerabilities and sources are user
+// specified packages. In order to support succinct traversal of the slice related
+// to a particular vulnerability, ImportGraph is technically backwards directed,
+// i.e., from a vulnerable package towards the program entry packages (see PkgNode).
 type ImportGraph struct {
 	// Packages contains all package nodes as a map: package node id -> package node.
 	Packages map[int]*PkgNode
@@ -150,6 +156,9 @@
 	Module int
 	// ImportedBy contains IDs of packages directly importing this package.
 	ImportedBy []int
+
+	// pkg is used for connecting package node to module and call graph nodes.
+	pkg *packages.Package
 }
 
 // moduleVulnerabilities is an internal structure for