internal/symbols: only download any repo once

This change modifies symbols.Patched to take a repository instead of a
repoURL - if there are multiple fix links for the same repository, it
will be reused instead of re cloning the repository for every fix link.

Change-Id: Iafc125c28e852b583859a47114b4497ae6e8cf12
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/572036
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
diff --git a/internal/symbols/patched_functions.go b/internal/symbols/patched_functions.go
index b6861bd..8a9e23d 100644
--- a/internal/symbols/patched_functions.go
+++ b/internal/symbols/patched_functions.go
@@ -6,7 +6,6 @@
 
 import (
 	"bytes"
-	"context"
 	"errors"
 	"fmt"
 	"go/ast"
@@ -26,39 +25,25 @@
 	"github.com/go-git/go-git/v5/plumbing/object"
 	"golang.org/x/mod/modfile"
 	"golang.org/x/vulndb/internal/derrors"
-	"golang.org/x/vulndb/internal/gitrepo"
 )
 
 // Patched returns symbols of module patched in commit identified
-// by commitHash. repoURL is URL of the git repository containing
-// the module.
+// by commitHash. r is the git repository containing the module.
 //
 // Patched returns a map from package import paths to symbols
 // patched in the package. Test packages and symbols are omitted.
 //
 // If the commit has more than one parent, an error is returned.
-func Patched(module, repoURL, commitHash string) (_ map[string][]string, err error) {
-	defer derrors.Wrap(&err, "Patched(%s, %s, %s)", module, repoURL, commitHash)
-
-	repoRoot, err := os.MkdirTemp("", commitHash)
-	if err != nil {
-		return nil, err
-	}
-	defer func() {
-		_ = os.RemoveAll(repoRoot)
-	}()
-
-	ctx := context.Background()
-	repo, err := gitrepo.PlainClone(ctx, repoRoot, repoURL)
-	if err != nil {
-		return nil, err
-	}
-
+func Patched(module, commitHash string, r *repository) (_ map[string][]string, err error) {
+	defer derrors.Wrap(&err, "Patched(%s, %s, %s)", module, r.url, commitHash)
+	repo := r.repo
 	w, err := repo.Worktree()
 	if err != nil {
 		return nil, err
 	}
 
+	defer resetWorktree(r.repo, w)
+
 	hash := plumbing.NewHash(commitHash)
 	commit, err := findCommit(repo, w, hash)
 	if err != nil {
@@ -78,7 +63,7 @@
 		return nil, err
 	}
 
-	newSymbols, err := moduleSymbols(repoRoot, module)
+	newSymbols, err := moduleSymbols(r.root, module)
 	if err != nil {
 		return nil, err
 	}
@@ -87,7 +72,7 @@
 		return nil, err
 	}
 
-	oldSymbols, err := moduleSymbols(repoRoot, module)
+	oldSymbols, err := moduleSymbols(r.root, module)
 	if err != nil {
 		return nil, err
 	}
@@ -103,6 +88,14 @@
 	return pkgSyms, nil
 }
 
+// resetWorktree takes a repository and its worktree and resets it to MAIN/MASTER@HEAD
+func resetWorktree(r *git.Repository, w *git.Worktree) {
+	r.Fetch(&git.FetchOptions{})
+	w.Reset(&git.ResetOptions{
+		Mode: git.HardReset,
+	})
+}
+
 // findCommit attempts to find a commit with hash in repo's w work tree.
 // If it cannot find the fix at the current branch, it tries to identify
 // the commit at all remote branches. Once it finds a commit, it returns
diff --git a/internal/symbols/populate.go b/internal/symbols/populate.go
index 2a1021e..b09b47f 100644
--- a/internal/symbols/populate.go
+++ b/internal/symbols/populate.go
@@ -5,48 +5,74 @@
 package symbols
 
 import (
+	"context"
 	"errors"
 	"fmt"
+	"os"
 	"path/filepath"
 	"strings"
 
+	"github.com/go-git/go-git/v5"
 	"golang.org/x/exp/slices"
+	"golang.org/x/vulndb/internal/gitrepo"
 	"golang.org/x/vulndb/internal/report"
 )
 
