internal/symbols: only download any repo once
This change modifies symbols.Patched to take a repository instead of a
repoURL - if there are multiple fix links for the same repository, it
will be reused instead of re cloning the repository for every fix link.
Change-Id: Iafc125c28e852b583859a47114b4497ae6e8cf12
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/572036
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
diff --git a/internal/symbols/patched_functions.go b/internal/symbols/patched_functions.go
index b6861bd..8a9e23d 100644
--- a/internal/symbols/patched_functions.go
+++ b/internal/symbols/patched_functions.go
@@ -6,7 +6,6 @@
import (
"bytes"
- "context"
"errors"
"fmt"
"go/ast"
@@ -26,39 +25,25 @@
"github.com/go-git/go-git/v5/plumbing/object"
"golang.org/x/mod/modfile"
"golang.org/x/vulndb/internal/derrors"
- "golang.org/x/vulndb/internal/gitrepo"
)
// Patched returns symbols of module patched in commit identified
-// by commitHash. repoURL is URL of the git repository containing
-// the module.
+// by commitHash. r is the git repository containing the module.
//
// Patched returns a map from package import paths to symbols
// patched in the package. Test packages and symbols are omitted.
//
// If the commit has more than one parent, an error is returned.
-func Patched(module, repoURL, commitHash string) (_ map[string][]string, err error) {
- defer derrors.Wrap(&err, "Patched(%s, %s, %s)", module, repoURL, commitHash)
-
- repoRoot, err := os.MkdirTemp("", commitHash)
- if err != nil {
- return nil, err
- }
- defer func() {
- _ = os.RemoveAll(repoRoot)
- }()
-
- ctx := context.Background()
- repo, err := gitrepo.PlainClone(ctx, repoRoot, repoURL)
- if err != nil {
- return nil, err
- }
-
+func Patched(module, commitHash string, r *repository) (_ map[string][]string, err error) {
+ defer derrors.Wrap(&err, "Patched(%s, %s, %s)", module, r.url, commitHash)
+ repo := r.repo
w, err := repo.Worktree()
if err != nil {
return nil, err
}
+ defer resetWorktree(r.repo, w)
+
hash := plumbing.NewHash(commitHash)
commit, err := findCommit(repo, w, hash)
if err != nil {
@@ -78,7 +63,7 @@
return nil, err
}
- newSymbols, err := moduleSymbols(repoRoot, module)
+ newSymbols, err := moduleSymbols(r.root, module)
if err != nil {
return nil, err
}
@@ -87,7 +72,7 @@
return nil, err
}
- oldSymbols, err := moduleSymbols(repoRoot, module)
+ oldSymbols, err := moduleSymbols(r.root, module)
if err != nil {
return nil, err
}
@@ -103,6 +88,14 @@
return pkgSyms, nil
}
+// resetWorktree takes a repository and its worktree and resets it to MAIN/MASTER@HEAD
+func resetWorktree(r *git.Repository, w *git.Worktree) {
+ r.Fetch(&git.FetchOptions{})
+ w.Reset(&git.ResetOptions{
+ Mode: git.HardReset,
+ })
+}
+
// findCommit attempts to find a commit with hash in repo's w work tree.
// If it cannot find the fix at the current branch, it tries to identify
// the commit at all remote branches. Once it finds a commit, it returns
diff --git a/internal/symbols/populate.go b/internal/symbols/populate.go
index 2a1021e..b09b47f 100644
--- a/internal/symbols/populate.go
+++ b/internal/symbols/populate.go
@@ -5,48 +5,74 @@
package symbols
import (
+ "context"
"errors"
"fmt"
+ "os"
"path/filepath"
"strings"
+ "github.com/go-git/go-git/v5"
"golang.org/x/exp/slices"
+ "golang.org/x/vulndb/internal/gitrepo"
"golang.org/x/vulndb/internal/report"
)
+// repository represents a repository that may contain fixes for a given report.
+type repository struct {
+ repo *git.Repository
+ url string
+ root string
+ fixHashes []string
+}
+
// Populate attempts to populate the report with symbols derived
// from the patch link(s) in the report.
func Populate(r *report.Report, update bool) error {
- return populate(r, update, Patched)
+ return populate(r, update, gitrepo.PlainClone, Patched)
}
-func populate(r *report.Report, update bool, patched func(string, string, string) (map[string][]string, error)) error {
- var errs []error
+func populate(r *report.Report, update bool, clone func(context.Context, string, string) (*git.Repository, error), patched func(string, string, *repository) (map[string][]string, error)) error {
+ reportFixRepos, errs := getFixRepos(r.CommitLinks(), clone)
for _, mod := range r.Modules {
- hasFixLink := len(mod.FixLinks) >= 0
- fixLinks := mod.FixLinks
- if len(fixLinks) == 0 {
- c := r.CommitLinks()
- if len(c) == 0 {
- errs = append(errs, fmt.Errorf("no commit fix links found for module %s", mod.Module))
+ hasFixLinks := len(mod.FixLinks) > 0
+ fixRepos := reportFixRepos
+ if hasFixLinks {
+ frs, ers := getFixRepos(mod.FixLinks, clone)
+ if len(ers) != 0 {
+ errs = append(errs, ers...)
+ }
+ if len(frs) == 0 {
+ errs = append(errs, fmt.Errorf("no working repos found for %s", mod.Module))
continue
}
- fixLinks = c
+ fixRepos = frs
}
foundSymbols := false
- for _, fixLink := range fixLinks {
- found, err := populateFromFixLink(fixLink, mod, update, patched)
- if err != nil {
- errs = append(errs, err)
+ for _, repo := range fixRepos {
+ for _, hash := range repo.fixHashes {
+ found, err := populateFromFixHash(repo, hash, mod, patched)
+ if err != nil {
+ errs = append(errs, err)
+ }
+ if !hasFixLinks && update && found {
+ fixLink := fmt.Sprintf("%s/commit/%s", repo.url, hash)
+ mod.FixLinks = append(mod.FixLinks, fixLink)
+ }
+ foundSymbols = foundSymbols || found
}
- foundSymbols = foundSymbols || found
+ root := repo.root
+ defer func() {
+ _ = os.RemoveAll(root)
+ }()
}
- if !foundSymbols && fixLinks != nil {
+
+ if !foundSymbols {
errs = append(errs, fmt.Errorf("no vulnerable symbols found for module %s", mod.Module))
}
// Sort fix links for testing/deterministic output
- if !hasFixLink && update {
+ if !hasFixLinks && update {
slices.Sort(mod.FixLinks)
}
}
@@ -54,12 +80,10 @@
return errors.Join(errs...)
}
-// populateFromFixLink takes a fixLink and a module and returns true if any symbols
-// are found for the given fix/module pair.
-func populateFromFixLink(fixLink string, m *report.Module, update bool, patched func(string, string, string) (map[string][]string, error)) (foundSymbols bool, err error) {
- fixHash := filepath.Base(fixLink)
- fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
- pkgsToSymbols, err := patched(m.Module, fixRepo, fixHash)
+// populateFromFixHash takes a repository, fix hash and corresponding module and returns true
+// if any symbols are found for the given fix/module pairs.
+func populateFromFixHash(repo *repository, fixHash string, m *report.Module, patched func(string, string, *repository) (map[string][]string, error)) (foundSymbols bool, err error) {
+ pkgsToSymbols, err := patched(m.Module, fixHash, repo)
if err != nil {
return false, err
}
@@ -80,8 +104,37 @@
})
}
}
- if update && foundSymbols {
- m.FixLinks = append(m.FixLinks, fixLink)
- }
return foundSymbols, nil
}
+
+// getFixRepos takes a list of fix links and returns the repositories and hashes of those fix links.
+func getFixRepos(links []string, clone func(context.Context, string, string) (*git.Repository, error)) (fixRepos map[string]*repository, errs []error) {
+ fixRepos = make(map[string]*repository)
+ for _, fixLink := range links {
+ fixHash := filepath.Base(fixLink)
+ repoURL := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
+ if _, found := fixRepos[repoURL]; !found {
+ repoRoot, err := os.MkdirTemp("", fixHash)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("error making temp dir for repo %s: %v", repoURL, err))
+ continue
+ }
+ ctx := context.Background()
+ r, err := clone(ctx, repoRoot, repoURL)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("error cloning repo: %v", err.Error()))
+ continue
+ }
+ fixRepos[repoURL] = &repository{
+ repo: r,
+ url: repoURL,
+ root: repoRoot,
+ fixHashes: []string{fixHash},
+ }
+ } else {
+ r := fixRepos[repoURL]
+ r.fixHashes = append(r.fixHashes, fixHash)
+ }
+ }
+ return fixRepos, errs
+}
diff --git a/internal/symbols/populate_test.go b/internal/symbols/populate_test.go
index aa9f83a..59f3621 100644
--- a/internal/symbols/populate_test.go
+++ b/internal/symbols/populate_test.go
@@ -5,9 +5,11 @@
package symbols
import (
+ "context"
"fmt"
"testing"
+ "github.com/go-git/go-git/v5"
"github.com/google/go-cmp/cmp"
"golang.org/x/vulndb/internal/osv"
"golang.org/x/vulndb/internal/report"
@@ -127,9 +129,29 @@
},
},
},
+ {
+ name: "has fix link",
+ update: false,
+ input: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ FixLinks: []string{"https://example.com/module/commit/1234", "https://example.com/module/commit/5678"},
+ }},
+ },
+ want: &report.Report{
+ Modules: []*report.Module{{
+ Module: "example.com/module",
+ Packages: []*report.Package{{
+ Package: "example.com/module/package",
+ Symbols: []string{"symbol1", "symbol2", "symbol3"},
+ }},
+ FixLinks: []string{"https://example.com/module/commit/1234", "https://example.com/module/commit/5678"},
+ }},
+ },
+ },
} {
t.Run(tc.name, func(t *testing.T) {
- if err := populate(tc.input, tc.update, patchedFake); err != nil {
+ if err := populate(tc.input, tc.update, mockClone, patchedFake); err != nil {
t.Fatal(err)
}
got := tc.input
@@ -140,16 +162,20 @@
}
}
-func patchedFake(module string, repo string, hash string) (map[string][]string, error) {
- if module == "example.com/module" && repo == "https://example.com/module" && hash == "1234" {
+func patchedFake(module string, hash string, repo *repository) (map[string][]string, error) {
+ if module == "example.com/module" && repo.url == "https://example.com/module" && hash == "1234" {
return map[string][]string{
"example.com/module/package": {"symbol1", "symbol2"},
}, nil
}
- if module == "example.com/module" && repo == "https://example.com/module" && hash == "5678" {
+ if module == "example.com/module" && repo.url == "https://example.com/module" && hash == "5678" {
return map[string][]string{
"example.com/module/package": {"symbol1", "symbol2", "symbol3"},
}, nil
}
- return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo, hash)
+ return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo.url, hash)
+}
+
+func mockClone(ctx context.Context, dir, repoURL string) (repo *git.Repository, err error) {
+ return nil, err
}