internal/worker: combine module download code

Refactor the module download code so most of it is shared
for sandbox and insecure modes.

Change-Id: Idcbeff26165e821832bc6bcfcbfb8e4a111ce18b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/475475
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/worker/analysis.go b/internal/worker/analysis.go
index cace21c..ca88015 100644
--- a/internal/worker/analysis.go
+++ b/internal/worker/analysis.go
@@ -21,8 +21,6 @@
 
 	"golang.org/x/pkgsite-metrics/internal/analysis"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
-	"golang.org/x/pkgsite-metrics/internal/log"
-	"golang.org/x/pkgsite-metrics/internal/modules"
 	"golang.org/x/pkgsite-metrics/internal/queue"
 	"golang.org/x/pkgsite-metrics/internal/sandbox"
 	"golang.org/x/pkgsite-metrics/internal/scan"
@@ -55,8 +53,6 @@
 	return writeResult(ctx, req.Serve, w, s.bqClient, analysis.TableName, row)
 }
 
-const sandboxRoot = "/bundle/rootfs"
-
 func (s *analysisServer) scan(ctx context.Context, req *analysis.ScanRequest) *analysis.Result {
 	row := &analysis.Result{
 		ModulePath: req.Module,
@@ -98,12 +94,7 @@
 		if err != nil {
 			return nil, nil, err
 		}
-		defer func() {
-			err1 := os.RemoveAll(tempDir)
-			if err == nil {
-				err = err1
-			}
-		}()
+		defer removeDir(&err, tempDir)
 	}
 
 	var destPath string
@@ -124,29 +115,18 @@
 		return nil, nil, err
 	}
 
-	if !req.Insecure {
-		sandboxDir, cleanup, err := downloadModuleSandbox(ctx, req.Module, req.Version, s.proxyClient)
-		if err != nil {
-			return nil, nil, err
-		}
-		defer cleanup()
-		log.Infof(ctx, "running %s on %s@%s in sandbox", req.Binary, req.Module, req.Version)
-		sbox := sandbox.New("/bundle")
-		sbox.Runsc = "/usr/local/bin/runsc"
-		tree, err := runAnalysisBinary(sbox, strings.TrimPrefix(destPath, sandboxRoot), req.Args, sandboxDir)
-		if err != nil {
-			return nil, nil, err
-		}
-		return tree, binaryHash, nil
-	}
-	// Insecure mode.
-	// Download the module.
-	log.Debugf(ctx, "fetching module zip: %s@%s", req.Module, req.Version)
-	const stripModulePrefix = true
-	if err := modules.Download(ctx, req.Module, req.Version, tempDir, s.proxyClient, stripModulePrefix); err != nil {
+	mdir := moduleDir(req.Module, req.Version, req.Insecure)
+	defer removeDir(&err, mdir)
+	if err := prepareModule(ctx, req.Module, req.Version, mdir, s.proxyClient, req.Insecure); err != nil {
 		return nil, nil, err
 	}
-	tree, err := runAnalysisBinary(nil, destPath, req.Args, tempDir)
+	var sbox *sandbox.Sandbox
+	if !req.Insecure {
+		sbox = sandbox.New("/bundle")
+		sbox.Runsc = "/usr/local/bin/runsc"
+		destPath = strings.TrimPrefix(destPath, sandboxRoot)
+	}
+	tree, err := runAnalysisBinary(sbox, destPath, req.Args, mdir)
 	if err != nil {
 		return nil, nil, err
 	}
diff --git a/internal/worker/scan.go b/internal/worker/scan.go
index a3b5c37..ef5d53c 100644
--- a/internal/worker/scan.go
+++ b/internal/worker/scan.go
@@ -13,6 +13,7 @@
 	"net/http"
 	"os"
 	"os/exec"
+	"path/filepath"
 	"runtime/debug"
 	"strconv"
 	"strings"
@@ -23,12 +24,16 @@
 	"golang.org/x/pkgsite-metrics/internal/config"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
 	"golang.org/x/pkgsite-metrics/internal/log"
+	"golang.org/x/pkgsite-metrics/internal/modules"
+	"golang.org/x/pkgsite-metrics/internal/proxy"
 )
 