+// repository represents a repository that may contain fixes for a given report.
+type repository struct {
+	repo      *git.Repository
+	url       string
+	root      string
+	fixHashes []string
+}
+
 // Populate attempts to populate the report with symbols derived
 // from the patch link(s) in the report.
 func Populate(r *report.Report, update bool) error {
-	return populate(r, update, Patched)
+	return populate(r, update, gitrepo.PlainClone, Patched)
 }
 
-func populate(r *report.Report, update bool, patched func(string, string, string) (map[string][]string, error)) error {
-	var errs []error
+func populate(r *report.Report, update bool, clone func(context.Context, string, string) (*git.Repository, error), patched func(string, string, *repository) (map[string][]string, error)) error {
+	reportFixRepos, errs := getFixRepos(r.CommitLinks(), clone)
 	for _, mod := range r.Modules {
-		hasFixLink := len(mod.FixLinks) >= 0
-		fixLinks := mod.FixLinks
-		if len(fixLinks) == 0 {
-			c := r.CommitLinks()
-			if len(c) == 0 {
-				errs = append(errs, fmt.Errorf("no commit fix links found for module %s", mod.Module))
+		hasFixLinks := len(mod.FixLinks) > 0
+		fixRepos := reportFixRepos
+		if hasFixLinks {
+			frs, ers := getFixRepos(mod.FixLinks, clone)
+			if len(ers) != 0 {
+				errs = append(errs, ers...)
+			}
+			if len(frs) == 0 {
+				errs = append(errs, fmt.Errorf("no working repos found for %s", mod.Module))
 				continue
 			}
-			fixLinks = c
+			fixRepos = frs
 		}
 
 		foundSymbols := false
-		for _, fixLink := range fixLinks {
-			found, err := populateFromFixLink(fixLink, mod, update, patched)
-			if err != nil {
-				errs = append(errs, err)
+		for _, repo := range fixRepos {
+			for _, hash := range repo.fixHashes {
+				found, err := populateFromFixHash(repo, hash, mod, patched)
+				if err != nil {
+					errs = append(errs, err)
+				}
+				if !hasFixLinks && update && found {
+					fixLink := fmt.Sprintf("%s/commit/%s", repo.url, hash)
+					mod.FixLinks = append(mod.FixLinks, fixLink)
+				}
+				foundSymbols = foundSymbols || found
 			}
-			foundSymbols = foundSymbols || found
+			root := repo.root
+			defer func() {
+				_ = os.RemoveAll(root)
+			}()
 		}
-		if !foundSymbols && fixLinks != nil {
+
+		if !foundSymbols {
 			errs = append(errs, fmt.Errorf("no vulnerable symbols found for module %s", mod.Module))
 		}
 		// Sort fix links for testing/deterministic output
-		if !hasFixLink && update {
+		if !hasFixLinks && update {
 			slices.Sort(mod.FixLinks)
 		}
 	}
@@ -54,12 +80,10 @@
 	return errors.Join(errs...)
 }
 
-// populateFromFixLink takes a fixLink and a module and returns true if any symbols
-// are found for the given fix/module pair.
-func populateFromFixLink(fixLink string, m *report.Module, update bool, patched func(string, string, string) (map[string][]string, error)) (foundSymbols bool, err error) {
-	fixHash := filepath.Base(fixLink)
-	fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
-	pkgsToSymbols, err := patched(m.Module, fixRepo, fixHash)
+// populateFromFixHash takes a repository, fix hash and corresponding module and returns true
+// if any symbols are found for the given fix/module pairs.
+func populateFromFixHash(repo *repository, fixHash string, m *report.Module, patched func(string, string, *repository) (map[string][]string, error)) (foundSymbols bool, err error) {
+	pkgsToSymbols, err := patched(m.Module, fixHash, repo)
 	if err != nil {
 		return false, err
 	}
@@ -80,8 +104,37 @@
 			})
 		}
 	}
-	if update && foundSymbols {
-		m.FixLinks = append(m.FixLinks, fixLink)
-	}
 	return foundSymbols, nil
 }
+
+// getFixRepos takes a list of fix links and returns the repositories and hashes of those fix links.
+func getFixRepos(links []string, clone func(context.Context, string, string) (*git.Repository, error)) (fixRepos map[string]*repository, errs []error) {
+	fixRepos = make(map[string]*repository)
+	for _, fixLink := range links {
+		fixHash := filepath.Base(fixLink)
+		repoURL := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
+		if _, found := fixRepos[repoURL]; !found {
+			repoRoot, err := os.MkdirTemp("", fixHash)
+			if err != nil {
+				errs = append(errs, fmt.Errorf("error making temp dir for repo %s: %v", repoURL, err))
+				continue
+			}
+			ctx := context.Background()
+			r, err := clone(ctx, repoRoot, repoURL)
+			if err != nil {
+				errs = append(errs, fmt.Errorf("error cloning repo: %v", err.Error()))
+				continue
+			}
+			fixRepos[repoURL] = &repository{
+				repo:      r,
+				url:       repoURL,
+				root:      repoRoot,
+				fixHashes: []string{fixHash},
+			}
+		} else {
+			r := fixRepos[repoURL]
+			r.fixHashes = append(r.fixHashes, fixHash)
+		}
+	}
+	return fixRepos, errs
+}
diff --git a/internal/symbols/populate_test.go b/internal/symbols/populate_test.go
index aa9f83a..59f3621 100644
--- a/internal/symbols/populate_test.go
+++ b/internal/symbols/populate_test.go
@@ -5,9 +5,11 @@
 package symbols
 
 import (
+	"context"
 	"fmt"
 	"testing"
 
+	"github.com/go-git/go-git/v5"
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/vulndb/internal/osv"
 	"golang.org/x/vulndb/internal/report"
@@ -127,9 +129,29 @@
 				},
 			},
 		},
+		{
+			name:   "has fix link",
+			update: false,
+			input: &report.Report{
+				Modules: []*report.Module{{
+					Module:   "example.com/module",
+					FixLinks: []string{"https://example.com/module/commit/1234", "https://example.com/module/commit/5678"},
+				}},
+			},
+			want: &report.Report{
+				Modules: []*report.Module{{
+					Module: "example.com/module",
+					Packages: []*report.Package{{
+						Package: "example.com/module/package",
+						Symbols: []string{"symbol1", "symbol2", "symbol3"},
+					}},
+					FixLinks: []string{"https://example.com/module/commit/1234", "https://example.com/module/commit/5678"},
+				}},
+			},
+		},
 	} {
 		t.Run(tc.name, func(t *testing.T) {
-			if err := populate(tc.input, tc.update, patchedFake); err != nil {
+			if err := populate(tc.input, tc.update, mockClone, patchedFake); err != nil {
 				t.Fatal(err)
 			}
 			got := tc.input
@@ -140,16 +162,20 @@
 	}
 }
 
-func patchedFake(module string, repo string, hash string) (map[string][]string, error) {
-	if module == "example.com/module" && repo == "https://example.com/module" && hash == "1234" {
+func patchedFake(module string, hash string, repo *repository) (map[string][]string, error) {
+	if module == "example.com/module" && repo.url == "https://example.com/module" && hash == "1234" {
 		return map[string][]string{
 			"example.com/module/package": {"symbol1", "symbol2"},
 		}, nil
 	}
-	if module == "example.com/module" && repo == "https://example.com/module" && hash == "5678" {
+	if module == "example.com/module" && repo.url == "https://example.com/module" && hash == "5678" {
 		return map[string][]string{
 			"example.com/module/package": {"symbol1", "symbol2", "symbol3"},
 		}, nil
 	}
-	return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo, hash)
+	return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo.url, hash)
+}
+
+func mockClone(ctx context.Context, dir, repoURL string) (repo *git.Repository, err error) {
+	return nil, err
 }