internal/symbols, cmd/vulnreport: move logic to add symbols to reports

Move the logic (but don't modify it) to populate symbols to its own file in
internal/symbols. Add some basic tests that confirm the current behavior
(which will likely be tweaked in follow up CLs).

Change-Id: I10593154c343adb680733ebd66a4dd97abed2c43
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/560778
Reviewed-by: Maceo Thompson <maceothompson@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/cmd/vulnreport/symbols.go b/cmd/vulnreport/symbols.go
index 11ba99c..bc0aa78 100644
--- a/cmd/vulnreport/symbols.go
+++ b/cmd/vulnreport/symbols.go
@@ -6,12 +6,8 @@
 
 import (
 	"context"
-	"fmt"
-	"path/filepath"
-	"strings"
 
 	"golang.org/x/vulndb/cmd/vulnreport/log"
-	"golang.org/x/vulndb/internal/osv"
 	"golang.org/x/vulndb/internal/report"
 	"golang.org/x/vulndb/internal/symbols"
 )
@@ -34,66 +30,10 @@
 	if err != nil {
 		return err
 	}
-	var defaultFixes []string
 
-	for _, ref := range r.References {
-		if ref.Type == osv.ReferenceTypeFix {
-			if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
-				defaultFixes = append(defaultFixes, ref.URL)
-			}
-		}
-	}
-	if len(defaultFixes) == 0 {
-		return fmt.Errorf("no commit fix links found")
+	if err = symbols.Populate(r, log.Err); err != nil {
+		return err
 	}
 
-	for _, mod := range r.Modules {
-		hasFixLink := mod.FixLink != ""
-		if hasFixLink {
-			defaultFixes = append(defaultFixes, mod.FixLink)
-		}
-		numFixedSymbols := make([]int, len(defaultFixes))
-		for i, fixLink := range defaultFixes {
-			fixHash := filepath.Base(fixLink)
-			fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
-			pkgsToSymbols, err := symbols.Patched(mod.Module, fixRepo, fixHash)
-			if err != nil {
-				log.Err(err)
-				continue
-			}
-			packages := mod.AllPackages()
-			for pkg, symbols := range pkgsToSymbols {
-				if _, exists := packages[pkg]; exists {
-					packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
-				} else {
-					mod.Packages = append(mod.Packages, &report.Package{
-						Package: pkg,
-						Symbols: symbols,
-					})
-				}
-				numFixedSymbols[i] += len(symbols)
-			}
-		}
-		// if the module's link field wasn't already populated, populate it with
-		// the link that results in the most symbols
-		if hasFixLink {
-			defaultFixes = defaultFixes[:len(defaultFixes)-1]
-		} else {
-			mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
-		}
-	}
 	return r.Write(filename)
 }