+const sandboxRoot = "/bundle/rootfs"
+
 var activeScans atomic.Int32
 
 func doScan(ctx context.Context, modulePath, version string, insecure bool, f func() error) (err error) {
-	defer derrors.Wrap(&err, "scan(%q, %q)", modulePath, version)
+	defer derrors.Wrap(&err, "doScan(%q, %q)", modulePath, version)
 
 	defer func() {
 		if e := recover(); e != nil {
@@ -198,3 +203,53 @@
 		return bucket.Object(name).NewReader(ctx)
 	}
 }
+
+// prepareModule prepares a module for scanning.
+// It downloads the module to the given directory and
+// takes other actions that increase the chance that
+// packages.Load will succeed.
+func prepareModule(ctx context.Context, modulePath, version, dir string, proxyClient *proxy.Client, insecure bool) error {
+	log.Debugf(ctx, "%s@%s: downloading to %s", modulePath, version, dir)
+	if err := modules.Download(ctx, modulePath, version, dir, proxyClient, true); err != nil {
+		log.Debugf(ctx, "download error: %v (%[1]T)", err)
+		return err
+	}
+	if !insecure {
+		// Download all dependencies outside of the sandbox, but use the Go build
+		// cache ("/bundle/rootfs/" + sandboxGoCache) inside the bundle.
+		log.Debugf(ctx, "%s@%s: running go mod download", modulePath, version)
+		cmd := exec.Command("go", "mod", "download")
+		cmd.Dir = dir
+		cmd.Env = append(cmd.Environ(),
+			"GOPROXY=https://proxy.golang.org",
+			"GOMODCACHE=/bundle/rootfs/"+sandboxGoModCache)
+		_, err := cmd.Output()
+		if err != nil {
+			return fmt.Errorf("%w: 'go mod download' for %s@%s returned %s",
+				derrors.BadModule, modulePath, version, derrors.IncludeStderr(err))
+		}
+		log.Debugf(ctx, "go mod download succeeded")
+	}
+	return nil
+}
+
+// moduleDir returns a the path of a directory where the module can be downloaded.
+func moduleDir(modulePath, version string, insecure bool) string {
+	dir := sandboxRoot
+	if insecure {
+		dir = os.TempDir()
+	}
+	return filepath.Join(dir, "modules", modulePath+"@"+version)
+}
+
+// removeDir calls os.RemoveAll(dir) and combines the error with errp.
+// It is meant to be deferred.
+func removeDir(errp *error, dir string) {
+	if err := os.RemoveAll(dir); err != nil {
+		if *errp == nil {
+			*errp = err
+		} else {
+			*errp = fmt.Errorf("RemoveAll(%q): %v, and also %w", dir, err, *errp)
+		}
+	}
+}
diff --git a/internal/worker/vulncheck_scan.go b/internal/worker/vulncheck_scan.go
index ae1a882..fea0667 100644
--- a/internal/worker/vulncheck_scan.go
+++ b/internal/worker/vulncheck_scan.go
@@ -285,13 +285,7 @@
 	if err != nil {
 		return nil, err
 	}
-
-	defer func() {
-		err1 := os.RemoveAll(tempDir)
-		if err == nil {
-			err = err1
-		}
-	}()
+	defer removeDir(&err, tempDir)
 
 	log.Debugf(ctx, "fetching module zip: %s@%s", modulePath, version)
 	if err = modules.Download(ctx, modulePath, version, tempDir, s.proxyClient, true); err != nil {
@@ -347,15 +341,18 @@
 	return res.Vulns, nil
 }
 
-func (s *scanner) runImportsScanSandbox(ctx context.Context, modulePath, version string, stats *vulncheckStats) ([]*vulncheck.Vuln, error) {
-	sandboxDir, cleanup, err := downloadModuleSandbox(ctx, modulePath, version, s.proxyClient)
-	if err != nil {
+func (s *scanner) runImportsScanSandbox(ctx context.Context, modulePath, version string, stats *vulncheckStats) (_ []*vulncheck.Vuln, err error) {
+	const insecure = false
+	mdir := moduleDir(modulePath, version, insecure)
+	defer removeDir(&err, mdir)
+	if err := prepareModule(ctx, modulePath, version, mdir, s.proxyClient, insecure); err != nil {
 		return nil, err
 	}
-	defer cleanup()
 
 	log.Infof(ctx, "running imports analysis in sandbox: %s@%s", modulePath, version)
-	stdout, err := s.sbox.Command("/binaries/vulncheck_sandbox", ModeImports, sandboxDir).Output()
+
+	smdir := strings.TrimPrefix(mdir, sandboxRoot)
+	stdout, err := s.sbox.Command("/binaries/vulncheck_sandbox", ModeImports, smdir).Output()
 	log.Infof(ctx, "done with imports analysis in sandbox: %s@%s err=%v", modulePath, version, err)
 
 	if err != nil {
@@ -368,19 +365,21 @@
 	return res.Vulns, nil
 }
 
-func (s *scanner) runGovulncheckScanSandbox(ctx context.Context, modulePath, version, binaryDir, mode string, stats *vulncheckStats) ([]*govulncheck.Vuln, error) {
+func (s *scanner) runGovulncheckScanSandbox(ctx context.Context, modulePath, version, binaryDir, mode string, stats *vulncheckStats) (_ []*govulncheck.Vuln, err error) {
 	if mode == ModeBinary {
 		return s.runBinaryScanSandbox(ctx, modulePath, version, binaryDir, stats)
 	}
 
-	sandboxDir, cleanup, err := downloadModuleSandbox(ctx, modulePath, version, s.proxyClient)
-	if err != nil {
+	const insecure = false
+	mdir := moduleDir(modulePath, version, insecure)
+	defer removeDir(&err, mdir)
+	if err := prepareModule(ctx, modulePath, version, mdir, s.proxyClient, insecure); err != nil {
 		return nil, err
 	}
-	defer cleanup()
 
 	log.Infof(ctx, "running govulncheck in sandbox: %s@%s", modulePath, version)
-	stdout, err := s.sbox.Command("/binaries/vulncheck_sandbox", ModeGovulncheck, sandboxDir).Output()
+	smdir := strings.TrimPrefix(mdir, sandboxRoot)
+	stdout, err := s.sbox.Command("/binaries/vulncheck_sandbox", ModeGovulncheck, smdir).Output()
 	log.Infof(ctx, "done with govulncheck in sandbox: %s@%s err=%v", modulePath, version, err)
 
 	if err != nil {
@@ -393,32 +392,6 @@
 	return res.Vulns, nil
 }
 
-func downloadModuleSandbox(ctx context.Context, modulePath, version string, proxyClient *proxy.Client) (string, func(), error) {
-	sandboxDir := "/modules/" + modulePath + "@" + version
-	imageDir := "/bundle/rootfs" + sandboxDir
-
-	log.Debugf(ctx, "downloading %s@%s to %s", modulePath, version, imageDir)
-	if err := modules.Download(ctx, modulePath, version, imageDir, proxyClient, true); err != nil {
-		log.Debugf(ctx, "download error: %v (%[1]T)", err)
-		return "", nil, err
-	}
-	// Download all dependencies outside of the sandbox, but use the Go build
-	// cache ("/bundle/rootfs/" + sandboxGoCache) inside the bundle.
-	log.Debugf(ctx, "running go mod download")
-	cmd := exec.Command("go", "mod", "download")
-	cmd.Dir = imageDir
-	cmd.Env = append(cmd.Environ(),
-		"GOPROXY=https://proxy.golang.org",
-		"GOMODCACHE=/bundle/rootfs/"+sandboxGoModCache)
-	_, err := cmd.Output()
-	if err != nil {
-		return "", nil, fmt.Errorf("%w: 'go mod download' for %s@%s returned %s",
-			derrors.BadModule, modulePath, version, derrors.IncludeStderr(err))
-	}
-	log.Debugf(ctx, "go mod download succeeded")
-	return sandboxDir, func() { os.RemoveAll(imageDir) }, nil
-}
-
 func (s *scanner) runBinaryScanSandbox(ctx context.Context, modulePath, version, binDir string, stats *vulncheckStats) ([]*govulncheck.Vuln, error) {
 	if s.gcsBucket == nil {
 		return nil, errors.New("binary bucket not configured; set GO_ECOSYSTEM_BINARY_BUCKET")
@@ -461,28 +434,17 @@
 }
 
 func (s *scanner) runGovulncheckScanInsecure(ctx context.Context, modulePath, version, binaryDir, mode string, stats *vulncheckStats) (_ []*govulncheck.Vuln, err error) {
-	tempDir, err := os.MkdirTemp("", "runGovulncheckScan")
-	if err != nil {
-		return nil, err
-	}
-
-	defer func() {
-		err1 := os.RemoveAll(tempDir)
-		if err == nil {
-			err = err1
-		}
-	}()
-
 	if mode == ModeBinary {
-		return s.runBinaryScanInsecure(ctx, modulePath, version, binaryDir, tempDir, stats)
+		return s.runBinaryScanInsecure(ctx, modulePath, version, binaryDir, os.TempDir(), stats)
 	}
 
-	log.Debugf(ctx, "fetching module zip: %s@%s", modulePath, version)
-	if err := modules.Download(ctx, modulePath, version, tempDir, s.proxyClient, true); err != nil {
+	mdir := moduleDir(modulePath, version, true)
+	defer removeDir(&err, mdir)
+	if err := prepareModule(ctx, modulePath, version, mdir, s.proxyClient, true); err != nil {
 		return nil, err
 	}
 	start := time.Now()
-	vulns, err := runGovulncheckCmd(ctx, "./...", tempDir, stats)
+	vulns, err := runGovulncheckCmd(ctx, "./...", mdir, stats)
 	if err != nil {
 		return nil, err
 	}