internal/symbols, cmd/vulnreport: move logic to add symbols to reports
Move the logic (but don't modify it) to populate symbols to its own file in
internal/symbols. Add some basic tests that confirm the current behavior
(which will likely be tweaked in follow up CLs).
Change-Id: I10593154c343adb680733ebd66a4dd97abed2c43
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/560778
Reviewed-by: Maceo Thompson <maceothompson@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/cmd/vulnreport/symbols.go b/cmd/vulnreport/symbols.go
index 11ba99c..bc0aa78 100644
--- a/cmd/vulnreport/symbols.go
+++ b/cmd/vulnreport/symbols.go
@@ -6,12 +6,8 @@
import (
"context"
- "fmt"
- "path/filepath"
- "strings"
"golang.org/x/vulndb/cmd/vulnreport/log"
- "golang.org/x/vulndb/internal/osv"
"golang.org/x/vulndb/internal/report"
"golang.org/x/vulndb/internal/symbols"
)
@@ -34,66 +30,10 @@
if err != nil {
return err
}
- var defaultFixes []string
- for _, ref := range r.References {
- if ref.Type == osv.ReferenceTypeFix {
- if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
- defaultFixes = append(defaultFixes, ref.URL)
- }
- }
- }
- if len(defaultFixes) == 0 {
- return fmt.Errorf("no commit fix links found")
+ if err = symbols.Populate(r, log.Err); err != nil {
+ return err
}
- for _, mod := range r.Modules {
- hasFixLink := mod.FixLink != ""
- if hasFixLink {
- defaultFixes = append(defaultFixes, mod.FixLink)
- }
- numFixedSymbols := make([]int, len(defaultFixes))
- for i, fixLink := range defaultFixes {
- fixHash := filepath.Base(fixLink)
- fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
- pkgsToSymbols, err := symbols.Patched(mod.Module, fixRepo, fixHash)
- if err != nil {
- log.Err(err)
- continue
- }
- packages := mod.AllPackages()
- for pkg, symbols := range pkgsToSymbols {
- if _, exists := packages[pkg]; exists {
- packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
- } else {
- mod.Packages = append(mod.Packages, &report.Package{
- Package: pkg,
- Symbols: symbols,
- })
- }
- numFixedSymbols[i] += len(symbols)
- }
- }
- // if the module's link field wasn't already populated, populate it with
- // the link that results in the most symbols
- if hasFixLink {
- defaultFixes = defaultFixes[:len(defaultFixes)-1]
- } else {
- mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
- }
- }
return r.Write(filename)
}
-
-// indexMax takes a slice of nonempty ints and returns the index of the maximum value
-func indexMax(s []int) (index int) {
- maxVal := s[0]
- index = 0
- for i, val := range s {
- if val > maxVal {
- maxVal = val
- index = i
- }
- }
- return index
-}
diff --git a/internal/symbols/populate.go b/internal/symbols/populate.go
new file mode 100644
index 0000000..840cfdc
--- /dev/null
+++ b/internal/symbols/populate.go
@@ -0,0 +1,86 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package symbols
+
+import (
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/vulndb/internal/osv"
+ "golang.org/x/vulndb/internal/report"
+)
+
+// Populate attempts to populate the report with symbols derived
+// from the patch link(s) in the report.
+func Populate(r *report.Report, errln logln) error {
+ return populate(r, Patched, errln)
+}
+
+func populate(r *report.Report, patched func(string, string, string) (map[string][]string, error), errln logln) error {
+ var defaultFixes []string
+
+ for _, ref := range r.References {
+ if ref.Type == osv.ReferenceTypeFix {
+ if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
+ defaultFixes = append(defaultFixes, ref.URL)
+ }
+ }
+ }
+ if len(defaultFixes) == 0 {
+ return fmt.Errorf("no commit fix links found")
+ }
+
+ for _, mod := range r.Modules {
+ hasFixLink := mod.FixLink != ""
+ if hasFixLink {
+ defaultFixes = append(defaultFixes, mod.FixLink)
+ }
+ numFixedSymbols := make([]int, len(defaultFixes))
+ for i, fixLink := range defaultFixes {
+ fixHash := filepath.Base(fixLink)
+ fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
+ pkgsToSymbols, err := patched(mod.Module, fixRepo, fixHash)
+ if err != nil {
+ errln(err)
+ continue
+ }
+ packages := mod.AllPackages()
+ for pkg, symbols := range pkgsToSymbols {
+ if _, exists := packages[pkg]; exists {
+ packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
+ } else {
+ mod.Packages = append(mod.Packages, &report.Package{
+ Package: pkg,
+ Symbols: symbols,
+ })
+ }
+ numFixedSymbols[i] += len(symbols)
+ }
+ }
+ // if the module's link field wasn't already populated, populate it with
+ // the link that results in the most symbols
+ if hasFixLink {
+ defaultFixes = defaultFixes[:len(defaultFixes)-1]
+ } else {
+ mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
+ }
+ }
+
+ return nil
+}
+
+// indexMax takes a slice of nonempty ints and returns the index of the maximum value
+func indexMax(s []int) (index int) {
+ maxVal := s[0]
+ index = 0
+ for i, val := range s {
+ if val > maxVal {
+ maxVal = val
+ index = i
+ }
+ }
+ return index
+}
diff --git a/internal/symbols/populate_test.go b/internal/symbols/populate_test.go
new file mode 100644
index 0000000..8b35f1d
--- /dev/null
+++ b/internal/symbols/populate_test.go
@@ -0,0 +1,116 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package symbols
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/vulndb/internal/osv"
+ "golang.org/x/vulndb/internal/report"
+)
+
+func TestPopulate(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ input *report.Report
+ want *report.Report
+ }{
+ {
+ name: "basic",
+ input: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ }},
+ References: []*report.Reference{{
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/1234",
+ }},
+ },
+ want: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ Packages: []*report.Package{{
+ Package: "example.com/module/package",
+ Symbols: []string{"symbol1", "symbol2"},
+ }},
+ FixLink: "https://example.com/module/commit/1234",
+ }},
+ References: []*report.Reference{
+ {
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/1234",
+ },
+ },
+ },
+ },
+ {
+ name: "multiple_fixes",
+ input: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ }},
+ References: []*report.Reference{
+ {
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/1234",
+ },
+ {
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/5678",
+ },
+ },
+ },
+ want: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ Packages: []*report.Package{{
+ Package: "example.com/module/package",
+ // We don't yet dedupe the symbols.
+ Symbols: []string{"symbol1", "symbol2", "symbol1", "symbol2", "symbol3"},
+ }},
+ // This commit is picked because it results in the most symbols.
+ FixLink: "https://example.com/module/commit/5678",
+ }},
+ References: []*report.Reference{
+ {
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/1234",
+ },
+ {
+ Type: osv.ReferenceTypeFix,
+ URL: "https://example.com/module/commit/5678",
+ },
+ },
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ discardLog := func(...any) {}
+ if err := populate(tc.input, patchedFake, discardLog); err != nil {
+ t.Fatal(err)
+ }
+ got := tc.input
+ if diff := cmp.Diff(tc.want, got); diff != "" {
+ t.Errorf("populate mismatch (-want, +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func patchedFake(module string, repo string, hash string) (map[string][]string, error) {
+ if module == "example.com/module" && repo == "https://example.com/module" && hash == "1234" {
+ return map[string][]string{
+ "example.com/module/package": {"symbol1", "symbol2"},
+ }, nil
+ }
+ if module == "example.com/module" && repo == "https://example.com/module" && hash == "5678" {
+ return map[string][]string{
+ "example.com/module/package": {"symbol1", "symbol2", "symbol3"},
+ }, nil
+ }
+ return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo, hash)
+}