-
-// indexMax takes a slice of nonempty ints and returns the index of the maximum value
-func indexMax(s []int) (index int) {
-	maxVal := s[0]
-	index = 0
-	for i, val := range s {
-		if val > maxVal {
-			maxVal = val
-			index = i
-		}
-	}
-	return index
-}
diff --git a/internal/symbols/populate.go b/internal/symbols/populate.go
new file mode 100644
index 0000000..840cfdc
--- /dev/null
+++ b/internal/symbols/populate.go
@@ -0,0 +1,86 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package symbols
+
+import (
+	"fmt"
+	"path/filepath"
+	"strings"
+
+	"golang.org/x/vulndb/internal/osv"
+	"golang.org/x/vulndb/internal/report"
+)
+
+// Populate attempts to populate the report with symbols derived
+// from the patch link(s) in the report.
+func Populate(r *report.Report, errln logln) error {
+	return populate(r, Patched, errln)
+}
+
+func populate(r *report.Report, patched func(string, string, string) (map[string][]string, error), errln logln) error {
+	var defaultFixes []string
+
+	for _, ref := range r.References {
+		if ref.Type == osv.ReferenceTypeFix {
+			if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
+				defaultFixes = append(defaultFixes, ref.URL)
+			}
+		}
+	}
+	if len(defaultFixes) == 0 {
+		return fmt.Errorf("no commit fix links found")
+	}
+
+	for _, mod := range r.Modules {
+		hasFixLink := mod.FixLink != ""
+		if hasFixLink {
+			defaultFixes = append(defaultFixes, mod.FixLink)
+		}
+		numFixedSymbols := make([]int, len(defaultFixes))
+		for i, fixLink := range defaultFixes {
+			fixHash := filepath.Base(fixLink)
+			fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
+			pkgsToSymbols, err := patched(mod.Module, fixRepo, fixHash)
+			if err != nil {
+				errln(err)
+				continue
+			}
+			packages := mod.AllPackages()
+			for pkg, symbols := range pkgsToSymbols {
+				if _, exists := packages[pkg]; exists {
+					packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
+				} else {
+					mod.Packages = append(mod.Packages, &report.Package{
+						Package: pkg,
+						Symbols: symbols,
+					})
+				}
+				numFixedSymbols[i] += len(symbols)
+			}
+		}
+		// if the module's link field wasn't already populated, populate it with
+		// the link that results in the most symbols
+		if hasFixLink {
+			defaultFixes = defaultFixes[:len(defaultFixes)-1]
+		} else {
+			mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
+		}
+	}
+
+	return nil
+}
+
+// indexMax takes a slice of nonempty ints and returns the index of the maximum value
+func indexMax(s []int) (index int) {
+	maxVal := s[0]
+	index = 0
+	for i, val := range s {
+		if val > maxVal {
+			maxVal = val
+			index = i
+		}
+	}
+	return index
+}
diff --git a/internal/symbols/populate_test.go b/internal/symbols/populate_test.go
new file mode 100644
index 0000000..8b35f1d
--- /dev/null
+++ b/internal/symbols/populate_test.go
@@ -0,0 +1,116 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package symbols
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"golang.org/x/vulndb/internal/osv"
+	"golang.org/x/vulndb/internal/report"
+)
+
+func TestPopulate(t *testing.T) {
+	for _, tc := range []struct {
+		name  string
+		input *report.Report
+		want  *report.Report
+	}{
+		{
+			name: "basic",
+			input: &report.Report{
+				Modules: []*report.Module{{
+					Module: "example.com/module",
+				}},
+				References: []*report.Reference{{
+					Type: osv.ReferenceTypeFix,
+					URL:  "https://example.com/module/commit/1234",
+				}},
+			},
+			want: &report.Report{
+				Modules: []*report.Module{{
+					Module: "example.com/module",
+					Packages: []*report.Package{{
+						Package: "example.com/module/package",
+						Symbols: []string{"symbol1", "symbol2"},
+					}},
+					FixLink: "https://example.com/module/commit/1234",
+				}},
+				References: []*report.Reference{
+					{
+						Type: osv.ReferenceTypeFix,
+						URL:  "https://example.com/module/commit/1234",
+					},
+				},
+			},
+		},
+		{
+			name: "multiple_fixes",
+			input: &report.Report{
+				Modules: []*report.Module{{
+					Module: "example.com/module",
+				}},
+				References: []*report.Reference{
+					{
+						Type: osv.ReferenceTypeFix,
+						URL:  "https://example.com/module/commit/1234",
+					},
+					{
+						Type: osv.ReferenceTypeFix,
+						URL:  "https://example.com/module/commit/5678",
+					},
+				},
+			},
+			want: &report.Report{
+				Modules: []*report.Module{{
+					Module: "example.com/module",
+					Packages: []*report.Package{{
+						Package: "example.com/module/package",
+						// We don't yet dedupe the symbols.
+						Symbols: []string{"symbol1", "symbol2", "symbol1", "symbol2", "symbol3"},
+					}},
+					// This commit is picked because it results in the most symbols.
+					FixLink: "https://example.com/module/commit/5678",
+				}},
+				References: []*report.Reference{
+					{
+						Type: osv.ReferenceTypeFix,
+						URL:  "https://example.com/module/commit/1234",
+					},
+					{
+						Type: osv.ReferenceTypeFix,
+						URL:  "https://example.com/module/commit/5678",
+					},
+				},
+			},
+		},
+	} {
+		t.Run(tc.name, func(t *testing.T) {
+			discardLog := func(...any) {}
+			if err := populate(tc.input, patchedFake, discardLog); err != nil {
+				t.Fatal(err)
+			}
+			got := tc.input
+			if diff := cmp.Diff(tc.want, got); diff != "" {
+				t.Errorf("populate mismatch (-want, +got):\n%s", diff)
+			}
+		})
+	}
+}
+
+func patchedFake(module string, repo string, hash string) (map[string][]string, error) {
+	if module == "example.com/module" && repo == "https://example.com/module" && hash == "1234" {
+		return map[string][]string{
+			"example.com/module/package": {"symbol1", "symbol2"},
+		}, nil
+	}
+	if module == "example.com/module" && repo == "https://example.com/module" && hash == "5678" {
+		return map[string][]string{
+			"example.com/module/package": {"symbol1", "symbol2", "symbol3"},
+		}, nil
+	}
+	return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo, hash)
+}